diff --git a/egs/ami/s5b/conf/segmentation_speech.conf b/egs/ami/s5b/conf/segmentation_speech.conf new file mode 100644 index 00000000000..c4c75b212fc --- /dev/null +++ b/egs/ami/s5b/conf/segmentation_speech.conf @@ -0,0 +1,14 @@ +# General segmentation options +pad_length=20 # Pad speech segments by this many frames on either side +max_relabel_length=10 # Maximum duration of speech that will be removed as part + # of smoothing process. This is only if there are no other + # speech segments nearby. +max_intersegment_length=30 # Merge nearby speech segments if the silence + # between them is less than this many frames. +post_pad_length=10 # Pad speech segments by this many frames on either side + # after the merging process using max_intersegment_length +max_segment_length=1000 # Segments that are longer than this are split into + # overlapping frames. +overlap_length=250 # Overlapping frames when segments are split. + # See the above option. +min_silence_length=20 # Min silence length at which to split very long segments diff --git a/egs/ami/s5b/local/ami_normalize_transcripts.pl b/egs/ami/s5b/local/ami_normalize_transcripts.pl new file mode 100644 index 00000000000..772e8b50fec --- /dev/null +++ b/egs/ami/s5b/local/ami_normalize_transcripts.pl @@ -0,0 +1,129 @@ +#!/usr/bin/env perl + +# Copyright 2014 University of Edinburgh (Author: Pawel Swietojanski) +# 2016 Vimal Manohar + +# The script - based on punctuation times - splits segments longer than #words (input parameter) +# and produces bit more more normalised form of transcripts, as follows +# MeetID Channel Spkr stime etime transcripts + +#use List::MoreUtils 'indexes'; +use strict; +use warnings; + +sub normalise_transcripts; + +sub merge_hashes { + my ($h1, $h2) = @_; + my %hash1 = %$h1; my %hash2 = %$h2; + foreach my $key2 ( keys %hash2 ) { + if( exists $hash1{$key2} ) { + warn "Key [$key2] is in both hashes!"; + next; + } else { + $hash1{$key2} = $hash2{$key2}; + } + } + return %hash1; +} + +sub print_hash { + my ($h) = @_; + my %hash = %$h; + foreach my $k (sort keys %hash) { + print "$k : $hash{$k}\n"; + } +} + +sub get_name { + #no warnings; + my $sname = sprintf("%07d_%07d", $_[0]*100, $_[1]*100) || die 'Input undefined!'; + #use warnings; + return $sname; +} + +sub split_on_comma { + + my ($text, $comma_times, $btime, $etime, $max_words_per_seg)= @_; + my %comma_hash = %$comma_times; + + print "Btime, Etime : $btime, $etime\n"; + + my $stime = ($etime+$btime)/2; #split time + my $skey = ""; + my $otime = $btime; + foreach my $k (sort {$comma_hash{$a} cmp $comma_hash{$b} } keys %comma_hash) { + print "Key : $k : $comma_hash{$k}\n"; + my $ktime = $comma_hash{$k}; + if ($ktime==$btime) { next; } + if ($ktime==$etime) { last; } + if (abs($stime-$ktime)/20) { + $st=$comma_hash{$skey}; + $et = $etime; + } + my (@utts) = split (' ', $utts1[$i]); + if ($#utts < $max_words_per_seg) { + my $nm = get_name($st, $et); + print "SplittedOnComma[$i]: $nm : $utts1[$i]\n"; + $transcripts{$nm} = $utts1[$i]; + } else { + print 'Continue splitting!'; + my %transcripts2 = split_on_comma($utts1[$i], \%comma_hash, $st, $et, $max_words_per_seg); + %transcripts = merge_hashes(\%transcripts, \%transcripts2); + } + } + return %transcripts; +} + +sub normalise_transcripts { + my $text = $_; + + #DO SOME ROUGH AND OBVIOUS PRELIMINARY NORMALISATION, AS FOLLOWS + #remove the remaining punctation labels e.g. some text ,0 some text ,1 + $text =~ s/[\.\,\?\!\:][0-9]+//g; + #there are some extra spurious puncations without spaces, e.g. UM,I, replace with space + $text =~ s/[A-Z']+,[A-Z']+/ /g; + #split words combination, ie. ANTI-TRUST to ANTI TRUST (None of them appears in cmudict anyway) + #$text =~ s/(.*)([A-Z])\s+(\-)(.*)/$1$2$3$4/g; + $text =~ s/\-/ /g; + #substitute X_M_L with X. M. L. etc. + $text =~ s/\_/. /g; + #normalise and trim spaces + $text =~ s/^\s*//g; + $text =~ s/\s*$//g; + $text =~ s/\s+/ /g; + #some transcripts are empty with -, nullify (and ignore) them + $text =~ s/^\-$//g; + $text =~ s/\s+\-$//; + # apply few exception for dashed phrases, Mm-Hmm, Uh-Huh, etc. those are frequent in AMI + # and will be added to dictionary + $text =~ s/MM HMM/MM\-HMM/g; + $text =~ s/UH HUH/UH\-HUH/g; + + return $text; +} + +while(<>) { + chomp; + print normalise_transcripts($_) . "\n"; +} + diff --git a/egs/ami/s5b/local/chain/run_decode.sh b/egs/ami/s5b/local/chain/run_decode.sh new file mode 100755 index 00000000000..545bdc7b157 --- /dev/null +++ b/egs/ami/s5b/local/chain/run_decode.sh @@ -0,0 +1,131 @@ +#!/bin/bash + +# Copyright 2016 Vimal Manohar +# Apache 2.0. + +set -e +set -o pipefail +set -u + +stage=-1 +decode_stage=1 + +mic=ihm +use_ihm_ali=false +exp_name=tdnn + +nj=20 + +cleanup_affix= +graph_dir= + +decode_set=dev +decode_suffix= + +extractor= +use_ivectors=true +use_offline_ivectors=false +frames_per_chunk=50 + +scoring_opts= + +. path.sh +. cmd.sh + +. parse_options.sh + +new_mic=$mic +if [ $use_ihm_ali == "true" ]; then + new_mic=${mic}_cleanali +fi + +dir=exp/$new_mic/chain${cleanup_affix:+_$cleanup_affix}/${exp_name} + +if [ $stage -le -1 ]; then + mfccdir=mfcc_${mic} + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $mfccdir/storage ]; then + utils/create_split_dir.pl /export/b0{1,2,3,4}/$USER/kaldi-data/egs/ami-$mic-$(date +'%m_%d_%H_%M')/s5/$mfccdir/storage $mfccdir/storage + fi + + steps/make_mfcc.sh --nj $nj --mfcc-config conf/mfcc.conf \ + --cmd "$train_cmd" data/$mic/${decode_set} exp/make_${mic}/$decode_set $mfccdir || exit 1; + + steps/compute_cmvn_stats.sh data/$mic/${decode_set} exp/make_${mic}/$mic/$decode_set $mfccdir || exit 1; + + utils/fix_data_dir.sh data/$mic/${decode_set} +fi + +if [ $stage -le 0 ]; then + mfccdir=mfcc_${mic}_hires + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $mfccdir/storage ]; then + utils/create_split_dir.pl /export/b0{1,2,3,4}/$USER/kaldi-data/egs/ami-$mic-$(date +'%m_%d_%H_%M')/s5/$mfccdir/storage $mfccdir/storage + fi + + utils/copy_data_dir.sh data/$mic/$decode_set data/$mic/${decode_set}_hires + + steps/make_mfcc.sh --nj $nj --mfcc-config conf/mfcc_hires.conf \ + --cmd "$train_cmd" data/$mic/${decode_set}_hires exp/make_${mic}_hires/$decode_set $mfccdir || exit 1; + + steps/compute_cmvn_stats.sh data/$mic/${decode_set}_hires exp/make_${mic}_hires/$mic/$decode_set $mfccdir || exit 1; + + utils/fix_data_dir.sh data/$mic/${decode_set}_hires +fi + +if $use_ivectors && [ $stage -le 1 ]; then + if [ -z "$extractor" ]; then + echo "--extractor must be supplied when using ivectors" + exit 1 + fi + + if $use_offline_ivectors; then + steps/online/nnet2/extract_ivectors.sh \ + --cmd "$train_cmd" --nj 8 \ + data/$mic/${decode_set}_hires data/lang $extractor \ + exp/$mic/nnet3${cleanup_affix:+_$cleanup_affix}/ivectors_offline_${decode_set} || exit 1 + else + steps/online/nnet2/extract_ivectors_online.sh \ + --cmd "$train_cmd" --nj 8 \ + data/$mic/${decode_set}_hires $extractor \ + exp/$mic/nnet3${cleanup_affix:+_$cleanup_affix}/ivectors_${decode_set} || exit 1 + fi +fi + +final_lm=`cat data/local/lm/final_lm` +LM=$final_lm.pr1-7 + +if [ -z "$graph_dir" ]; then + graph_dir=$dir/graph_${LM} + if [ $stage -le 2 ]; then + # Note: it might appear that this $lang directory is mismatched, and it is as + # far as the 'topo' is concerned, but this script doesn't read the 'topo' from + # the lang directory. + utils/mkgraph.sh --self-loop-scale 1.0 data/lang_${LM} $dir $graph_dir + fi +fi + +nj=`cat data/$mic/${decode_set}/utt2spk|cut -d' ' -f2|sort -u|wc -l` + +if [ $nj -gt 50 ]; then + nj=50 +fi + +if [ "$frames_per_chunk" -ne 50 ]; then + decode_suffix=${decode_suffix}_cs${frames_per_chunk} +fi + +if [ $stage -le 3 ]; then + ivector_opts= + if $use_ivectors; then + if $use_offline_ivectors; then + ivector_opts="--online-ivector-dir exp/$mic/nnet3${cleanup_affix:+_$cleanup_affix}/ivectors_offline_${decode_set}" + decode_suffix=${decode_suffix}_offline + else + ivector_opts="--online-ivector-dir exp/$mic/nnet3${cleanup_affix:+_$cleanup_affix}/ivectors_${decode_set}" + fi + fi + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --stage $decode_stage --frames-per-chunk $frames_per_chunk \ + --nj $nj --cmd "$decode_cmd" $ivector_opts \ + --scoring-opts "--min-lmwt 5 --decode-mbr false $scoring_opts" \ + $graph_dir data/$mic/${decode_set}_hires $dir/decode${decode_suffix}_${decode_set} || exit 1; +fi diff --git a/egs/ami/s5b/local/chain/run_decode_two_stage.sh b/egs/ami/s5b/local/chain/run_decode_two_stage.sh new file mode 100755 index 00000000000..0d354bfa574 --- /dev/null +++ b/egs/ami/s5b/local/chain/run_decode_two_stage.sh @@ -0,0 +1,135 @@ +#!/bin/bash + +set -e -u +set -o pipefail + +stage=-1 +decode_stage=1 + +mic=ihm +use_ihm_ali=false +exp_name=tdnn + +cleanup_affix= + +decode_set=dev +extractor= +use_ivectors=true +scoring_opts= +lmwt=8 +pad_frames=10 + +. path.sh +. cmd.sh + +. parse_options.sh + +new_mic=$mic +if [ $use_ihm_ali == "true" ]; then + new_mic=${mic}_cleanali +fi + +dir=exp/$new_mic/chain${cleanup_affix:+_$cleanup_affix}/${exp_name} + +nj=20 + +if [ $stage -le -1 ]; then + mfccdir=mfcc_${mic} + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $mfccdir/storage ]; then + utils/create_split_dir.pl /export/b0{1,2,3,4}/$USER/kaldi-data/egs/ami-$mic-$(date +'%m_%d_%H_%M')/s5/$mfccdir/storage $mfccdir/storage + fi + + steps/make_mfcc.sh --nj $nj --mfcc-config conf/mfcc.conf \ + --cmd "$train_cmd" data/$mic/${decode_set} exp/make_${mic}/$decode_set $mfccdir || exit 1; + + steps/compute_cmvn_stats.sh data/$mic/${decode_set} exp/make_${mic}/$mic/$decode_set $mfccdir || exit 1; + + utils/fix_data_dir.sh data/$mic/${decode_set} +fi + +utils/data/get_utt2dur.sh data/$mic/${decode_set} + +if [ $stage -le 0 ]; then + mfccdir=mfcc_${mic}_hires + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $mfccdir/storage ]; then + utils/create_split_dir.pl /export/b0{1,2,3,4}/$USER/kaldi-data/egs/ami-$mic-$(date +'%m_%d_%H_%M')/s5/$mfccdir/storage $mfccdir/storage + fi + + utils/copy_data_dir.sh data/$mic/$decode_set data/$mic/${decode_set}_hires + + steps/make_mfcc.sh --nj $nj --mfcc-config conf/mfcc_hires.conf \ + --cmd "$train_cmd" data/$mic/${decode_set}_hires exp/make_${mic}_hires/$decode_set $mfccdir || exit 1; + + steps/compute_cmvn_stats.sh data/$mic/${decode_set}_hires exp/make_${mic}_hires/$mic/$decode_set $mfccdir || exit 1; + + utils/fix_data_dir.sh data/$mic/${decode_set}_hires +fi + +if $use_ivectors && [ $stage -le 1 ]; then + if [ -z "$extractor" ]; then + "--extractor must be supplied when using ivectors" + fi + + steps/online/nnet2/extract_ivectors_online.sh \ + --cmd "$train_cmd" --nj 8 \ + data/$mic/${decode_set}_hires $extractor \ + exp/$mic/nnet3${cleanup_affix:+_$cleanup_affix}/ivectors_${decode_set} || exit 1 +fi + +final_lm=`cat data/local/lm/final_lm` +LM=$final_lm.pr1-7 +graph_dir=$dir/graph_${LM} +if [ $stage -le 2 ]; then + # Note: it might appear that this $lang directory is mismatched, and it is as + # far as the 'topo' is concerned, but this script doesn't read the 'topo' from + # the lang directory. + utils/mkgraph.sh --self-loop-scale 1.0 data/lang_${LM} $dir $graph_dir +fi + +nj=`cat data/$mic/${decode_set}/utt2spk|cut -d' ' -f2|sort -u|wc -l` + +if [ $nj -gt 50 ]; then + nj=50 +fi + +if [ $stage -le 3 ]; then + ivector_opts= + if $use_ivectors; then + ivector_opts="--online-ivector-dir exp/$mic/nnet3${cleanup_affix:+_$cleanup_affix}/ivectors_${decode_set}" + fi + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --stage $decode_stage \ + --nj $nj --cmd "$decode_cmd" $ivector_opts \ + --scoring-opts "--min-lmwt 5 $scoring_opts" \ + $graph_dir data/$mic/${decode_set}_hires $dir/decode_${decode_set} || exit 1; +fi + +ivector_weights=$dir/decode_${decode_set}/ascore_$lmwt/ivector_weights.gz + +if [ $stage -le 4 ]; then + cat $dir/decode_${decode_set}/ascore_$lmwt/${decode_set}_hires.utt.ctm | \ + grep -i -v -E '\[noise|laughter|vocalized-noise\]' | \ + local/get_ivector_weights_from_ctm_conf.pl \ + --pad-frames $pad_frames data/$mic/${decode_set}/utt2dur | \ + gzip -c > $ivector_weights +fi + +if [ $stage -le 5 ]; then + steps/online/nnet2/extract_ivectors_online.sh \ + --cmd "$train_cmd" --nj $nj --weights $ivector_weights \ + data/$mic/${decode_set}_hires $extractor \ + exp/$mic/nnet3${cleanup_affix:+_$cleanup_affix}/ivectors_${decode_set}_stage2 || exit 1 +fi + +if [ $stage -le 6 ]; then + ivector_opts= + if $use_ivectors; then + ivector_opts="--online-ivector-dir exp/$mic/nnet3${cleanup_affix:+_$cleanup_affix}/ivectors_${decode_set}_stage2" + fi + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --stage $decode_stage \ + --nj $nj --cmd "$decode_cmd" $ivector_opts \ + --scoring-opts "--min-lmwt 5 $scoring_opts" \ + $graph_dir data/$mic/${decode_set}_hires $dir/decode_${decode_set}_stage2 || exit 1; +fi + diff --git a/egs/ami/s5b/local/chain/run_tdnn_noivec.sh b/egs/ami/s5b/local/chain/run_tdnn_noivec.sh new file mode 100755 index 00000000000..d1329dc2bd1 --- /dev/null +++ b/egs/ami/s5b/local/chain/run_tdnn_noivec.sh @@ -0,0 +1,245 @@ +#!/bin/bash + +# This is a chain-training script with TDNN neural networks. +# Please see RESULTS_* for examples of command lines invoking this script. + + +# local/nnet3/run_tdnn.sh --stage 8 --use-ihm-ali true --mic sdm1 # rerunning with biphone +# local/nnet3/run_tdnn.sh --stage 8 --use-ihm-ali false --mic sdm1 + +# local/chain/run_tdnn.sh --use-ihm-ali true --mic sdm1 --train-set train --gmm tri3 --nnet3-affix "" --stage 12 & + +# local/chain/run_tdnn.sh --use-ihm-ali true --mic mdm8 --stage 12 & +# local/chain/run_tdnn.sh --use-ihm-ali true --mic mdm8 --train-set train --gmm tri3 --nnet3-affix "" --stage 12 & + +# local/chain/run_tdnn.sh --mic sdm1 --use-ihm-ali true --train-set train_cleaned --gmm tri3_cleaned& + + +set -e -o pipefail + +# First the options that are passed through to run_ivector_common.sh +# (some of which are also used in this script directly). +stage=0 +mic=ihm +nj=30 +min_seg_len=1.55 +use_ihm_ali=false +train_set=train_cleaned +gmm=tri3_cleaned # the gmm for the target data +ihm_gmm=tri3 # the gmm for the IHM system (if --use-ihm-ali true). +num_threads_ubm=32 +nnet3_affix=_cleaned # cleanup affix for nnet3 and chain dirs, e.g. _cleaned + +# The rest are configs specific to this script. Most of the parameters +# are just hardcoded at this level, in the commands below. +train_stage=-10 +tree_affix= # affix for tree directory, e.g. "a" or "b", in case we change the configuration. +tdnn_affix= #affix for TDNN directory, e.g. "a" or "b", in case we change the configuration. +common_egs_dir= # you can set this to use previously dumped egs. + +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + + +if ! cuda-compiled; then + cat <data/lang_chain/topo + fi +fi + +if [ $stage -le 13 ]; then + # Get the alignments as lattices (gives the chain training more freedom). + # use the same num-jobs as the alignments + steps/align_fmllr_lats.sh --nj 100 --cmd "$train_cmd" ${lores_train_data_dir} \ + data/lang $gmm_dir $lat_dir + rm $lat_dir/fsts.*.gz # save space +fi + +if [ $stage -le 14 ]; then + # Build a tree using our new topology. We know we have alignments for the + # speed-perturbed data (local/nnet3/run_ivector_common.sh made them), so use + # those. + if [ -f $tree_dir/final.mdl ]; then + echo "$0: $tree_dir/final.mdl already exists, refusing to overwrite it." + exit 1; + fi + steps/nnet3/chain/build_tree.sh --frame-subsampling-factor 3 \ + --context-opts "--context-width=2 --central-position=1" \ + --leftmost-questions-truncate -1 \ + --cmd "$train_cmd" 4200 ${lores_train_data_dir} data/lang_chain $ali_dir $tree_dir +fi + +if [ $stage -le 15 ]; then + mkdir -p $dir + + echo "$0: creating neural net configs"; + + steps/nnet3/tdnn/make_configs.py \ + --self-repair-scale-nonlinearity 0.00001 \ + --feat-dir data/$mic/${train_set}_sp_hires_comb \ + --tree-dir $tree_dir \ + --relu-dim 450 \ + --splice-indexes "-1,0,1 -1,0,1,2 -3,0,3 -3,0,3 -3,0,3 -6,-3,0 0" \ + --use-presoftmax-prior-scale false \ + --xent-regularize 0.1 \ + --xent-separate-forward-affine true \ + --include-log-softmax false \ + --final-layer-normalize-target 1.0 \ + $dir/configs || exit 1; +fi + +if [ $stage -le 16 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b0{5,6,7,8}/$USER/kaldi-data/egs/ami-$(date +'%m_%d_%H_%M')/s5b/$dir/egs/storage $dir/egs/storage + fi + + touch $dir/egs/.nodelete # keep egs around when that run dies. + + steps/nnet3/chain/train.py --stage $train_stage \ + --cmd "$decode_cmd" \ + --feat.cmvn-opts "--norm-means=true --norm-vars=false" \ + --chain.xent-regularize 0.1 \ + --chain.leaky-hmm-coefficient 0.1 \ + --chain.l2-regularize 0.00005 \ + --chain.apply-deriv-weights false \ + --chain.lm-opts="--num-extra-lm-states=2000" \ + --egs.dir "$common_egs_dir" \ + --egs.opts "--frames-overlap-per-eg 0" \ + --egs.chunk-width 150 \ + --trainer.num-chunk-per-minibatch 128 \ + --trainer.frames-per-iter 1500000 \ + --trainer.num-epochs 4 \ + --trainer.optimization.num-jobs-initial 2 \ + --trainer.optimization.num-jobs-final 12 \ + --trainer.optimization.initial-effective-lrate 0.001 \ + --trainer.optimization.final-effective-lrate 0.0001 \ + --trainer.max-param-change 2.0 \ + --cleanup.remove-egs true \ + --feat-dir $train_data_dir \ + --tree-dir $tree_dir \ + --lat-dir $lat_dir \ + --dir $dir +fi + + +graph_dir=$dir/graph_${LM} +if [ $stage -le 17 ]; then + # Note: it might appear that this data/lang_chain directory is mismatched, and it is as + # far as the 'topo' is concerned, but this script doesn't read the 'topo' from + # the lang directory. + utils/mkgraph.sh --left-biphone --self-loop-scale 1.0 data/lang_${LM} $dir $graph_dir +fi + +if [ $stage -le 18 ]; then + rm $dir/.error 2>/dev/null || true + for decode_set in dev eval; do + ( + nj_dev=`cat data/$mic/${decode_set}_hires/spk2utt | wc -l` + if [ $nj_dev -gt 30 ]; then + nj_dev=30 + fi + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --nj $nj_dev --cmd "$decode_cmd" \ + --scoring-opts "--min-lmwt 5 " \ + $graph_dir data/$mic/${decode_set}_hires $dir/decode_${decode_set} || exit 1; + ) || touch $dir/.error & + done + wait + if [ -f $dir/.error ]; then + echo "$0: something went wrong in decoding" + exit 1 + fi +fi +exit 0 + diff --git a/egs/ami/s5b/local/chain/tuning/run_tdnn_1a.sh b/egs/ami/s5b/local/chain/tuning/run_tdnn_1a.sh index 86587d6d830..88c09c2cb15 100755 --- a/egs/ami/s5b/local/chain/tuning/run_tdnn_1a.sh +++ b/egs/ami/s5b/local/chain/tuning/run_tdnn_1a.sh @@ -224,8 +224,12 @@ if [ $stage -le 18 ]; then rm $dir/.error 2>/dev/null || true for decode_set in dev eval; do ( + nj_dev=`cat data/$mic/${decode_set}_hires/spk2utt | wc -l` + if [ $nj_dev -gt $nj ]; then + nj_dev=$nj + fi steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ - --nj $nj --cmd "$decode_cmd" \ + --nj $nj_dev --cmd "$decode_cmd" \ --online-ivector-dir exp/$mic/nnet3${nnet3_affix}/ivectors_${decode_set}_hires \ --scoring-opts "--min-lmwt 5 " \ $graph_dir data/$mic/${decode_set}_hires $dir/decode_${decode_set} || exit 1; diff --git a/egs/ami/s5b/local/get_ivector_weights_from_ctm_conf.pl b/egs/ami/s5b/local/get_ivector_weights_from_ctm_conf.pl new file mode 100755 index 00000000000..96db9af3638 --- /dev/null +++ b/egs/ami/s5b/local/get_ivector_weights_from_ctm_conf.pl @@ -0,0 +1,77 @@ +#! /usr/bin/perl +use strict; +use warnings; +use Getopt::Long; + +my $pad_frames = 0; +my $silence_weight = 0.00001; +my $scale_weights_by_ctm_conf = "false"; +my $frame_shift = 0.01; + +GetOptions('pad-frames:i' => \$pad_frames, + 'silence-weight:f' => \$frame_shift, + 'scale-weights-by-ctm-conf:s' => \$scale_weights_by_ctm_conf, + 'frame-shift:f' => \$frame_shift); + +if (scalar @ARGV != 1) { + die "Usage: get_ivector_weights_from_ctm_conf.pl < > "; +} + +my $utt2dur = shift @ARGV; + +$pad_frames >= 0 || die "Bad pad-frames value $pad_frames; must be >= 0"; +($scale_weights_by_ctm_conf eq 'false') || ($scale_weights_by_ctm_conf eq 'true') || die "Bad scale-weights-by-ctm-conf $scale_weights_by_ctm_conf; must be true/false"; + +open(L, "<$utt2dur") || die "unable to open utt2dur file $utt2dur"; + +my @all_utts = (); +my %utt2weights; + +while () { + chomp; + my @A = split; + @A == 2 || die "Incorrent format of utt2dur file $_"; + my ($utt, $len) = @A; + + push @all_utts, $utt; + $len = int($len / $frame_shift); + + # Initialize weights for each utterance + my $weights = []; + for (my $n = 0; $n < $len; $n++) { + push @$weights, $silence_weight; + } + $utt2weights{$utt} = $weights; +} +close(L); + +while () { + chomp; + my @A = split; + @A == 6 || die "bad ctm line $_"; + + my $utt = $A[0]; + my $beg = $A[2]; + my $len = $A[3]; + my $beg_int = int($beg / $frame_shift) - $pad_frames; + my $len_int = int($len / $frame_shift) + 2*$pad_frames; + my $conf = $A[5]; + + my $array_ref = $utt2weights{$utt}; + defined $array_ref || die "No length info for utterance $utt"; + + for (my $t = $beg_int; $t < $beg_int + $len_int; $t++) { + if ($t >= 0 && $t < @$array_ref) { + if ($scale_weights_by_ctm_conf eq "false") { + ${$array_ref}[$t] = 1; + } else { + ${$array_ref}[$t] = $conf; + } + } + } +} + +foreach my $utt (keys %utt2weights) { + my $array_ref = $utt2weights{$utt}; + print ($utt, " [ ", join(" ", @$array_ref), " ]\n"); +} diff --git a/egs/ami/s5b/local/make_rt_2004_dev.pl b/egs/ami/s5b/local/make_rt_2004_dev.pl new file mode 120000 index 00000000000..a0d27619369 --- /dev/null +++ b/egs/ami/s5b/local/make_rt_2004_dev.pl @@ -0,0 +1 @@ +../../../rt/s5/local/make_rt_2004_dev.pl \ No newline at end of file diff --git a/egs/ami/s5b/local/make_rt_2004_eval.pl b/egs/ami/s5b/local/make_rt_2004_eval.pl new file mode 120000 index 00000000000..8b951f9c940 --- /dev/null +++ b/egs/ami/s5b/local/make_rt_2004_eval.pl @@ -0,0 +1 @@ +../../../rt/s5/local/make_rt_2004_eval.pl \ No newline at end of file diff --git a/egs/ami/s5b/local/make_rt_2005_eval.pl b/egs/ami/s5b/local/make_rt_2005_eval.pl new file mode 120000 index 00000000000..6185b83a5a3 --- /dev/null +++ b/egs/ami/s5b/local/make_rt_2005_eval.pl @@ -0,0 +1 @@ +../../../rt/s5/local/make_rt_2005_eval.pl \ No newline at end of file diff --git a/egs/ami/s5b/local/modify_stm.py b/egs/ami/s5b/local/modify_stm.py new file mode 100755 index 00000000000..52ab6fed1ef --- /dev/null +++ b/egs/ami/s5b/local/modify_stm.py @@ -0,0 +1,97 @@ +#! /usr/bin/env python + +import sys +import collections +import itertools +import argparse + +from collections import defaultdict + +def IgnoreWordList(stm_lines, wordlist): + for i in range(0, len(stm_lines)): + line = stm_lines[i] + splits = line.strip().split() + + line_changed = False + for j in range(5, len(splits)): + if str.lower(splits[j]) in wordlist: + splits[j] = "{{ {0} / @ }}".format(splits[j]) + line_changed = True + + + if line_changed: + stm_lines[i] = " ".join(splits) + +def IgnoreIsolatedWords(stm_lines): + for i in range(0, len(stm_lines)): + line = stm_lines[i] + splits = line.strip().split() + + assert( splits[5][0] != '<' ) + + if len(splits) == 6 and splits[5] != "IGNORE_TIME_SEGMENT_IN_SCORING": + splits.insert(5, "") + else: + splits.insert(5, "") + stm_lines[i] = " ".join(splits) + +def IgnoreBeginnings(stm_lines): + beg_times = defaultdict(itertools.repeat(float("inf")).next) + + lines_to_add = [] + for line in stm_lines: + splits = line.strip().split() + + beg_times[(splits[0],splits[1])] = min(beg_times[(splits[0],splits[1])], float(splits[3])) + + for t,v in beg_times.iteritems(): + lines_to_add.append("{0} {1} {0} 0.0 {2} IGNORE_TIME_SEGMENT_IN_SCORING".format(t[0], t[1], v)) + + stm_lines.extend(lines_to_add) + +def WriteStmLines(stm_lines): + for line in stm_lines: + print(line) + +def GetArgs(): + parser = argparse.ArgumentParser("This script modifies STM to remove certain words and segments from scoring. Use sort +0 -1 +1 -2 +3nb -4 while writing out.", + formatter_class = argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument("--ignore-beginnings", + type = str, choices = ["true", "false"], + help = "Ignore beginnings of the recordings since " + "they are not transcribed") + parser.add_argument("--ignore-isolated-words", + type = str, choices = ["true", "false"], + help = "Remove isolated words from scoring " + "because they may be hard to recognize without " + "speaker diarization") + parser.add_argument("--ignore-word-list", + type = str, + help = "List of words to be ignored") + + args = parser.parse_args() + + return args + +def Main(): + args = GetArgs() + + stm_lines = [ x.strip() for x in sys.stdin.readlines() ] + + print (';; LABEL "NO_ISO", "No isolated words", "Ignoring isolated words"') + print (';; LABEL "ISO", "Isolated words", "isolated words"') + + #if args.ignore_word_list is not None: + # wordlist = {} + # for x in open(args.ignore_word_list).readlines(): + # wordlist[str.lower(x.strip())] = 1 + # IgnoreWordList(stm_lines, wordlist) + + IgnoreIsolatedWords(stm_lines) + IgnoreBeginnings(stm_lines) + + WriteStmLines(stm_lines) + +if __name__ == "__main__": + Main() diff --git a/egs/ami/s5b/local/nnet3/run_blstm.sh b/egs/ami/s5b/local/nnet3/run_blstm.sh index 776151fb5aa..e0e7bcfcdcf 100755 --- a/egs/ami/s5b/local/nnet3/run_blstm.sh +++ b/egs/ami/s5b/local/nnet3/run_blstm.sh @@ -7,6 +7,7 @@ remove_egs=true use_ihm_ali=false train_set=train_cleaned ihm_gmm=tri3 +gmm=tri3a_cleaned nnet3_affix=_cleaned # BLSTM params @@ -32,6 +33,7 @@ local/nnet3/run_lstm.sh --affix $affix \ --srand $srand \ --train-stage $train_stage \ --train-set $train_set \ + --gmm $gmm \ --ihm-gmm $ihm_gmm \ --nnet3-affix $nnet3_affix \ --lstm-delay " [-1,1] [-2,2] [-3,3] " \ @@ -49,4 +51,3 @@ local/nnet3/run_lstm.sh --affix $affix \ --num-epochs $num_epochs \ --use-ihm-ali $use_ihm_ali \ --remove-egs $remove_egs - diff --git a/egs/ami/s5b/local/nnet3/run_lstm.sh b/egs/ami/s5b/local/nnet3/run_lstm.sh index c5583e2d0ef..25254629933 100755 --- a/egs/ami/s5b/local/nnet3/run_lstm.sh +++ b/egs/ami/s5b/local/nnet3/run_lstm.sh @@ -225,9 +225,12 @@ if [ $stage -le 14 ]; then [ ! -z $decode_iter ] && model_opts=" --iter $decode_iter "; for decode_set in dev eval; do ( - num_jobs=`cat data/$mic/${decode_set}_hires/utt2spk|cut -d' ' -f2|sort -u|wc -l` + nj_dev=`cat data/$mic/${decode_set}_hires/spk2utt | wc -l` + if [ $nj_dev -gt $nj ]; then + nj_dev=$nj + fi decode_dir=${dir}/decode_${decode_set} - steps/nnet3/decode.sh --nj 250 --cmd "$decode_cmd" \ + steps/nnet3/decode.sh --nj $nj_dev --cmd "$decode_cmd" \ $model_opts \ --extra-left-context $extra_left_context \ --extra-right-context $extra_right_context \ diff --git a/egs/ami/s5b/local/nnet3/run_tdnn.sh b/egs/ami/s5b/local/nnet3/run_tdnn.sh index bbc6ed5c042..7b463f4ce57 100755 --- a/egs/ami/s5b/local/nnet3/run_tdnn.sh +++ b/egs/ami/s5b/local/nnet3/run_tdnn.sh @@ -45,10 +45,12 @@ tdnn_affix= #affix for TDNN directory e.g. "a" or "b", in case we change the co # Options which are not passed through to run_ivector_common.sh train_stage=-10 splice_indexes="-2,-1,0,1,2 -1,2 -3,3 -7,2 -3,3 0 0" -remove_egs=true +remove_egs=false relu_dim=850 num_epochs=3 +common_egs_dir= + . cmd.sh . ./path.sh . ./utils/parse_options.sh @@ -122,30 +124,55 @@ fi [ ! -f $ali_dir/ali.1.gz ] && echo "$0: expected $ali_dir/ali.1.gz to exist" && exit 1 if [ $stage -le 12 ]; then + steps/nnet3/tdnn/make_configs.py \ + --self-repair-scale-nonlinearity 0.00001 \ + --feat-dir $train_data_dir \ + --ivector-dir $train_ivector_dir \ + --ali-dir $ali_dir \ + --relu-dim $relu_dim \ + --splice-indexes "$splice_indexes" \ + --use-presoftmax-prior-scale true \ + --include-log-softmax true \ + --final-layer-normalize-target 1.0 \ + $dir/configs || exit 1; +fi + +if [ $stage -le 13 ]; then if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then utils/create_split_dir.pl \ /export/b0{3,4,5,6}/$USER/kaldi-data/egs/ami-$(date +'%m_%d_%H_%M')/s5b/$dir/egs/storage $dir/egs/storage fi - steps/nnet3/tdnn/train.sh --stage $train_stage \ - --num-epochs $num_epochs --num-jobs-initial 2 --num-jobs-final 12 \ - --splice-indexes "$splice_indexes" \ - --feat-type raw \ - --online-ivector-dir ${train_ivector_dir} \ - --cmvn-opts "--norm-means=false --norm-vars=false" \ - --initial-effective-lrate 0.0015 --final-effective-lrate 0.00015 \ + steps/nnet3/train_dnn.py --stage $train_stage \ --cmd "$decode_cmd" \ - --relu-dim "$relu_dim" \ - --remove-egs "$remove_egs" \ - $train_data_dir data/lang $ali_dir $dir + --feat.online-ivector-dir $train_ivector_dir \ + --feat.cmvn-opts "--norm-means=false --norm-vars=false" \ + --egs.dir "$common_egs_dir" \ + --trainer.samples-per-iter 400000 \ + --trainer.num-epochs $num_epochs \ + --trainer.optimization.num-jobs-initial 2 \ + --trainer.optimization.num-jobs-final 12 \ + --trainer.optimization.initial-effective-lrate 0.0015 \ + --trainer.optimization.final-effective-lrate 0.00015 \ + --trainer.max-param-change 2.0 \ + --cleanup.remove-egs "$remove_egs" \ + --cleanup true \ + --feat-dir $train_data_dir \ + --lang data/lang \ + --ali-dir $ali_dir \ + --dir $dir fi -if [ $stage -le 12 ]; then +if [ $stage -le 14 ]; then rm $dir/.error || true 2>/dev/null for decode_set in dev eval; do ( + nj_dev=`cat data/$mic/${decode_set}_hires/spk2utt | wc -l` + if [ $nj_dev -gt $nj ]; then + nj_dev=$nj + fi decode_dir=${dir}/decode_${decode_set} - steps/nnet3/decode.sh --nj $nj --cmd "$decode_cmd" \ + steps/nnet3/decode.sh --nj $nj_dev --cmd "$decode_cmd" \ --online-ivector-dir exp/$mic/nnet3${nnet3_affix}/ivectors_${decode_set}_hires \ $graph_dir data/$mic/${decode_set}_hires $decode_dir ) & diff --git a/egs/ami/s5b/local/prepare_parallel_train_data.sh b/egs/ami/s5b/local/prepare_parallel_train_data.sh index b049c906c3b..b551bacfb92 100755 --- a/egs/ami/s5b/local/prepare_parallel_train_data.sh +++ b/egs/ami/s5b/local/prepare_parallel_train_data.sh @@ -5,6 +5,10 @@ # but the wav data is copied from data/ihm. This is a little tricky because the # utterance ids are different between the different mics +train_set=train + +. utils/parse_options.sh + if [ $# != 1 ]; then echo "Usage: $0 [sdm1|mdm8]" @@ -18,12 +22,10 @@ if [ $mic == "ihm" ]; then exit 1; fi -train_set=train - . cmd.sh . ./path.sh -for f in data/ihm/train/utt2spk data/$mic/train/utt2spk; do +for f in data/ihm/${train_set}/utt2spk data/$mic/${train_set}/utt2spk; do if [ ! -f $f ]; then echo "$0: expected file $f to exist" exit 1 @@ -32,12 +34,12 @@ done set -e -o pipefail -mkdir -p data/$mic/train_ihmdata +mkdir -p data/$mic/${train_set}_ihmdata # the utterance-ids and speaker ids will be from the SDM or MDM data -cp data/$mic/train/{spk2utt,text,utt2spk} data/$mic/train_ihmdata/ +cp data/$mic/${train_set}/{spk2utt,text,utt2spk} data/$mic/${train_set}_ihmdata/ # the recording-ids will be from the IHM data. -cp data/ihm/train/{wav.scp,reco2file_and_channel} data/$mic/train_ihmdata/ +cp data/ihm/${train_set}/{wav.scp,reco2file_and_channel} data/$mic/${train_set}_ihmdata/ # map sdm/mdm segments to the ihm segments @@ -47,19 +49,17 @@ mic_base_upcase=$(echo $mic | sed 's/[0-9]//g' | tr 'a-z' 'A-Z') # It has lines like: # AMI_EN2001a_H02_FEO065_0021133_0021442 AMI_EN2001a_SDM_FEO065_0021133_0021442 -tmpdir=data/$mic/train_ihmdata/ +tmpdir=data/$mic/${train_set}_ihmdata/ -awk '{print $1, $1}' $tmpdir/ihmutt2utt # Map the 1st field of the segments file from the ihm data (the 1st field being # the utterance-id) to the corresponding SDM or MDM utterance-id. The other # fields remain the same (e.g. we want the recording-ids from the IHM data). -utils/apply_map.pl -f 1 $tmpdir/ihmutt2utt data/$mic/train_ihmdata/segments - -utils/fix_data_dir.sh data/$mic/train_ihmdata +utils/apply_map.pl -f 1 $tmpdir/ihmutt2utt data/$mic/${train_set}_ihmdata/segments -rm $tmpdir/ihmutt2utt +utils/fix_data_dir.sh data/$mic/${train_set}_ihmdata exit 0; diff --git a/egs/ami/s5b/local/run_cleanup_segmentation.sh b/egs/ami/s5b/local/run_cleanup_segmentation.sh index e2f0b0516ce..9a947ce1fce 100755 --- a/egs/ami/s5b/local/run_cleanup_segmentation.sh +++ b/egs/ami/s5b/local/run_cleanup_segmentation.sh @@ -129,7 +129,6 @@ fi final_lm=`cat data/local/lm/final_lm` LM=$final_lm.pr1-7 - if [ $stage -le 5 ]; then graph_dir=exp/$mic/${gmm}_${cleanup_affix}/graph_$LM nj_dev=$(cat data/$mic/dev/spk2utt | wc -l) @@ -137,9 +136,9 @@ if [ $stage -le 5 ]; then $decode_cmd $graph_dir/mkgraph.log \ utils/mkgraph.sh data/lang_$LM exp/$mic/${gmm}_${cleanup_affix} $graph_dir - steps/decode_fmllr.sh --nj $nj --cmd "$decode_cmd" --config conf/decode.conf \ + steps/decode_fmllr.sh --nj $nj_dev --cmd "$decode_cmd" --config conf/decode.conf \ $graph_dir data/$mic/dev exp/$mic/${gmm}_${cleanup_affix}/decode_dev_$LM - steps/decode_fmllr.sh --nj $nj --cmd "$decode_cmd" --config conf/decode.conf \ + steps/decode_fmllr.sh --nj $nj_eval --cmd "$decode_cmd" --config conf/decode.conf \ $graph_dir data/$mic/eval exp/$mic/${gmm}_${cleanup_affix}/decode_eval_$LM fi diff --git a/egs/ami/s5b/local/run_prepare_rt.sh b/egs/ami/s5b/local/run_prepare_rt.sh new file mode 120000 index 00000000000..e10f1d53a19 --- /dev/null +++ b/egs/ami/s5b/local/run_prepare_rt.sh @@ -0,0 +1 @@ +../../../rt/s5/local/run_prepare_rt.sh \ No newline at end of file diff --git a/egs/ami/s5b/local/run_train_raw_lstm.sh b/egs/ami/s5b/local/run_train_raw_lstm.sh new file mode 100755 index 00000000000..5c0431fe796 --- /dev/null +++ b/egs/ami/s5b/local/run_train_raw_lstm.sh @@ -0,0 +1,143 @@ +#!/bin/bash + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= + +# LSTM options +splice_indexes="-2,-1,0,1,2 0" +label_delay=0 +num_lstm_layers=2 +cell_dim=64 +hidden_dim=64 +recurrent_projection_dim=32 +non_recurrent_projection_dim=32 +chunk_width=40 +chunk_left_context=40 +lstm_delay="-1 -2" + +# training options +num_epochs=3 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +momentum=0.5 +num_chunk_per_minibatch=256 +samples_per_iter=20000 +remove_egs=false +max_param_change=1 + +num_utts_subset_valid=6 +num_utts_subset_train=6 + +use_dense_targets=false +extra_egs_copy_cmd="nnet3-copy-egs-overlap-detection ark:- ark:- |" + +# target options +train_data_dir=data/sdm1/train_whole_sp_hires_bp +targets_scp=exp/sdm1/overlap_speech_train_cleaned_sp/overlap_feats.scp +deriv_weights_scp=exp/sdm1/overlap_speech_train_cleaned_sp/deriv_weights.scp +egs_dir= +nj=40 +feat_type=raw +config_dir= +compute_objf_opts= + +mic=sdm1 +dir= +affix=a + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +num_hidden_layers=`echo $splice_indexes | perl -ane 'print scalar @F'` || exit 1 +if [ -z "$dir" ]; then + dir=exp/$mic/nnet3_raw/nnet_lstm +fi + +dir=$dir${affix:+_$affix}_n${num_hidden_layers} +if [ $label_delay -gt 0 ]; then dir=${dir}_ld$label_delay; fi + + +if ! cuda-compiled; then + cat <' $dir/ascore_${LMWT}/${name}.JOB.ctm || touch $dir/.error; + if $decode_mbr; then + $cmd JOB=1:$nj $dir/ascoring/log/get_ctm.${LMWT}.JOB.log \ + mkdir -p $dir/ascore_${LMWT}/ '&&' \ + lattice-scale --inv-acoustic-scale=${LMWT} "ark:gunzip -c $dir/lat.JOB.gz|" ark:- \| \ + lattice-limit-depth ark:- ark:- \| \ + lattice-push --push-strings=false ark:- ark:- \| \ + lattice-align-words-lexicon --max-expand=10.0 \ + $lang/phones/align_lexicon.int $model ark:- ark:- \| \ + lattice-to-ctm-conf $frame_shift_opt --decode-mbr=$decode_mbr ark:- - \| \ + utils/int2sym.pl -f 5 $lang/words.txt \ + '>' $dir/ascore_${LMWT}/${name}.JOB.ctm || touch $dir/.error; + else + $cmd JOB=1:$nj $dir/ascoring/log/get_ctm.${LMWT}.JOB.log \ + mkdir -p $dir/ascore_${LMWT}/ '&&' \ + lattice-scale --inv-acoustic-scale=${LMWT} "ark:gunzip -c $dir/lat.JOB.gz|" ark:- \| \ + lattice-limit-depth ark:- ark:- \| \ + lattice-1best ark:- ark:- \| \ + lattice-push --push-strings=false ark:- ark:- \| \ + lattice-align-words-lexicon --max-expand=10.0 \ + $lang/phones/align_lexicon.int $model ark:- ark:- \| \ + nbest-to-ctm $frame_shift_opt ark:- - \| \ + utils/int2sym.pl -f 5 $lang/words.txt \ + '>' $dir/ascore_${LMWT}/${name}.JOB.ctm || touch $dir/.error; + fi + # Merge and clean, - for ((n=1; n<=nj; n++)); do cat $dir/ascore_${LMWT}/${name}.${n}.ctm; done > $dir/ascore_${LMWT}/${name}.ctm - rm -f $dir/ascore_${LMWT}/${name}.*.ctm + for ((n=1; n<=nj; n++)); do + cat $dir/ascore_${LMWT}/${name}.${n}.ctm; + rm -f $dir/ascore_${LMWT}/${name}.${n}.ctm + done > $dir/ascore_${LMWT}/${name}.utt.ctm )& done wait; [ -f $dir/.error ] && echo "$0: error during ctm generation. check $dir/ascoring/log/get_ctm.*.log" && exit 1; fi +if [ $stage -le 1 ]; then + for LMWT in $(seq $min_lmwt $max_lmwt); do + cat $dir/ascore_${LMWT}/${name}.utt.ctm | \ + $copy_ctm_script | utils/convert_ctm.pl $data/segments $data/reco2file_and_channel \ + > $dir/ascore_${LMWT}/${name}.ctm || exit 1 + done +fi + if [ $stage -le 1 ]; then # Remove some stuff we don't want to score, from the ctm. # - we remove hesitations here, otherwise the CTM would have a bug! # (confidences in place of the removed hesitations), - for x in $dir/ascore_*/${name}.ctm; do - cp $x $x.tmpf; + for LMWT in $(seq $min_lmwt $max_lmwt); do + x=$dir/ascore_${LMWT}/${name}.ctm + mv $x $x.tmpf; cat $x.tmpf | grep -i -v -E '\[noise|laughter|vocalized-noise\]' | \ grep -i -v -E ' (ACH|AH|EEE|EH|ER|EW|HA|HEE|HM|HMM|HUH|MM|OOF|UH|UM) ' | \ grep -i -v -E '' > $x; @@ -94,8 +126,9 @@ fi if [ $stage -le 2 ]; then if [ "$asclite" == "true" ]; then - oname=$name + oname=${name} [ ! -z $overlap_spk ] && oname=${name}_o$overlap_spk + oname=${oname}${stm_suffix} echo "asclite is starting" # Run scoring, meaning of hubscr.pl options: # -G .. produce alignment graphs, @@ -109,10 +142,10 @@ if [ $stage -le 2 ]; then # -V .. skip validation of input transcripts, # -h rt-stt .. removes non-lexical items from CTM, $cmd LMWT=$min_lmwt:$max_lmwt $dir/ascoring/log/score.LMWT.log \ - cp $data/stm $dir/ascore_LMWT/ '&&' \ + cp $data/stm${stm_suffix} $dir/ascore_LMWT/ '&&' \ cp $dir/ascore_LMWT/${name}.ctm $dir/ascore_LMWT/${oname}.ctm '&&' \ $hubscr -G -v -m 1:2 -o$overlap_spk -a -C -B 8192 -p $hubdir -V -l english \ - -h rt-stt -g $data/glm -r $dir/ascore_LMWT/stm $dir/ascore_LMWT/${oname}.ctm || exit 1 + -h rt-stt -g $data/glm -r $dir/ascore_LMWT/stm${stm_suffix} $dir/ascore_LMWT/${oname}.ctm || exit 1 # Compress some scoring outputs : alignment info and graphs, echo -n "compressing asclite outputs " for LMWT in $(seq $min_lmwt $max_lmwt); do @@ -126,8 +159,8 @@ if [ $stage -le 2 ]; then echo done else $cmd LMWT=$min_lmwt:$max_lmwt $dir/ascoring/log/score.LMWT.log \ - cp $data/stm $dir/ascore_LMWT/ '&&' \ - $hubscr -p $hubdir -v -V -l english -h hub5 -g $data/glm -r $dir/ascore_LMWT/stm $dir/ascore_LMWT/${name}.ctm || exit 1 + cp $data/stm${stm_suffix} $dir/ascore_LMWT/ '&&' \ + $hubscr -p $hubdir -v -V -l english -h hub5 -g $data/glm -r $dir/ascore_LMWT/stm${suffix} $dir/ascore_LMWT/${name}${stm_suffix}.ctm || exit 1 fi fi diff --git a/egs/ami/s5b/path.sh b/egs/ami/s5b/path.sh index ad2c93b309b..b4711d23926 100644 --- a/egs/ami/s5b/path.sh +++ b/egs/ami/s5b/path.sh @@ -9,5 +9,4 @@ LMBIN=$KALDI_ROOT/tools/irstlm/bin SRILM=$KALDI_ROOT/tools/srilm/bin/i686-m64 BEAMFORMIT=$KALDI_ROOT/tools/BeamformIt -export PATH=$PATH:$LMBIN:$BEAMFORMIT:$SRILM - +export PATH=$LMBIN:$BEAMFORMIT:$SRILM:$PATH diff --git a/egs/aspire/s5/conf/mfcc_hires_bp.conf b/egs/aspire/s5/conf/mfcc_hires_bp.conf new file mode 100644 index 00000000000..64292e8b489 --- /dev/null +++ b/egs/aspire/s5/conf/mfcc_hires_bp.conf @@ -0,0 +1,13 @@ +# config for high-resolution MFCC features, intended for neural network training. +# Note: we keep all cepstra, so it has the same info as filterbank features, +# but MFCC is more easily compressible (because less correlated) which is why +# we prefer this method. +--use-energy=false # use average of log energy, not energy. +--sample-frequency=8000 # Switchboard is sampled at 8kHz +--num-mel-bins=28 +--num-ceps=28 +--cepstral-lifter=0 +--low-freq=330 # low cutoff frequency for mel bins +--high-freq=-1000 # high cutoff frequently, relative to Nyquist of 4000 (=3000) + + diff --git a/egs/aspire/s5/conf/segmentation_music.conf b/egs/aspire/s5/conf/segmentation_music.conf new file mode 100644 index 00000000000..28b5feaf5d5 --- /dev/null +++ b/egs/aspire/s5/conf/segmentation_music.conf @@ -0,0 +1,14 @@ +# General segmentation options +pad_length=-1 # Pad speech segments by this many frames on either side +max_blend_length=-1 # Maximum duration of speech that will be removed as part + # of smoothing process. This is only if there are no other + # speech segments nearby. +max_intersegment_length=0 # Merge nearby speech segments if the silence + # between them is less than this many frames. +post_pad_length=-1 # Pad speech segments by this many frames on either side + # after the merging process using max_intersegment_length +max_segment_length=1000 # Segments that are longer than this are split into + # overlapping frames. +overlap_length=250 # Overlapping frames when segments are split. + # See the above option. +min_silence_length=100000 # Min silence length at which to split very long segments diff --git a/egs/aspire/s5/conf/segmentation_ovlp.conf b/egs/aspire/s5/conf/segmentation_ovlp.conf new file mode 100644 index 00000000000..28b5feaf5d5 --- /dev/null +++ b/egs/aspire/s5/conf/segmentation_ovlp.conf @@ -0,0 +1,14 @@ +# General segmentation options +pad_length=-1 # Pad speech segments by this many frames on either side +max_blend_length=-1 # Maximum duration of speech that will be removed as part + # of smoothing process. This is only if there are no other + # speech segments nearby. +max_intersegment_length=0 # Merge nearby speech segments if the silence + # between them is less than this many frames. +post_pad_length=-1 # Pad speech segments by this many frames on either side + # after the merging process using max_intersegment_length +max_segment_length=1000 # Segments that are longer than this are split into + # overlapping frames. +overlap_length=250 # Overlapping frames when segments are split. + # See the above option. +min_silence_length=100000 # Min silence length at which to split very long segments diff --git a/egs/aspire/s5/conf/segmentation_speech.conf b/egs/aspire/s5/conf/segmentation_speech.conf new file mode 100644 index 00000000000..c4c75b212fc --- /dev/null +++ b/egs/aspire/s5/conf/segmentation_speech.conf @@ -0,0 +1,14 @@ +# General segmentation options +pad_length=20 # Pad speech segments by this many frames on either side +max_relabel_length=10 # Maximum duration of speech that will be removed as part + # of smoothing process. This is only if there are no other + # speech segments nearby. +max_intersegment_length=30 # Merge nearby speech segments if the silence + # between them is less than this many frames. +post_pad_length=10 # Pad speech segments by this many frames on either side + # after the merging process using max_intersegment_length +max_segment_length=1000 # Segments that are longer than this are split into + # overlapping frames. +overlap_length=250 # Overlapping frames when segments are split. + # See the above option. +min_silence_length=20 # Min silence length at which to split very long segments diff --git a/egs/aspire/s5/conf/segmentation_speech_simple.conf b/egs/aspire/s5/conf/segmentation_speech_simple.conf new file mode 100644 index 00000000000..56c178c8115 --- /dev/null +++ b/egs/aspire/s5/conf/segmentation_speech_simple.conf @@ -0,0 +1,14 @@ +# General segmentation options +pad_length=20 # Pad speech segments by this many frames on either side +max_relabel_length=-1 # Maximum duration of speech that will be removed as part + # of smoothing process. This is only if there are no other + # speech segments nearby. +max_intersegment_length=30 # Merge nearby speech segments if the silence + # between them is less than this many frames. +post_pad_length=-1 # Pad speech segments by this many frames on either side + # after the merging process using max_intersegment_length +max_segment_length=1000 # Segments that are longer than this are split into + # overlapping frames. +overlap_length=250 # Overlapping frames when segments are split. + # See the above option. +min_silence_length=20 # Min silence length at which to split very long segments diff --git a/egs/aspire/s5/local/multi_condition/get_ctm.sh b/egs/aspire/s5/local/multi_condition/get_ctm.sh index f67a1191544..67c2c0bd87b 100755 --- a/egs/aspire/s5/local/multi_condition/get_ctm.sh +++ b/egs/aspire/s5/local/multi_condition/get_ctm.sh @@ -7,7 +7,7 @@ decode_mbr=true filter_ctm_command=cp glm= stm= -window=10 +resolve_overlaps=true overlap=5 [ -f ./path.sh ] && . ./path.sh . parse_options.sh || exit 1; @@ -62,7 +62,11 @@ lattice-align-words-lexicon --output-error-lats=true --output-if-empty=true --ma lattice-to-ctm-conf $frame_shift_opt --decode-mbr=$decode_mbr ark:- $decode_dir/score_$LMWT/penalty_$wip/ctm.overlapping || exit 1; # combine the segment-wise ctm files, while resolving overlaps -python local/multi_condition/resolve_ctm_overlaps.py --overlap $overlap --window-length $window $data_dir/utt2spk $decode_dir/score_$LMWT/penalty_$wip/ctm.overlapping $decode_dir/score_$LMWT/penalty_$wip/ctm.merged || exit 1; +if $resolve_overlaps; then + steps/resolve_ctm_overlaps.py $data_dir/segments $decode_dir/score_$LMWT/penalty_$wip/ctm.overlapping $decode_dir/score_$LMWT/penalty_$wip/ctm.merged || exit 1; +else + cp $decode_dir/score_$LMWT/penalty_$wip/ctm.overlapping $decode_dir/score_$LMWT/penalty_$wip/ctm.merged || exit 1; +fi merged_ctm=$decode_dir/score_$LMWT/penalty_$wip/ctm.merged cat $merged_ctm | utils/int2sym.pl -f 5 $lang/words.txt | \ diff --git a/egs/aspire/s5/local/nnet3/prep_test_aspire_segmentation.sh b/egs/aspire/s5/local/nnet3/prep_test_aspire_segmentation.sh new file mode 100755 index 00000000000..266781fc84d --- /dev/null +++ b/egs/aspire/s5/local/nnet3/prep_test_aspire_segmentation.sh @@ -0,0 +1,160 @@ +#!/bin/bash + +# Copyright Johns Hopkins University (Author: Daniel Povey, Vijayaditya Peddinti) 2016. Apache 2.0. +# This script generates the ctm files for dev_aspire, test_aspire and eval_aspire +# for scoring with ASpIRE scoring server. +# It also provides the WER for dev_aspire data. + +set -e +set -o pipefail +set -u + +# general opts +iter=final +stage=0 +decode_num_jobs=30 +num_jobs=30 +affix= + +sad_iter=final + +# ivector opts +max_count=75 # parameter for extract_ivectors.sh +sub_speaker_frames=6000 +ivector_scale=0.75 +filter_ctm=true +weights_file= +silence_weight=0.00001 + +# decode opts +pass2_decode_opts="--min-active 1000" +lattice_beam=8 +extra_left_context=0 # change for (B)LSTM +extra_right_context=0 # change for BLSTM +frames_per_chunk=50 # change for (B)LSTM +acwt=0.1 # important to change this when using chain models +post_decode_acwt=1.0 # important to change this when using chain models + +. ./cmd.sh +[ -f ./path.sh ] && . ./path.sh +. utils/parse_options.sh || exit 1; + +if [ $# -ne 5 ]; then + echo "Usage: $0 [options] " + echo " Options:" + echo " --stage (0|1|2) # start scoring script from part-way through." + echo "e.g.:" + echo "$0 dev_aspire data/lang exp/tri5a/graph_pp exp/nnet3/tdnn" + exit 1; +fi + +data_set=$1 +sad_nnet_dir=$2 +lang=$3 # data/lang +graph=$4 #exp/tri5a/graph_pp +dir=$5 # exp/nnet3/tdnn + +model_affix=`basename $dir` +ivector_dir=exp/nnet3 +ivector_affix=${affix:+_$affix}_chain_${model_affix}_iter$iter +affix=_${affix}_iter${iter} +act_data_set=${data_set} # we will modify the data dir, when segmenting it + # so we will keep track of original data dirfor the glm and stm files + +if [[ "$data_set" =~ "test_aspire" ]]; then + out_file=single_dev_test${affix}_$model_affix.ctm +elif [[ "$data_set" =~ "eval_aspire" ]]; then + out_file=single_eval${affix}_$model_affix.ctm +elif [[ "$data_set" =~ "dev_aspire" ]]; then + # we will just decode the directory without oracle segments file + # as we would like to operate in the actual evaluation condition + out_file=single_dev${affix}_${model_affix}.ctm +else + exit 1 +fi + +if [ $stage -le 1 ]; then + steps/segmentation/do_segmentation_data_dir.sh --reco-nj $num_jobs \ + --mfcc-config conf/mfcc_hires_bp.conf --feat-affix bp --iter $sad_iter \ + --do-downsampling false --extra-left-context 100 --extra-right-context 20 \ + --output-name output-speech --frame-subsampling-factor 6 \ + data/${data_set} $sad_nnet_dir mfcc_hires_bp data/${data_set}${affix} + # Output will be in data/${data_set}_seg +fi + +# uniform segmentation script would have created this dataset +# so update that script if you plan to change this variable +segmented_data_set=${data_set}${affix}_seg + +if [ $stage -le 2 ]; then + mfccdir=mfcc_reverb + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $mfccdir/storage ]; then + date=$(date +'%m_%d_%H_%M') + utils/create_split_dir.pl /export/b0{1,2,3,4}/$USER/kaldi-data/egs/aspire-$date/s5/$mfccdir/storage $mfccdir/storage + fi + + utils/copy_data_dir.sh data/${segmented_data_set} data/${segmented_data_set}_hires + steps/make_mfcc.sh --nj 30 --cmd "$train_cmd" \ + --mfcc-config conf/mfcc_hires.conf data/${segmented_data_set}_hires \ + exp/make_reverb_hires/${segmented_data_set} $mfccdir + steps/compute_cmvn_stats.sh data/${segmented_data_set}_hires \ + exp/make_reverb_hires/${segmented_data_set} $mfccdir + utils/fix_data_dir.sh data/${segmented_data_set}_hires + utils/validate_data_dir.sh --no-text data/${segmented_data_set}_hires +fi + +decode_dir=$dir/decode_${segmented_data_set}_pp +if [ $stage -le 5 ]; then + echo "Extracting i-vectors, stage 2" + # this does offline decoding, except we estimate the iVectors per + # speaker, excluding silence (based on alignments from a DNN decoding), with a + # different script. This is just to demonstrate that script. + # the --sub-speaker-frames is optional; if provided, it will divide each speaker + # up into "sub-speakers" of at least that many frames... can be useful if + # acoustic conditions drift over time within the speaker's data. + steps/online/nnet2/extract_ivectors.sh --cmd "$train_cmd" --nj 20 \ + --sub-speaker-frames $sub_speaker_frames --max-count $max_count \ + data/${segmented_data_set}_hires $lang $ivector_dir/extractor \ + $ivector_dir/ivectors_${segmented_data_set}${ivector_affix}; +fi + +if [ $stage -le 6 ]; then + echo "Generating lattices, stage 2 with --acwt $acwt" + rm -f ${decode_dir}_tg/.error + steps/nnet3/decode.sh --nj $decode_num_jobs --cmd "$decode_cmd" --config conf/decode.config $pass2_decode_opts \ + --acwt $acwt --post-decode-acwt $post_decode_acwt \ + --extra-left-context $extra_left_context \ + --extra-right-context $extra_right_context \ + --frames-per-chunk "$frames_per_chunk" \ + --skip-scoring true --iter $iter --lattice-beam $lattice_beam \ + --online-ivector-dir $ivector_dir/ivectors_${segmented_data_set}${ivector_affix} \ + $graph data/${segmented_data_set}_hires ${decode_dir}_tg || touch ${decode_dir}_tg/.error + [ -f ${decode_dir}_tg/.error ] && echo "$0: Error decoding" && exit 1; +fi + +if [ $stage -le 7 ]; then + echo "Rescoring lattices" + steps/lmrescore_const_arpa.sh --cmd "$decode_cmd" \ + --skip-scoring true \ + ${lang}_pp_test{,_fg} data/${segmented_data_set}_hires \ + ${decode_dir}_{tg,fg}; +fi + +decode_dir=${decode_dir}_fg + +if [ $stage -le 8 ]; then + local/score_aspire.sh --cmd "$decode_cmd" \ + --min-lmwt 1 --max-lmwt 20 \ + --word-ins-penalties "0.0,0.25,0.5,0.75,1.0" \ + --ctm-beam 6 \ + --iter $iter \ + --decode-mbr true \ + --resolve-overlaps false \ + --tune-hyper true \ + $lang $decode_dir $act_data_set $segmented_data_set $out_file +fi + +# Two-pass decoding baseline +# %WER 27.8 | 2120 27217 | 78.2 13.6 8.2 6.0 27.8 75.9 | -0.613 | exp/chain/tdnn_7b/decode_dev_aspire_whole_uniformsegmented_win10_over5_v6_200jobs_iterfinal_pp_fg/score_9/penalty_0.0/ctm.filt.filt.sys +# Using automatic segmentation +# %WER 28.2 | 2120 27214 | 76.5 12.4 11.1 4.7 28.2 75.2 | -0.522 | exp/chain/tdnn_7b/decode_dev_aspire_seg_v7_n_stddev_iterfinal_pp_fg/score_10/penalty_0.0/ctm.filt.filt.sys diff --git a/egs/aspire/s5/local/score_aspire.sh b/egs/aspire/s5/local/score_aspire.sh index 3e35b6d3dae..9c08a6c85d1 100755 --- a/egs/aspire/s5/local/score_aspire.sh +++ b/egs/aspire/s5/local/score_aspire.sh @@ -14,10 +14,9 @@ word_ins_penalties=0.0,0.25,0.5,0.75,1.0 default_wip=0.0 ctm_beam=6 decode_mbr=true -window=30 -overlap=5 cmd=run.pl stage=1 +resolve_overlaps=true tune_hyper=true # if true: # if the data set is "dev_aspire" we check for the # best lmwt and word_insertion_penalty, @@ -89,7 +88,7 @@ if $tune_hyper ; then # or use the default values if [ $stage -le 1 ]; then - if [ "$act_data_set" == "dev_aspire" ]; then + if [[ "$act_data_set" =~ "dev_aspire" ]]; then wip_string=$(echo $word_ins_penalties | sed 's/,/ /g') temp_wips=($wip_string) $cmd WIP=1:${#temp_wips[@]} $decode_dir/scoring/log/score.wip.WIP.log \ @@ -98,8 +97,8 @@ if $tune_hyper ; then echo \$wip \&\& \ $cmd LMWT=$min_lmwt:$max_lmwt $decode_dir/scoring/log/score.LMWT.\$wip.log \ local/multi_condition/get_ctm.sh --filter-ctm-command "$filter_ctm_command" \ - --window $window --overlap $overlap \ --beam $ctm_beam --decode-mbr $decode_mbr \ + --resolve-overlaps $resolve_overlaps \ --glm data/${act_data_set}/glm --stm data/${act_data_set}/stm \ LMWT \$wip $lang data/${segmented_data_set}_hires $model $decode_dir || exit 1; @@ -124,7 +123,7 @@ wipfile.close() fi - if [ "$act_data_set" == "test_aspire" ] || [ "$act_data_set" == "eval_aspire" ]; then + if [[ "$act_data_set" =~ "test_aspire" ]] || [[ "$act_data_set" =~ "eval_aspire" ]]; then # check for the best values from dev_aspire decodes dev_decode_dir=$(echo $decode_dir|sed "s/test_aspire/dev_aspire_whole/g; s/eval_aspire/dev_aspire_whole/g") if [ -f $dev_decode_dir/scoring/bestLMWT ]; then diff --git a/egs/aspire/s5/local/segmentation/do_corruption_data_dir.sh b/egs/aspire/s5/local/segmentation/do_corruption_data_dir.sh new file mode 100755 index 00000000000..45fdf6c1c5c --- /dev/null +++ b/egs/aspire/s5/local/segmentation/do_corruption_data_dir.sh @@ -0,0 +1,138 @@ +#! /bin/bash + +# Copyright 2016 Vimal Manohar +# Apache 2.0 + +set -e +set -u +set -o pipefail + +. path.sh + +stage=0 +corruption_stage=-10 +corrupt_only=false + +# Data options +data_dir=data/train_si284 # Expecting whole data directory. +speed_perturb=true +num_data_reps=5 # Number of corrupted versions +snrs="20:10:15:5:0:-5" +foreground_snrs="20:10:15:5:0:-5" +background_snrs="20:10:15:5:2:0:-2:-5" +base_rirs=simulated +speeds="0.9 1.0 1.1" + +# Parallel options +reco_nj=40 +cmd=queue.pl + +# Options for feature extraction +mfcc_config=conf/mfcc_hires_bp.conf +feat_suffix=hires_bp + +reco_vad_dir= # Output of prepare_unsad_data.sh. + # If provided, the speech labels and deriv weights will be + # copied into the output data directory. + +. utils/parse_options.sh + +if [ $# -ne 0 ]; then + echo "Usage: $0" + exit 1 +fi + +data_id=`basename ${data_dir}` + +rvb_opts=() +if [ "$base_rirs" == "simulated" ]; then + # This is the config for the system using simulated RIRs and point-source noises + rvb_opts+=(--rir-set-parameters "0.5, RIRS_NOISES/simulated_rirs/smallroom/rir_list") + rvb_opts+=(--rir-set-parameters "0.5, RIRS_NOISES/simulated_rirs/mediumroom/rir_list") + rvb_opts+=(--noise-set-parameters "0.1, RIRS_NOISES/pointsource_noises/background_noise_list") + rvb_opts+=(--noise-set-parameters "0.9, RIRS_NOISES/pointsource_noises/foreground_noise_list") +else + # This is the config for the JHU ASpIRE submission system + rvb_opts+=(--rir-set-parameters "1.0, RIRS_NOISES/real_rirs_isotropic_noises/rir_list") + rvb_opts+=(--noise-set-parameters RIRS_NOISES/real_rirs_isotropic_noises/noise_list) +fi + +corrupted_data_id=${data_id}_corrupted + +if [ $stage -le 1 ]; then + python steps/data/reverberate_data_dir.py \ + "${rvb_opts[@]}" \ + --prefix="rev" \ + --foreground-snrs=$foreground_snrs \ + --background-snrs=$background_snrs \ + --speech-rvb-probability=1 \ + --pointsource-noise-addition-probability=1 \ + --isotropic-noise-addition-probability=1 \ + --num-replications=$num_data_reps \ + --max-noises-per-minute=2 \ + data/${data_id} data/${corrupted_data_id} +fi + +corrupted_data_dir=data/${corrupted_data_id} + +if $speed_perturb; then + if [ $stage -le 2 ]; then + ## Assuming whole data directories + for x in $corrupted_data_dir; do + cp $x/reco2dur $x/utt2dur + utils/data/perturb_data_dir_speed_random.sh --speeds "$speeds" $x ${x}_spr + done + fi + + corrupted_data_dir=${corrupted_data_dir}_spr + corrupted_data_id=${corrupted_data_id}_spr + + if [ $stage -le 3 ]; then + utils/data/perturb_data_dir_volume.sh --scale-low 0.03125 --scale-high 2 \ + ${corrupted_data_dir} + fi +fi + +if $corrupt_only; then + echo "$0: Got corrupted data directory in ${corrupted_data_dir}" + exit 0 +fi + +mfccdir=`basename $mfcc_config` +mfccdir=${mfccdir%%.conf} + +if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $mfccdir/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$mfccdir/storage $mfccdir/storage +fi + +if [ $stage -le 4 ]; then + utils/copy_data_dir.sh $corrupted_data_dir ${corrupted_data_dir}_$feat_suffix + corrupted_data_dir=${corrupted_data_dir}_$feat_suffix + steps/make_mfcc.sh --mfcc-config $mfcc_config \ + --cmd "$cmd" --nj $reco_nj \ + $corrupted_data_dir exp/make_${feat_suffix}/${corrupted_data_id} $mfccdir + steps/compute_cmvn_stats.sh --fake \ + $corrupted_data_dir exp/make_${feat_suffix}/${corrupted_data_id} $mfccdir +else + corrupted_data_dir=${corrupted_data_dir}_$feat_suffix +fi + +if [ $stage -le 8 ]; then + if [ ! -z "$reco_vad_dir" ]; then + if [ ! -f $reco_vad_dir/speech_labels.scp ]; then + echo "$0: Could not find file $reco_vad_dir/speech_labels.scp" + exit 1 + fi + + cat $reco_vad_dir/speech_labels.scp | \ + steps/segmentation/get_reverb_scp.pl -f 1 $num_data_reps | \ + sort -k1,1 > ${corrupted_data_dir}/speech_labels.scp + + cat $reco_vad_dir/deriv_weights.scp | \ + steps/segmentation/get_reverb_scp.pl -f 1 $num_data_reps | \ + sort -k1,1 > ${corrupted_data_dir}/deriv_weights.scp + fi +fi + +exit 0 diff --git a/egs/aspire/s5/local/segmentation/do_corruption_data_dir_music.sh b/egs/aspire/s5/local/segmentation/do_corruption_data_dir_music.sh new file mode 100755 index 00000000000..8865e640674 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/do_corruption_data_dir_music.sh @@ -0,0 +1,236 @@ +#!/bin/bash +set -e +set -u +set -o pipefail + +. path.sh +. cmd.sh + +num_data_reps=5 +data_dir=data/train_si284 + +nj=40 +reco_nj=40 + +stage=0 +corruption_stage=-10 + +pad_silence=false + +mfcc_config=conf/mfcc_hires_bp_vh.conf +feat_suffix=hires_bp_vh +mfcc_irm_config=conf/mfcc_hires_bp.conf + +dry_run=false +corrupt_only=false +speed_perturb=true +speeds="0.9 1.0 1.1" + +reco_vad_dir= + +max_jobs_run=20 + +foreground_snrs="5:2:1:0:-2:-5:-10:-20" +background_snrs="5:2:1:0:-2:-5:-10:-20" + +. utils/parse_options.sh + +if [ $# -ne 0 ]; then + echo "Usage: $0" + exit 1 +fi + +data_id=`basename ${data_dir}` + +rvb_opts=() +# This is the config for the system using simulated RIRs and point-source noises +rvb_opts+=(--rir-set-parameters "0.5, RIRS_NOISES/simulated_rirs/smallroom/rir_list") +rvb_opts+=(--rir-set-parameters "0.5, RIRS_NOISES/simulated_rirs/mediumroom/rir_list") +rvb_opts+=(--noise-set-parameters RIRS_NOISES/music/music_list) + +music_utt2num_frames=RIRS_NOISES/music/split_utt2num_frames + +corrupted_data_id=${data_id}_music_corrupted +orig_corrupted_data_id=$corrupted_data_id + +if [ $stage -le 1 ]; then + python steps/data/reverberate_data_dir.py \ + "${rvb_opts[@]}" \ + --prefix="music" \ + --foreground-snrs=$foreground_snrs \ + --background-snrs=$background_snrs \ + --speech-rvb-probability=1 \ + --pointsource-noise-addition-probability=1 \ + --isotropic-noise-addition-probability=1 \ + --num-replications=$num_data_reps \ + --max-noises-per-minute=5 \ + data/${data_id} data/${corrupted_data_id} +fi + +if $dry_run; then + exit 0 +fi + +corrupted_data_dir=data/${corrupted_data_id} +# Data dir without speed perturbation +orig_corrupted_data_dir=$corrupted_data_dir + +if $speed_perturb; then + if [ $stage -le 2 ]; then + ## Assuming whole data directories + for x in $corrupted_data_dir; do + cp $x/reco2dur $x/utt2dur + utils/data/perturb_data_dir_speed_random.sh --speeds "$speeds" $x ${x}_spr + done + fi + + corrupted_data_dir=${corrupted_data_dir}_spr + corrupted_data_id=${corrupted_data_id}_spr + + if [ $stage -le 3 ]; then + utils/data/perturb_data_dir_volume.sh --scale-low 0.03125 --scale-high 2 \ + ${corrupted_data_dir} + fi +fi + +if $corrupt_only; then + echo "$0: Got corrupted data directory in ${corrupted_data_dir}" + exit 0 +fi + +mfccdir=`basename $mfcc_config` +mfccdir=${mfccdir%%.conf} + +if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $mfccdir/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$mfccdir/storage $mfccdir/storage +fi + +if [ $stage -le 4 ]; then + if [ ! -z $feat_suffix ]; then + utils/copy_data_dir.sh $corrupted_data_dir ${corrupted_data_dir}_$feat_suffix + corrupted_data_dir=${corrupted_data_dir}_$feat_suffix + fi + steps/make_mfcc.sh --mfcc-config $mfcc_config \ + --cmd "$train_cmd" --nj $reco_nj \ + $corrupted_data_dir exp/make_${mfccdir}/${corrupted_data_id} $mfccdir + steps/compute_cmvn_stats.sh --fake \ + $corrupted_data_dir exp/make_${mfccdir}/${corrupted_data_id} $mfccdir +else + if [ ! -z $feat_suffix ]; then + corrupted_data_dir=${corrupted_data_dir}_$feat_suffix + fi +fi + +if [ $stage -le 8 ]; then + if [ ! -z "$reco_vad_dir" ]; then + if [ ! -f $reco_vad_dir/speech_labels.scp ]; then + echo "$0: Could not find file $reco_vad_dir/speech_labels.scp" + exit 1 + fi + + cat $reco_vad_dir/speech_labels.scp | \ + steps/segmentation/get_reverb_scp.pl -f 1 $num_data_reps "music" | \ + sort -k1,1 > ${corrupted_data_dir}/speech_labels.scp + + cat $reco_vad_dir/deriv_weights.scp | \ + steps/segmentation/get_reverb_scp.pl -f 1 $num_data_reps "music" | \ + sort -k1,1 > ${corrupted_data_dir}/deriv_weights.scp + fi +fi + +# music_dir is without speed perturbation +music_dir=exp/make_music_labels/${orig_corrupted_data_id} +music_data_dir=$music_dir/music_data + +mkdir -p $music_data_dir + +if [ $stage -le 10 ]; then + utils/data/get_reco2num_frames.sh --nj $reco_nj $orig_corrupted_data_dir + utils/split_data.sh --per-reco ${orig_corrupted_data_dir} $reco_nj + + cp $orig_corrupted_data_dir/wav.scp $music_data_dir + + # The first rspecifier is a dummy required to get the recording-id as key. + # It has no segments in it as they are all removed by --remove-labels. + $train_cmd JOB=1:$reco_nj $music_dir/log/get_music_seg.JOB.log \ + segmentation-init-from-additive-signals-info --lengths-rspecifier=ark,t:${orig_corrupted_data_dir}/reco2num_frames \ + --additive-signals-segmentation-rspecifier="ark:segmentation-init-from-lengths ark:$music_utt2num_frames ark:- |" \ + "ark,t:utils/filter_scp.pl ${orig_corrupted_data_dir}/split${reco_nj}reco/JOB/reco2utt $orig_corrupted_data_dir/additive_signals_info.txt |" \ + ark:- \| \ + segmentation-post-process --merge-adjacent-segments ark:- \ + ark:- \| \ + segmentation-to-segments ark:- ark:$music_data_dir/utt2spk.JOB \ + $music_data_dir/segments.JOB + + utils/data/get_reco2utt.sh $corrupted_data_dir + for n in `seq $reco_nj`; do cat $music_data_dir/utt2spk.$n; done > $music_data_dir/utt2spk + for n in `seq $reco_nj`; do cat $music_data_dir/segments.$n; done > $music_data_dir/segments + + utils/fix_data_dir.sh $music_data_dir + + if $speed_perturb; then + utils/data/perturb_data_dir_speed_4way.sh $music_data_dir ${music_data_dir}_spr + mv ${music_data_dir}_spr/segments{,.temp} + cat ${music_data_dir}_spr/segments.temp | \ + utils/filter_scp.pl -f 2 ${corrupted_data_dir}/reco2utt > ${music_data_dir}_spr/segments + utils/fix_data_dir.sh ${music_data_dir}_spr + rm ${music_data_dir}_spr/segments.temp + fi +fi + +if $speed_perturb; then + music_data_dir=${music_data_dir}_spr +fi + +label_dir=music_labels + +mkdir -p $label_dir +label_dir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' $label_dir ${PWD}` + +if [ $stage -le 11 ]; then + utils/split_data.sh --per-reco ${corrupted_data_dir} $reco_nj + # TODO: Don't assume that its whole data directory. + nj=$reco_nj + if [ $nj -gt 4 ]; then + nj=4 + fi + utils/data/get_utt2num_frames.sh --cmd "$train_cmd" --nj $nj ${corrupted_data_dir} + utils/data/get_reco2utt.sh $music_data_dir/ + + $train_cmd JOB=1:$reco_nj $music_dir/log/get_music_labels.JOB.log \ + segmentation-init-from-segments --shift-to-zero=false \ + "utils/filter_scp.pl -f 2 ${corrupted_data_dir}/split${reco_nj}reco/JOB/reco2utt ${music_data_dir}/segments |" ark:- \| \ + segmentation-combine-segments-to-recordings ark:- \ + "ark,t:utils/filter_scp.pl ${corrupted_data_dir}/split${reco_nj}reco/JOB/reco2utt ${music_data_dir}/reco2utt |" \ + ark:- \| \ + segmentation-to-ali --lengths-rspecifier=ark,t:${corrupted_data_dir}/utt2num_frames ark:- \ + ark,scp:$label_dir/music_labels_${corrupted_data_id}.JOB.ark,$label_dir/music_labels_${corrupted_data_id}.JOB.scp +fi + +for n in `seq $reco_nj`; do + cat $label_dir/music_labels_${corrupted_data_id}.$n.scp +done | utils/filter_scp.pl ${corrupted_data_dir}/utt2spk > ${corrupted_data_dir}/music_labels.scp + +if [ $stage -le 12 ]; then + utils/split_data.sh --per-reco ${corrupted_data_dir} $reco_nj + + cat < $music_dir/speech_music_map +0 0 0 +0 1 3 +1 0 1 +1 1 2 +EOF + + $train_cmd JOB=1:$reco_nj $music_dir/log/get_speech_music_labels.JOB.log \ + intersect-int-vectors --mapping-in=$music_dir/speech_music_map --length-tolerance=2 \ + "scp:utils/filter_scp.pl ${corrupted_data_dir}/split${reco_nj}reco/JOB/reco2utt ${corrupted_data_dir}/speech_labels.scp |" \ + "scp:utils/filter_scp.pl ${corrupted_data_dir}/split${reco_nj}reco/JOB/reco2utt ${corrupted_data_dir}/music_labels.scp |" \ + ark,scp:$label_dir/speech_music_labels_${corrupted_data_id}.JOB.ark,$label_dir/speech_music_labels_${corrupted_data_id}.JOB.scp + + for n in `seq $reco_nj`; do + cat $label_dir/speech_music_labels_${corrupted_data_id}.$n.scp + done > $corrupted_data_dir/speech_music_labels.scp +fi + +exit 0 diff --git a/egs/aspire/s5/local/segmentation/do_corruption_data_dir_overlapped_speech.sh b/egs/aspire/s5/local/segmentation/do_corruption_data_dir_overlapped_speech.sh new file mode 100755 index 00000000000..991bec96308 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/do_corruption_data_dir_overlapped_speech.sh @@ -0,0 +1,209 @@ +#! /bin/bash + +# Copyright 2016 Vimal Manohar +# Apache 2.0 + +set -e +set -u +set -o pipefail + +. path.sh + +stage=0 +corruption_stage=-10 +corrupt_only=false + +# Data options +data_dir=data/train_si284 # Excpecting non-whole data directory +num_data_reps=5 # Number of corrupted versions +snrs="20:10:15:5:0:-5" +foreground_snrs="20:10:15:5:0:-5" +background_snrs="20:10:15:5:0:-5" +overlap_snrs="5:2:1:0:-1:-2" +overlap_labels_dir=overlap_labels + +# Parallel options +nj=40 +cmd=queue.pl + +# Options for feature extraction +mfcc_config=conf/mfcc_hires_bp.conf +feat_suffix=hires_bp +energy_config=conf/log_energy.conf + +utt_vad_dir= + +. utils/parse_options.sh + +if [ $# -ne 0 ]; then + echo "Usage: $0" + exit 1 +fi + +rvb_opts=() +# This is the config for the system using simulated RIRs and point-source noises +rvb_opts+=(--rir-set-parameters "0.5, RIRS_NOISES/simulated_rirs/smallroom/rir_list") +rvb_opts+=(--rir-set-parameters "0.5, RIRS_NOISES/simulated_rirs/mediumroom/rir_list") +rvb_opts+=(--speech-segments-set-parameters="$data_dir/wav.scp,$data_dir/segments") + +if [ $stage -le 0 ]; then + steps/segmentation/get_data_dir_with_segmented_wav.py \ + $data_dir ${data_dir}_seg +fi + +data_dir=${data_dir}_seg + +data_id=`basename ${data_dir}` + +corrupted_data_id=${data_id}_ovlp_corrupted +clean_data_id=${data_id}_ovlp_clean +noise_data_id=${data_id}_ovlp_noise + +utils/data/get_reco2dur.sh --cmd $cmd --nj 40 $data_dir + +if [ $stage -le 1 ]; then + python steps/data/make_corrupted_data_dir.py \ + "${rvb_opts[@]}" \ + --prefix="ovlp" \ + --overlap-snrs=$overlap_snrs \ + --speech-rvb-probability=1 \ + --overlapping-speech-addition-probability=1 \ + --num-replications=$num_data_reps \ + --min-overlapping-segments-per-minute=1 \ + --max-overlapping-segments-per-minute=1 \ + --output-additive-noise-dir=data/${noise_data_id} \ + --output-reverb-dir=data/${clean_data_id} \ + ${data_dir} data/${corrupted_data_id} +fi + +clean_data_dir=data/${clean_data_id} +corrupted_data_dir=data/${corrupted_data_id} +noise_data_dir=data/${noise_data_id} +orig_corrupted_data_dir=data/${corrupted_data_id} + +if false; then + if [ $stage -le 2 ]; then + for x in $clean_data_dir $corrupted_data_dir $noise_data_dir; do + utils/data/perturb_data_dir_speed_3way.sh $x ${x}_sp + done + fi + + corrupted_data_dir=${corrupted_data_dir}_sp + clean_data_dir=${clean_data_dir}_sp + noise_data_dir=${noise_data_dir}_sp + + corrupted_data_id=${corrupted_data_id}_sp + clean_data_id=${clean_data_id}_sp + noise_data_id=${noise_data_id}_sp +fi + +if [ $stage -le 3 ]; then + utils/data/perturb_data_dir_volume.sh --scale-low 0.03125 --scale-high 2 ${corrupted_data_dir} + utils/data/perturb_data_dir_volume.sh --reco2vol ${corrupted_data_dir}/reco2vol ${clean_data_dir} + utils/data/perturb_data_dir_volume.sh --reco2vol ${corrupted_data_dir}/reco2vol ${noise_data_dir} +fi + +if $corrupt_only; then + echo "$0: Got corrupted data directory in ${corrupted_data_dir}" + exit 0 +fi + +mfccdir=`basename $mfcc_config` +mfccdir=${mfccdir%%.conf} + +if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $mfccdir/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$mfccdir/storage $mfccdir/storage +fi + +if [ $stage -le 4 ]; then + utils/copy_data_dir.sh $corrupted_data_dir ${corrupted_data_dir}_$feat_suffix + corrupted_data_dir=${corrupted_data_dir}_$feat_suffix + steps/make_mfcc.sh --mfcc-config $mfcc_config \ + --cmd "$cmd" --nj $nj \ + $corrupted_data_dir exp/make_${feat_suffix}/${corrupted_data_id} $mfccdir +else + corrupted_data_dir=${corrupted_data_dir}_$feat_suffix +fi + +if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d log_energy/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/log_energy/storage log_energy/storage +fi + +if [ $stage -le 5 ]; then + utils/copy_data_dir.sh $clean_data_dir ${clean_data_dir}_log_energy + steps/make_mfcc.sh --mfcc-config conf/log_energy.conf \ + --cmd "$cmd" --nj $nj ${clean_data_dir}_log_energy \ + exp/make_log_energy/${clean_data_id} log_energy +fi + +if [ $stage -le 6 ]; then + utils/copy_data_dir.sh $noise_data_dir ${noise_data_dir}_log_energy + steps/make_mfcc.sh --mfcc-config conf/log_energy.conf \ + --cmd "$cmd" --nj $nj ${noise_data_dir}_log_energy \ + exp/make_log_energy/${noise_data_id} log_energy +fi + +targets_dir=log_snr +if [ $stage -le 7 ]; then + mkdir -p exp/make_log_snr/${corrupted_data_id} + + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $targets_dir/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$targets_dir/storage $targets_dir/storage + fi + + # Get log-SNR targets + steps/segmentation/make_snr_targets.sh \ + --nj $nj --cmd "$cmd" \ + --target-type Snr --compress false \ + ${clean_data_dir}_log_energy ${noise_data_dir}_log_energy ${corrupted_data_dir} \ + exp/make_log_snr/${corrupted_data_id} $targets_dir +fi + +exit 0 + +if [ $stage -le 5 ]; then + # clean here is the reverberated first-speaker signal + utils/copy_data_dir.sh $clean_data_dir ${clean_data_dir}_$feat_suffix + clean_data_dir=${clean_data_dir}_$feat_suffix + steps/make_mfcc.sh --mfcc-config $mfcc_config \ + --cmd "$cmd" --nj $nj \ + $clean_data_dir exp/make_${feat_suffix}/${clean_data_id} $mfccdir +else + clean_data_dir=${clean_data_dir}_$feat_suffix +fi + +if [ $stage -le 6 ]; then + # noise here is the reverberated second-speaker signal + utils/copy_data_dir.sh $noise_data_dir ${noise_data_dir}_$feat_suffix + noise_data_dir=${noise_data_dir}_$feat_suffix + steps/make_mfcc.sh --mfcc-config $mfcc_config \ + --cmd "$cmd" --nj $nj \ + $noise_data_dir exp/make_${feat_suffix}/${noise_data_id} $mfccdir +else + noise_data_dir=${noise_data_dir}_$feat_suffix +fi + +targets_dir=irm_targets +if [ $stage -le 8 ]; then + mkdir -p exp/make_irm_targets/${corrupted_data_id} + + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $targets_dir/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$targets_dir/storage $targets_dir/storage + fi + + # Get SNR targets only for the overlapped speech labels. + steps/segmentation/make_snr_targets.sh \ + --nj $nj --cmd "$cmd --max-jobs-run $max_jobs_run" \ + --target-type Irm --compress false --apply-exp true \ + --ali-rspecifier "ark,s,cs:cat ${corrupted_data_dir}/sad_seg.scp | segmentation-to-ali --lengths-rspecifier=ark,t:${corrupted_data_dir}/utt2num_frames scp:- ark:- |" \ + overlapped_speech_labels.scp \ + --silence-phones 0 \ + ${clean_data_dir} ${noise_data_dir} ${corrupted_data_dir} \ + exp/make_irm_targets/${corrupted_data_id} $targets_dir +fi + +exit 0 diff --git a/egs/aspire/s5/local/segmentation/do_corruption_data_dir_snr.sh b/egs/aspire/s5/local/segmentation/do_corruption_data_dir_snr.sh new file mode 100755 index 00000000000..19b4036c9aa --- /dev/null +++ b/egs/aspire/s5/local/segmentation/do_corruption_data_dir_snr.sh @@ -0,0 +1,236 @@ +#! /bin/bash + +# Copyright 2016 Vimal Manohar +# Apache 2.0 + +set -e +set -u +set -o pipefail + +. path.sh + +stage=0 +corruption_stage=-10 +corrupt_only=false + +# Data options +data_dir=data/train_si284 # Expecting whole data directory. +speed_perturb=true +num_data_reps=5 # Number of corrupted versions +snrs="20:10:15:5:0:-5" +foreground_snrs="20:10:15:5:0:-5" +background_snrs="20:10:15:5:2:0:-2:-5" +base_rirs=simulated +speeds="0.9 1.0 1.1" +resample_data_dir=false + +# Parallel options +reco_nj=40 +cmd=queue.pl + +# Options for feature extraction +mfcc_config=conf/mfcc_hires_bp.conf +feat_suffix=hires_bp + +reco_vad_dir= # Output of prepare_unsad_data.sh. + # If provided, the speech labels and deriv weights will be + # copied into the output data directory. + +. utils/parse_options.sh + +if [ $# -ne 0 ]; then + echo "Usage: $0" + exit 1 +fi + +data_id=`basename ${data_dir}` + +rvb_opts=() +if [ "$base_rirs" == "simulated" ]; then + # This is the config for the system using simulated RIRs and point-source noises + rvb_opts+=(--rir-set-parameters "0.5, RIRS_NOISES/simulated_rirs/smallroom/rir_list") + rvb_opts+=(--rir-set-parameters "0.5, RIRS_NOISES/simulated_rirs/mediumroom/rir_list") + rvb_opts+=(--noise-set-parameters "0.1, RIRS_NOISES/pointsource_noises/background_noise_list") + rvb_opts+=(--noise-set-parameters "0.9, RIRS_NOISES/pointsource_noises/foreground_noise_list") +else + # This is the config for the JHU ASpIRE submission system + rvb_opts+=(--rir-set-parameters "1.0, RIRS_NOISES/real_rirs_isotropic_noises/rir_list") + rvb_opts+=(--noise-set-parameters RIRS_NOISES/real_rirs_isotropic_noises/noise_list) +fi + +if $resample_data_dir; then + sample_frequency=`cat $mfcc_config | perl -ne 'if (m/--sample-frequency=(\S+)/) { print $1; }'` + if [ -z "$sample_frequency" ]; then + sample_frequency=16000 + fi + + utils/data/resample_data_dir.sh $sample_frequency ${data_dir} || exit 1 + data_id=`basename ${data_dir}` + rvb_opts+=(--source-sampling-rate=$sample_frequency) +fi + +corrupted_data_id=${data_id}_corrupted +clean_data_id=${data_id}_clean +noise_data_id=${data_id}_noise + +if [ $stage -le 1 ]; then + python steps/data/reverberate_data_dir.py \ + "${rvb_opts[@]}" \ + --prefix="rev" \ + --foreground-snrs=$foreground_snrs \ + --background-snrs=$background_snrs \ + --speech-rvb-probability=1 \ + --pointsource-noise-addition-probability=1 \ + --isotropic-noise-addition-probability=1 \ + --num-replications=$num_data_reps \ + --max-noises-per-minute=2 \ + --output-additive-noise-dir=data/${noise_data_id} \ + --output-reverb-dir=data/${clean_data_id} \ + data/${data_id} data/${corrupted_data_id} +fi + +corrupted_data_dir=data/${corrupted_data_id} +clean_data_dir=data/${clean_data_id} +noise_data_dir=data/${noise_data_id} + +if $speed_perturb; then + if [ $stage -le 2 ]; then + ## Assuming whole data directories + for x in $corrupted_data_dir $clean_data_dir $noise_data_dir; do + cp $x/reco2dur $x/utt2dur + utils/data/perturb_data_dir_speed_random.sh --speeds "$speeds" $x ${x}_spr + done + fi + + corrupted_data_dir=${corrupted_data_dir}_spr + clean_data_dir=${clean_data_dir}_spr + noise_data_dir=${noise_data_dir}_spr + corrupted_data_id=${corrupted_data_id}_spr + clean_data_id=${clean_data_id}_spr + noise_data_id=${noise_data_id}_spr + + if [ $stage -le 3 ]; then + utils/data/perturb_data_dir_volume.sh --scale-low 0.03125 --scale-high 2 \ + ${corrupted_data_dir} + utils/data/perturb_data_dir_volume.sh --reco2vol ${corrupted_data_dir}/reco2vol ${clean_data_dir} + utils/data/perturb_data_dir_volume.sh --reco2vol ${corrupted_data_dir}/reco2vol ${noise_data_dir} + fi +fi + +if $corrupt_only; then + echo "$0: Got corrupted data directory in ${corrupted_data_dir}" + exit 0 +fi + +mfccdir=`basename $mfcc_config` +mfccdir=${mfccdir%%.conf} + +if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $mfccdir/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$mfccdir/storage $mfccdir/storage +fi + +if [ $stage -le 4 ]; then + utils/copy_data_dir.sh $corrupted_data_dir ${corrupted_data_dir}_$feat_suffix + corrupted_data_dir=${corrupted_data_dir}_$feat_suffix + steps/make_mfcc.sh --mfcc-config $mfcc_config \ + --cmd "$cmd" --nj $reco_nj \ + $corrupted_data_dir exp/make_${feat_suffix}/${corrupted_data_id} $mfccdir + steps/compute_cmvn_stats.sh --fake \ + $corrupted_data_dir exp/make_${feat_suffix}/${corrupted_data_id} $mfccdir +else + corrupted_data_dir=${corrupted_data_dir}_$feat_suffix +fi + +if [ $stage -le 5 ]; then + utils/copy_data_dir.sh $clean_data_dir ${clean_data_dir}_$feat_suffix + clean_data_dir=${clean_data_dir}_$feat_suffix + steps/make_mfcc.sh --mfcc-config $mfcc_config \ + --cmd "$cmd" --nj $reco_nj \ + $clean_data_dir exp/make_${feat_suffix}/${clean_data_id} $mfccdir + steps/compute_cmvn_stats.sh --fake \ + $clean_data_dir exp/make_${feat_suffix}/${clean_data_id} $mfccdir +else + clean_data_dir=${clean_data_dir}_$feat_suffix +fi + +if [ $stage -le 6 ]; then + utils/copy_data_dir.sh $noise_data_dir ${noise_data_dir}_$feat_suffix + noise_data_dir=${noise_data_dir}_$feat_suffix + steps/make_mfcc.sh --mfcc-config $mfcc_config \ + --cmd "$cmd" --nj $reco_nj \ + $noise_data_dir exp/make_${feat_suffix}/${noise_data_id} $mfccdir + steps/compute_cmvn_stats.sh --fake \ + $noise_data_dir exp/make_${feat_suffix}/${noise_data_id} $mfccdir +else + noise_data_dir=${noise_data_dir}_$feat_suffix +fi + +targets_dir=irm_targets +if [ $stage -le 7 ]; then + mkdir -p exp/make_log_snr/${corrupted_data_id} + + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $targets_dir/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$targets_dir/storage $targets_dir/storage + fi + + idct_params=`cat $mfcc_config | perl -e ' + $num_mel_bins = 23; $num_ceps = 13; $cepstral_lifter = 22.0; + while (<>) { + chomp; + s/#.+//g; + if (m/^\s*$/) { next; } + if (m/--num-mel-bins=(\S+)/) { + $num_mel_bins = $1; + } elsif (m/--num-ceps=(\S+)/) { + $num_ceps = $1; + } elsif (m/--cepstral-lifter=(\S+)/) { + $cepstral_lifter = $1; + } + } + print "$num_mel_bins $num_ceps $cepstral_lifter";'` + + num_filters=`echo $idct_params | awk '{print $1}'` + num_ceps=`echo $idct_params | awk '{print $2}'` + cepstral_lifter=`echo $idct_params | awk '{print $3}'` + echo "$num_filters $num_ceps $cepstral_lifter" + + mkdir -p exp/make_irm_targets/$corrupted_data_id + utils/data/get_dct_matrix.py --get-idct-matrix=true \ + --num-filters=$num_filters --num-ceps=$num_ceps \ + --cepstral-lifter=$cepstral_lifter \ + exp/make_irm_targets/$corrupted_data_id/idct_matrix + + # Get log-SNR targets + steps/segmentation/make_snr_targets.sh \ + --nj $reco_nj --cmd "$cmd" \ + --target-type Irm --compress false \ + --transform-matrix exp/make_irm_targets/$corrupted_data_id/idct_matrix \ + ${clean_data_dir} ${noise_data_dir} ${corrupted_data_dir} \ + exp/make_irm_targets/${corrupted_data_id} $targets_dir +fi + + +if [ $stage -le 8 ]; then + if [ ! -z "$reco_vad_dir" ]; then + if [ ! -f $reco_vad_dir/speech_labels.scp ]; then + echo "$0: Could not find file $reco_vad_dir/speech_labels.scp" + exit 1 + fi + + cat $reco_vad_dir/speech_labels.scp | \ + steps/segmentation/get_reverb_scp.pl -f 1 $num_data_reps | \ + sort -k1,1 > ${corrupted_data_dir}/speech_labels.scp + + cat $reco_vad_dir/deriv_weights.scp | \ + steps/segmentation/get_reverb_scp.pl -f 1 $num_data_reps | \ + sort -k1,1 > ${corrupted_data_dir}/deriv_weights.scp + + cat $reco_vad_dir/deriv_weights_manual_seg.scp | \ + steps/segmentation/get_reverb_scp.pl -f 1 $num_data_reps | \ + sort -k1,1 > ${corrupted_data_dir}/deriv_weights_for_irm_targets.scp + fi +fi + +exit 0 diff --git a/egs/aspire/s5/local/segmentation/do_corruption_whole_data_dir_overlapped_speech.sh b/egs/aspire/s5/local/segmentation/do_corruption_whole_data_dir_overlapped_speech.sh new file mode 100755 index 00000000000..75dbce578b2 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/do_corruption_whole_data_dir_overlapped_speech.sh @@ -0,0 +1,284 @@ +#! /bin/bash + +# Copyright 2016 Vimal Manohar +# Apache 2.0 + +set -e +set -u +set -o pipefail + +. path.sh + +stage=0 +corruption_stage=-10 +corrupt_only=false + +# Data options +data_dir=data/train_si284 # Excpecting non-whole data directory +speed_perturb=true +num_data_reps=5 # Number of corrupted versions +snrs="20:10:15:5:0:-5" +foreground_snrs="20:10:15:5:0:-5" +background_snrs="20:10:15:5:0:-5" +overlap_snrs="5:2:1:0:-1:-2" +# Whole-data directory corresponding to data_dir +whole_data_dir=data/train_si284_whole +overlap_labels_dir=overlap_labels + +# Parallel options +reco_nj=40 +nj=40 +cmd=queue.pl + +# Options for feature extraction +mfcc_config=conf/mfcc_hires_bp.conf +feat_suffix=hires_bp +energy_config=conf/log_energy.conf + +reco_vad_dir= # Output of prepare_unsad_data.sh. + # If provided, the speech labels and deriv weights will be + # copied into the output data directory. +utt_vad_dir= + +. utils/parse_options.sh + +if [ $# -ne 0 ]; then + echo "Usage: $0" + exit 1 +fi + +rvb_opts=() +# This is the config for the system using simulated RIRs and point-source noises +rvb_opts+=(--rir-set-parameters "0.5, RIRS_NOISES/simulated_rirs/smallroom/rir_list") +rvb_opts+=(--rir-set-parameters "0.5, RIRS_NOISES/simulated_rirs/mediumroom/rir_list") +rvb_opts+=(--speech-segments-set-parameters="$data_dir/wav.scp,$data_dir/segments") + +whole_data_id=`basename ${whole_data_dir}` + +corrupted_data_id=${whole_data_id}_ovlp_corrupted +clean_data_id=${whole_data_id}_ovlp_clean +noise_data_id=${whole_data_id}_ovlp_noise + +if [ $stage -le 1 ]; then + python steps/data/make_corrupted_data_dir.py \ + "${rvb_opts[@]}" \ + --prefix="ovlp" \ + --overlap-snrs=$overlap_snrs \ + --speech-rvb-probability=1 \ + --overlapping-speech-addition-probability=1 \ + --num-replications=$num_data_reps \ + --min-overlapping-segments-per-minute=5 \ + --max-overlapping-segments-per-minute=20 \ + --output-additive-noise-dir=data/${noise_data_id} \ + --output-reverb-dir=data/${clean_data_id} \ + data/${whole_data_id} data/${corrupted_data_id} +fi + +if $dry_run; then + exit 0 +fi + +clean_data_dir=data/${clean_data_id} +corrupted_data_dir=data/${corrupted_data_id} +noise_data_dir=data/${noise_data_id} +orig_corrupted_data_dir=$corrupted_data_dir + +if $speed_perturb; then + if [ $stage -le 2 ]; then + ## Assuming whole data directories + for x in $clean_data_dir $corrupted_data_dir $noise_data_dir; do + cp $x/reco2dur $x/utt2dur + utils/data/perturb_data_dir_speed_3way.sh $x ${x}_sp + done + fi + + corrupted_data_dir=${corrupted_data_dir}_sp + clean_data_dir=${clean_data_dir}_sp + noise_data_dir=${noise_data_dir}_sp + + corrupted_data_id=${corrupted_data_id}_sp + clean_data_id=${clean_data_id}_sp + noise_data_id=${noise_data_id}_sp + + if [ $stage -le 3 ]; then + utils/data/perturb_data_dir_volume.sh --scale-low 0.03125 --scale-high 2 --force true ${corrupted_data_dir} + utils/data/perturb_data_dir_volume.sh --force true --reco2vol ${corrupted_data_dir}/reco2vol ${clean_data_dir} + utils/data/perturb_data_dir_volume.sh --force true --reco2vol ${corrupted_data_dir}/reco2vol ${noise_data_dir} + fi +fi + +if $corrupt_only; then + echo "$0: Got corrupted data directory in ${corrupted_data_dir}" + exit 0 +fi + +mfccdir=`basename $mfcc_config` +mfccdir=${mfccdir%%.conf} + +if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $mfccdir/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$mfccdir/storage $mfccdir/storage +fi + +if [ $stage -le 4 ]; then + utils/copy_data_dir.sh $corrupted_data_dir ${corrupted_data_dir}_$feat_suffix + corrupted_data_dir=${corrupted_data_dir}_$feat_suffix + steps/make_mfcc.sh --mfcc-config $mfcc_config \ + --cmd "$train_cmd" --nj $reco_nj \ + $corrupted_data_dir exp/make_${feat_suffix}/${corrupted_data_id} $mfccdir +fi + +if [ $stage -le 5 ]; then + steps/make_mfcc.sh --mfcc-config $energy_config \ + --cmd "$train_cmd" --nj $reco_nj \ + $clean_data_dir exp/make_log_energy/${clean_data_id} log_energy_feats +fi + +if [ $stage -le 6 ]; then + steps/make_mfcc.sh --mfcc-config $energy_config \ + --cmd "$train_cmd" --nj $reco_nj \ + $noise_data_dir exp/make_log_energy/${noise_data_id} log_energy_feats +fi + +if [ -z "$reco_vad_dir" ]; then + echo "reco-vad-dir must be provided" + exit 1 +fi + +targets_dir=irm_targets +if [ $stage -le 8 ]; then + mkdir -p exp/make_irm_targets/${corrupted_data_id} + + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $targets_dir/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$targets_dir/storage $targets_dir/storage + fi + + steps/segmentation/make_snr_targets.sh \ + --nj $nj --cmd "$train_cmd --max-jobs-run $max_jobs_run" \ + --target-type Irm --compress true --apply-exp false \ + ${clean_data_dir} ${noise_data_dir} ${corrupted_data_dir} \ + exp/make_irm_targets/${corrupted_data_id} $targets_dir +fi + +# Combine the VAD from the base recording and the VAD from the overlapping segments +# to create per-frame labels of the number of overlapping speech segments +# Unreliable segments are regions where no VAD labels were available for the +# overlapping segments. These can be later removed by setting deriv weights to 0. + +# Data dirs without speed perturbation +overlap_dir=exp/make_overlap_labels/${corrupted_data_id} +unreliable_dir=exp/make_overlap_labels/unreliable_${corrupted_data_id} +overlap_data_dir=$overlap_dir/overlap_data +unreliable_data_dir=$overlap_dir/unreliable_data + +mkdir -p $unreliable_dir + +if [ $stage -le 8 ]; then + cat $reco_vad_dir/sad_seg.scp | \ + steps/segmentation/get_reverb_scp.pl -f 1 $num_data_reps "ovlp" \ + | sort -k1,1 > ${corrupted_data_dir}/sad_seg.scp + utils/data/get_utt2num_frames.sh $corrupted_data_dir + utils/split_data.sh --per-reco ${orig_corrupted_data_dir} $reco_nj + + $train_cmd JOB=1:$reco_nj $overlap_dir/log/get_overlap_seg.JOB.log \ + segmentation-init-from-overlap-info --lengths-rspecifier=ark,t:$corrupted_data_dir/utt2num_frames \ + "scp:utils/filter_scp.pl ${orig_corrupted_data_dir}/split${reco_nj}reco/JOB/utt2spk $corrupted_data_dir/sad_seg.scp |" \ + ark,t:$orig_corrupted_data_dir/overlapped_segments_info.txt \ + scp:$utt_vad_dir/sad_seg.scp ark:- ark:$unreliable_dir/unreliable_seg_speed_unperturbed.JOB.ark \| \ + segmentation-copy --keep-label=1 ark:- ark:- \| \ + segmentation-get-stats --lengths-rspecifier=ark,t:$corrupted_data_dir/utt2num_frames \ + ark:- ark:- ark:/dev/null \| \ + segmentation-init-from-ali ark:- ark:$overlap_dir/overlap_seg_speed_unperturbed.JOB.ark +fi + +if [ $stage -le 9 ]; then + mkdir -p $overlap_data_dir $unreliable_data_dir + cp $orig_corrupted_data_dir/wav.scp $overlap_data_dir + cp $orig_corrupted_data_dir/wav.scp $unreliable_data_dir + + # Create segments where there is definitely an overlap. + # Assume no more than 10 speakers overlap. + $train_cmd JOB=1:$reco_nj $overlap_dir/log/process_to_segments.JOB.log \ + segmentation-post-process --remove-labels=0:1 \ + ark:$overlap_dir/overlap_seg_speed_unperturbed.JOB.ark ark:- \| \ + segmentation-post-process --merge-labels=2:3:4:5:6:7:8:9:10 --merge-dst-label=1 ark:- ark:- \| \ + segmentation-to-segments ark:- ark:$overlap_data_dir/utt2spk.JOB $overlap_data_dir/segments.JOB + + $train_cmd JOB=1:$reco_nj $overlap_dir/log/get_unreliable_segments.JOB.log \ + segmentation-to-segments --single-speaker \ + ark:$unreliable_dir/unreliable_seg_speed_unperturbed.JOB.ark \ + ark:$unreliable_data_dir/utt2spk.JOB $unreliable_data_dir/segments.JOB + + for n in `seq $reco_nj`; do cat $overlap_data_dir/utt2spk.$n; done > $overlap_data_dir/utt2spk + for n in `seq $reco_nj`; do cat $overlap_data_dir/segments.$n; done > $overlap_data_dir/segments + for n in `seq $reco_nj`; do cat $unreliable_data_dir/utt2spk.$n; done > $unreliable_data_dir/utt2spk + for n in `seq $reco_nj`; do cat $unreliable_data_dir/segments.$n; done > $unreliable_data_dir/segments + + utils/fix_data_dir.sh $overlap_data_dir + utils/fix_data_dir.sh $unreliable_data_dir + + if $speed_perturb; then + utils/data/perturb_data_dir_speed_3way.sh $overlap_data_dir ${overlap_data_dir}_sp + utils/data/perturb_data_dir_speed_3way.sh $unreliable_data_dir ${unreliable_data_dir}_sp + fi +fi + +if $speed_perturb; then + overlap_data_dir=${overlap_data_dir}_sp + unreliable_data_dir=${unreliable_data_dir}_sp +fi + +# make $overlap_labels_dir an absolute pathname. +overlap_labels_dir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' $overlap_labels_dir ${PWD}` + +if [ $stage -le 10 ]; then + utils/split_data.sh --per-reco ${overlap_data_dir} $reco_nj + + $train_cmd JOB=1:$reco_nj $overlap_dir/log/get_overlap_speech_labels.JOB.log \ + utils/data/get_reco2utt.sh ${overlap_data_dir}/split${reco_nj}reco/JOB '&&' \ + segmentation-init-from-segments --shift-to-zero=false \ + ${overlap_data_dir}/split${reco_nj}reco/JOB/segments ark:- \| \ + segmentation-combine-segments-to-recordings ark:- ark,t:${overlap_data_dir}/split${reco_nj}reco/JOB/reco2utt \ + ark:- \| \ + segmentation-to-ali --lengths-rspecifier=ark,t:${corrupted_data_dir}/utt2num_frames ark:- \ + ark,scp:$overlap_labels_dir/overlapped_speech_${corrupted_data_id}.JOB.ark,$overlap_labels_dir/overlapped_speech_${corrupted_data_id}.JOB.scp +fi + +for n in `seq $reco_nj`; do + cat $overlap_labels_dir/overlapped_speech_${corrupted_data_id}.$n.scp +done > ${corrupted_data_dir}/overlapped_speech_labels.scp + +if [ $stage -le 11 ]; then + utils/data/get_reco2utt.sh ${unreliable_data_dir} + + # First convert the unreliable segments into a recording-level segmentation. + # Initialize a segmentation from utt2num_frames and set to 0, the regions + # of unreliable segments. At this stage deriv weights is 1 for all but the + # unreliable segment regions. + # Initialize a segmentation from the VAD labels and retain only the speech segments. + # Intersect this with the deriv weights segmentation from above. At this stage + # deriv weights is 1 for only the regions where base VAD label is 1 and + # the overlapping segment is not unreliable. Convert this to deriv weights. + $train_cmd JOB=1:$reco_nj $unreliable_dir/log/get_deriv_weights.JOB.log\ + segmentation-init-from-segments --shift-to-zero=false \ + "utils/filter_scp.pl -f 2 ${overlap_data_dir}/split${reco_nj}reco/JOB/reco2utt ${unreliable_data_dir}/segments |" ark:- \| \ + segmentation-combine-segments-to-recordings ark:- "ark,t:utils/filter_scp.pl ${overlap_data_dir}/split${reco_nj}reco/JOB/reco2utt ${unreliable_data_dir}/reco2utt |" \ + ark:- \| \ + segmentation-create-subsegments --filter-label=1 --subsegment-label=0 --ignore-missing \ + "ark:utils/filter_scp.pl ${overlap_data_dir}/split${reco_nj}reco/JOB/reco2utt $corrupted_data_dir/utt2num_frames | segmentation-init-from-lengths ark,t:- ark:- |" \ + ark:- ark:- \| \ + segmentation-intersect-segments --mismatch-label=0 \ + "ark:utils/filter_scp.pl ${overlap_data_dir}/split${reco_nj}reco/JOB/reco2utt $corrupted_data_dir/sad_seg.scp | segmentation-post-process --remove-labels=0:2:3 scp:- ark:- |" \ + ark:- ark:- \| \ + segmentation-post-process --remove-labels=0 ark:- ark:- \| \ + segmentation-to-ali --lengths-rspecifier=ark,t:${corrupted_data_dir}/utt2num_frames ark:- ark,t:- \| \ + steps/segmentation/convert_ali_to_vec.pl \| copy-vector ark,t:- \ + ark,scp:$overlap_labels_dir/deriv_weights_for_overlapped_speech.JOB.ark,$overlap_labels_dir/deriv_weights_for_overlapped_speech.JOB.scp + + for n in `seq $reco_nj`; do + cat $overlap_labels_dir/deriv_weights_for_overlapped_speech.${n}.scp + done > $corrupted_data_dir/deriv_weights_for_overlapped_speech.scp +fi + +exit 0 diff --git a/egs/aspire/s5/local/segmentation/make_musan_music.py b/egs/aspire/s5/local/segmentation/make_musan_music.py new file mode 100755 index 00000000000..5d13078de63 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/make_musan_music.py @@ -0,0 +1,69 @@ +#! /usr/bin/env python + +from __future__ import print_function +import argparse +import os + + +def _get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--use-vocals", type=str, default="false", + choices=["true", "false"], + help="If true, also add music with vocals in the " + "output music-set-parameters") + parser.add_argument("root_dir", type=str, + help="Root directory of MUSAN corpus") + parser.add_argument("music_list", type=argparse.FileType('w'), + help="Convert music list into noise-set-paramters " + "for steps/data/reverberate_data_dir.py") + + args = parser.parse_args() + + args.use_vocals = True if args.use_vocals == "true" else False + return args + + +def read_vocals(annotations): + vocals = {} + for line in open(annotations): + parts = line.strip().split() + if parts[2] == "Y": + vocals[parts[0]] = True + return vocals + + +def write_music(utt, file_path, music_list): + print ('{utt} {file_path}'.format( + utt=utt, file_path=file_path), file=music_list) + + +def prepare_music_set(root_dir, use_vocals, music_list): + vocals = {} + music_dir = os.path.join(root_dir, "music") + for root, dirs, files in os.walk(music_dir): + if os.path.exists(os.path.join(root, "ANNOTATIONS")): + vocals = read_vocals(os.path.join(root, "ANNOTATIONS")) + + for f in files: + file_path = os.path.join(root, f) + if f.endswith(".wav"): + utt = str(f).replace(".wav", "") + if not use_vocals and utt in vocals: + continue + write_music(utt, file_path, music_list) + music_list.close() + + +def main(): + args = _get_args() + + try: + prepare_music_set(args.root_dir, args.use_vocals, + args.music_list) + finally: + args.music_list.close() + + +if __name__ == '__main__': + main() diff --git a/egs/aspire/s5/local/segmentation/make_sad_tdnn_configs.py b/egs/aspire/s5/local/segmentation/make_sad_tdnn_configs.py new file mode 100755 index 00000000000..e859a3593ce --- /dev/null +++ b/egs/aspire/s5/local/segmentation/make_sad_tdnn_configs.py @@ -0,0 +1,616 @@ +#!/usr/bin/env python + +# we're using python 3.x style print but want it to work in python 2.x, +from __future__ import print_function +import os +import argparse +import shlex +import sys +import warnings +import copy +import imp +import ast + +nodes = imp.load_source('', 'steps/nnet3/components.py') +import libs.common as common_lib + +def GetArgs(): + # we add compulsary arguments as named arguments for readability + parser = argparse.ArgumentParser(description="Writes config files and variables " + "for TDNNs creation and training", + epilog="See steps/nnet3/tdnn/train.sh for example.") + + # Only one of these arguments can be specified, and one of them has to + # be compulsarily specified + feat_group = parser.add_mutually_exclusive_group(required = True) + feat_group.add_argument("--feat-dim", type=int, + help="Raw feature dimension, e.g. 13") + feat_group.add_argument("--feat-dir", type=str, + help="Feature directory, from which we derive the feat-dim") + + # only one of these arguments can be specified + ivector_group = parser.add_mutually_exclusive_group(required = False) + ivector_group.add_argument("--ivector-dim", type=int, + help="iVector dimension, e.g. 100", default=0) + ivector_group.add_argument("--ivector-dir", type=str, + help="iVector dir, which will be used to derive the ivector-dim ", default=None) + + num_target_group = parser.add_mutually_exclusive_group(required = True) + num_target_group.add_argument("--num-targets", type=int, + help="number of network targets (e.g. num-pdf-ids/num-leaves)") + num_target_group.add_argument("--ali-dir", type=str, + help="alignment directory, from which we derive the num-targets") + num_target_group.add_argument("--tree-dir", type=str, + help="directory with final.mdl, from which we derive the num-targets") + num_target_group.add_argument("--output-node-parameters", type=str, action='append', + dest='output_node_para_array', + help = "Define output nodes' and their parameters like output-suffix, dim, objective-type etc") + # CNN options + parser.add_argument('--cnn.layer', type=str, action='append', dest = "cnn_layer", + help="CNN parameters at each CNN layer, e.g. --filt-x-dim=3 --filt-y-dim=8 " + "--filt-x-step=1 --filt-y-step=1 --num-filters=256 --pool-x-size=1 --pool-y-size=3 " + "--pool-z-size=1 --pool-x-step=1 --pool-y-step=3 --pool-z-step=1, " + "when CNN layers are used, no LDA will be added", default = None) + parser.add_argument("--cnn.bottleneck-dim", type=int, dest = "cnn_bottleneck_dim", + help="Output dimension of the linear layer at the CNN output " + "for dimension reduction, e.g. 256." + "The default zero means this layer is not needed.", default=0) + + # General neural network options + parser.add_argument("--splice-indexes", type=str, required = True, + help="Splice indexes at each layer, e.g. '-3,-2,-1,0,1,2,3' " + "If CNN layers are used the first set of splice indexes will be used as input " + "to the first CNN layer and later splice indexes will be interpreted as indexes " + "for the TDNNs.") + parser.add_argument("--add-lda", type=str, action=common_lib.StrToBoolAction, + help="If \"true\" an LDA matrix computed from the input features " + "(spliced according to the first set of splice-indexes) will be used as " + "the first Affine layer. This affine layer's parameters are fixed during training. " + "This variable needs to be set to \"false\" when using dense-targets.\n" + "If --cnn.layer is specified this option will be forced to \"false\".", + default=True, choices = ["false", "true"]) + + parser.add_argument("--include-log-softmax", type=str, action=common_lib.StrToBoolAction, + help="add the final softmax layer ", default=True, choices = ["false", "true"]) + parser.add_argument("--add-final-sigmoid", type=str, action=common_lib.StrToBoolAction, + help="add a final sigmoid layer as alternate to log-softmax-layer. " + "Can only be used if include-log-softmax is false. " + "This is useful in cases where you want the output to be " + "like probabilities between 0 and 1. Typically the nnet " + "is trained with an objective such as quadratic", + default=False, choices = ["false", "true"]) + + parser.add_argument("--objective-type", type=str, + help = "the type of objective; i.e. quadratic or linear", + default="linear", choices = ["linear", "quadratic"]) + parser.add_argument("--xent-regularize", type=float, + help="For chain models, if nonzero, add a separate output for cross-entropy " + "regularization (with learning-rate-factor equal to the inverse of this)", + default=0.0) + parser.add_argument("--final-layer-normalize-target", type=float, + help="RMS target for final layer (set to <1 if final layer learns too fast", + default=1.0) + parser.add_argument("--subset-dim", type=int, default=0, + help="dimension of the subset of units to be sent to the central frame") + parser.add_argument("--pnorm-input-dim", type=int, + help="input dimension to p-norm nonlinearities") + parser.add_argument("--pnorm-output-dim", type=int, + help="output dimension of p-norm nonlinearities") + relu_dim_group = parser.add_mutually_exclusive_group(required = False) + relu_dim_group.add_argument("--relu-dim", type=int, + help="dimension of all ReLU nonlinearity layers") + relu_dim_group.add_argument("--relu-dim-final", type=int, + help="dimension of the last ReLU nonlinearity layer. Dimensions increase geometrically from the first through the last ReLU layer.", default=None) + parser.add_argument("--relu-dim-init", type=int, + help="dimension of the first ReLU nonlinearity layer. Dimensions increase geometrically from the first through the last ReLU layer.", default=None) + + parser.add_argument("--self-repair-scale-nonlinearity", type=float, + help="A non-zero value activates the self-repair mechanism in the sigmoid and tanh non-linearities of the LSTM", default=None) + + + parser.add_argument("--use-presoftmax-prior-scale", type=str, action=common_lib.StrToBoolAction, + help="if true, a presoftmax-prior-scale is added", + choices=['true', 'false'], default = True) + + # Options to convert input MFCC into Fbank features. This is useful when a + # LDA layer is not added (such as when using dense targets) + parser.add_argument("--cnn.cepstral-lifter", type=float, dest = "cepstral_lifter", + help="The factor used for determining the liftering vector in the production of MFCC. " + "User has to ensure that it matches the lifter used in MFCC generation, " + "e.g. 22.0", default=22.0) + + parser.add_argument("config_dir", + help="Directory to write config files and variables") + + print(' '.join(sys.argv)) + + args = parser.parse_args() + args = CheckArgs(args) + + return args + +def CheckArgs(args): + if not os.path.exists(args.config_dir): + os.makedirs(args.config_dir) + + ## Check arguments. + if args.feat_dir is not None: + args.feat_dim = common_lib.get_feat_dim(args.feat_dir) + + if args.ivector_dir is not None: + args.ivector_dim = common_lib.get_ivector_dim(args.ivector_dir) + + if not args.feat_dim > 0: + raise Exception("feat-dim has to be postive") + + if len(args.output_node_para_array) == 0: + if args.ali_dir is not None: + args.num_targets = common_lib.get_number_of_leaves_from_tree(args.ali_dir) + elif args.tree_dir is not None: + args.num_targets = common_lib.get_number_of_leaves_from_tree(args.tree_dir) + if not args.num_targets > 0: + print(args.num_targets) + raise Exception("num_targets has to be positive") + args.output_node_para_array.append( + "--dim={0} --objective-type={1} --include-log-softmax={2} --add-final-sigmoid={3} --xent-regularize={4}".format( + args.num_targets, args.objective_type, + "true" if args.include_log_softmax else "false", + "true" if args.add_final_sigmoid else "false", + args.xent_regularize)) + + if not args.ivector_dim >= 0: + raise Exception("ivector-dim has to be non-negative") + + if (args.subset_dim < 0): + raise Exception("--subset-dim has to be non-negative") + + if not args.relu_dim is None: + if not args.pnorm_input_dim is None or not args.pnorm_output_dim is None or not args.relu_dim_init is None: + raise Exception("--relu-dim argument not compatible with " + "--pnorm-input-dim or --pnorm-output-dim or --relu-dim-init options"); + args.nonlin_input_dim = args.relu_dim + args.nonlin_output_dim = args.relu_dim + args.nonlin_output_dim_final = None + args.nonlin_output_dim_init = None + args.nonlin_type = 'relu' + + elif not args.relu_dim_final is None: + if not args.pnorm_input_dim is None or not args.pnorm_output_dim is None: + raise Exception("--relu-dim-final argument not compatible with " + "--pnorm-input-dim or --pnorm-output-dim options") + if args.relu_dim_init is None: + raise Exception("--relu-dim-init argument should also be provided with --relu-dim-final") + if args.relu_dim_init > args.relu_dim_final: + raise Exception("--relu-dim-init has to be no larger than --relu-dim-final") + args.nonlin_input_dim = None + args.nonlin_output_dim = None + args.nonlin_output_dim_final = args.relu_dim_final + args.nonlin_output_dim_init = args.relu_dim_init + args.nonlin_type = 'relu' + + else: + if not args.relu_dim_init is None: + raise Exception("--relu-dim-final argument not compatible with " + "--pnorm-input-dim or --pnorm-output-dim options") + if not args.pnorm_input_dim > 0 or not args.pnorm_output_dim > 0: + raise Exception("--relu-dim not set, so expected --pnorm-input-dim and " + "--pnorm-output-dim to be provided."); + args.nonlin_input_dim = args.pnorm_input_dim + args.nonlin_output_dim = args.pnorm_output_dim + if (args.nonlin_input_dim < args.nonlin_output_dim) or (args.nonlin_input_dim % args.nonlin_output_dim != 0): + raise Exception("Invalid --pnorm-input-dim {0} and --pnorm-output-dim {1}".format(args.nonlin_input_dim, args.nonlin_output_dim)) + args.nonlin_output_dim_final = None + args.nonlin_output_dim_init = None + args.nonlin_type = 'pnorm' + + if args.add_lda and args.cnn_layer is not None: + args.add_lda = False + warnings.warn("--add-lda is set to false as CNN layers are used.") + + return args + +def AddConvMaxpLayer(config_lines, name, input, args): + if '3d-dim' not in input: + raise Exception("The input to AddConvMaxpLayer() needs '3d-dim' parameters.") + + input = nodes.AddConvolutionLayer(config_lines, name, input, + input['3d-dim'][0], input['3d-dim'][1], input['3d-dim'][2], + args.filt_x_dim, args.filt_y_dim, + args.filt_x_step, args.filt_y_step, + args.num_filters, input['vectorization']) + + if args.pool_x_size > 1 or args.pool_y_size > 1 or args.pool_z_size > 1: + input = nodes.AddMaxpoolingLayer(config_lines, name, input, + input['3d-dim'][0], input['3d-dim'][1], input['3d-dim'][2], + args.pool_x_size, args.pool_y_size, args.pool_z_size, + args.pool_x_step, args.pool_y_step, args.pool_z_step) + + return input + +# The ivectors are processed through an affine layer parallel to the CNN layers, +# then concatenated with the CNN output and passed to the deeper part of the network. +def AddCnnLayers(config_lines, cnn_layer, cnn_bottleneck_dim, cepstral_lifter, config_dir, feat_dim, splice_indexes=[0], ivector_dim=0): + cnn_args = ParseCnnString(cnn_layer) + num_cnn_layers = len(cnn_args) + # We use an Idct layer here to convert MFCC to FBANK features + common_lib.write_idct_matrix(feat_dim, cepstral_lifter, config_dir.strip() + "/idct.mat") + prev_layer_output = {'descriptor': "input", + 'dimension': feat_dim} + prev_layer_output = nodes.AddFixedAffineLayer(config_lines, "Idct", prev_layer_output, config_dir.strip() + '/idct.mat') + + list = [('Offset({0}, {1})'.format(prev_layer_output['descriptor'],n) if n != 0 else prev_layer_output['descriptor']) for n in splice_indexes] + splice_descriptor = "Append({0})".format(", ".join(list)) + cnn_input_dim = len(splice_indexes) * feat_dim + prev_layer_output = {'descriptor': splice_descriptor, + 'dimension': cnn_input_dim, + '3d-dim': [len(splice_indexes), feat_dim, 1], + 'vectorization': 'yzx'} + + for cl in range(0, num_cnn_layers): + prev_layer_output = AddConvMaxpLayer(config_lines, "L{0}".format(cl), prev_layer_output, cnn_args[cl]) + + if cnn_bottleneck_dim > 0: + prev_layer_output = nodes.AddAffineLayer(config_lines, "cnn-bottleneck", prev_layer_output, cnn_bottleneck_dim, "") + + if ivector_dim > 0: + iv_layer_output = {'descriptor': 'ReplaceIndex(ivector, t, 0)', + 'dimension': ivector_dim} + iv_layer_output = nodes.AddAffineLayer(config_lines, "ivector", iv_layer_output, ivector_dim, "") + prev_layer_output['descriptor'] = 'Append({0}, {1})'.format(prev_layer_output['descriptor'], iv_layer_output['descriptor']) + prev_layer_output['dimension'] = prev_layer_output['dimension'] + iv_layer_output['dimension'] + + return prev_layer_output + +def PrintConfig(file_name, config_lines): + f = open(file_name, 'w') + f.write("\n".join(config_lines['components'])+"\n") + f.write("\n#Component nodes\n") + f.write("\n".join(config_lines['component-nodes'])+"\n") + f.close() + +def ParseCnnString(cnn_param_string_list): + cnn_parser = argparse.ArgumentParser(description="cnn argument parser") + + cnn_parser.add_argument("--filt-x-dim", required=True, type=int) + cnn_parser.add_argument("--filt-y-dim", required=True, type=int) + cnn_parser.add_argument("--filt-x-step", type=int, default = 1) + cnn_parser.add_argument("--filt-y-step", type=int, default = 1) + cnn_parser.add_argument("--num-filters", required=True, type=int) + cnn_parser.add_argument("--pool-x-size", type=int, default = 1) + cnn_parser.add_argument("--pool-y-size", type=int, default = 1) + cnn_parser.add_argument("--pool-z-size", type=int, default = 1) + cnn_parser.add_argument("--pool-x-step", type=int, default = 1) + cnn_parser.add_argument("--pool-y-step", type=int, default = 1) + cnn_parser.add_argument("--pool-z-step", type=int, default = 1) + + cnn_args = [] + for cl in range(0, len(cnn_param_string_list)): + cnn_args.append(cnn_parser.parse_args(shlex.split(cnn_param_string_list[cl]))) + + return cnn_args + +def ParseSpliceString(splice_indexes): + splice_array = [] + left_context = 0 + right_context = 0 + split_on_spaces = splice_indexes.split(); # we already checked the string is nonempty. + if len(split_on_spaces) < 1: + raise Exception("invalid splice-indexes argument, too short: " + + splice_indexes) + try: + for string in split_on_spaces: + this_splices = string.split(",") + if len(this_splices) < 1: + raise Exception("invalid splice-indexes argument, too-short element: " + + splice_indexes) + # the rest of this block updates left_context and right_context, and + # does some checking. + leftmost_splice = 10000 + rightmost_splice = -10000 + + int_list = [] + for s in this_splices: + try: + n = int(s) + if n < leftmost_splice: + leftmost_splice = n + if n > rightmost_splice: + rightmost_splice = n + int_list.append(n) + except ValueError: + #if len(splice_array) == 0: + # raise Exception("First dimension of splicing array must not have averaging [yet]") + try: + x = nodes.StatisticsConfig(s, { 'dimension':100, + 'descriptor': 'foo'} ) + int_list.append(s) + except Exception as e: + raise Exception("The following element of the splicing array is not a valid specifier " + "of statistics: {0}\nGot {1}".format(s, str(e))) + splice_array.append(int_list) + + if leftmost_splice == 10000 or rightmost_splice == -10000: + raise Exception("invalid element of --splice-indexes: " + string) + left_context += -leftmost_splice + right_context += rightmost_splice + except ValueError as e: + raise Exception("invalid --splice-indexes argument " + args.splice_indexes + " " + str(e)) + + left_context = max(0, left_context) + right_context = max(0, right_context) + + return {'left_context':left_context, + 'right_context':right_context, + 'splice_indexes':splice_array, + 'num_hidden_layers':len(splice_array) + } + +def AddPriorsAccumulator(config_lines, name, input): + components = config_lines['components'] + component_nodes = config_lines['component-nodes'] + + components.append("component name={0}_softmax type=SoftmaxComponent dim={1}".format(name, input['dimension'])) + component_nodes.append("component-node name={0}_softmax component={0}_softmax input={1}".format(name, input['descriptor'])) + + return {'descriptor': '{0}_softmax'.format(name), + 'dimension': input['dimension']} + +def AddFinalLayer(config_lines, input, output_dim, + ng_affine_options = " param-stddev=0 bias-stddev=0 ", + label_delay=None, + use_presoftmax_prior_scale = False, + prior_scale_file = None, + include_log_softmax = True, + add_final_sigmoid = False, + name_affix = None, + objective_type = "linear", + objective_scale = 1.0, + objective_scales_vec = None): + components = config_lines['components'] + component_nodes = config_lines['component-nodes'] + + if name_affix is not None: + final_node_prefix = 'Final-' + str(name_affix) + else: + final_node_prefix = 'Final' + + prev_layer_output = nodes.AddAffineLayer(config_lines, + final_node_prefix , input, output_dim, + ng_affine_options) + if include_log_softmax: + if use_presoftmax_prior_scale : + components.append('component name={0}-fixed-scale type=FixedScaleComponent scales={1}'.format(final_node_prefix, prior_scale_file)) + component_nodes.append('component-node name={0}-fixed-scale component={0}-fixed-scale input={1}'.format(final_node_prefix, + prev_layer_output['descriptor'])) + prev_layer_output['descriptor'] = "{0}-fixed-scale".format(final_node_prefix) + prev_layer_output = nodes.AddSoftmaxLayer(config_lines, final_node_prefix, prev_layer_output) + + elif add_final_sigmoid: + # Useful when you need the final outputs to be probabilities + # between 0 and 1. + # Usually used with an objective-type such as "quadratic" + prev_layer_output = nodes.AddSigmoidLayer(config_lines, final_node_prefix, prev_layer_output) + + # we use the same name_affix as a prefix in for affine/scale nodes but as a + # suffix for output node + if (objective_scale != 1.0 or objective_scales_vec is not None): + prev_layer_output = nodes.AddGradientScaleLayer(config_lines, final_node_prefix, prev_layer_output, objective_scale, objective_scales_vec) + + nodes.AddOutputLayer(config_lines, prev_layer_output, label_delay, suffix = name_affix, objective_type = objective_type) + +def AddOutputLayers(config_lines, prev_layer_output, output_nodes, + ng_affine_options = "", label_delay = 0): + + for o in output_nodes: + # make the intermediate config file for layerwise discriminative + # training + AddFinalLayer(config_lines, prev_layer_output, o.dim, + ng_affine_options, label_delay = label_delay, + include_log_softmax = o.include_log_softmax, + add_final_sigmoid = o.add_final_sigmoid, + objective_type = o.objective_type, + name_affix = o.output_suffix) + + if o.xent_regularize != 0.0: + nodes.AddFinalLayer(config_lines, prev_layer_output, o.dim, + include_log_softmax = True, + label_delay = label_delay, + name_affix = o.output_suffix + '_xent') + +# The function signature of MakeConfigs is changed frequently as it is intended for local use in this script. +def MakeConfigs(config_dir, splice_indexes_string, + cnn_layer, cnn_bottleneck_dim, cepstral_lifter, + feat_dim, ivector_dim, add_lda, + nonlin_type, nonlin_input_dim, nonlin_output_dim, subset_dim, + nonlin_output_dim_init, nonlin_output_dim_final, + use_presoftmax_prior_scale, final_layer_normalize_target, + output_nodes, self_repair_scale): + + parsed_splice_output = ParseSpliceString(splice_indexes_string.strip()) + + left_context = parsed_splice_output['left_context'] + right_context = parsed_splice_output['right_context'] + num_hidden_layers = parsed_splice_output['num_hidden_layers'] + splice_indexes = parsed_splice_output['splice_indexes'] + input_dim = len(parsed_splice_output['splice_indexes'][0]) + feat_dim + ivector_dim + + prior_scale_file = '{0}/presoftmax_prior_scale.vec'.format(config_dir) + + config_lines = {'components':[], 'component-nodes':[]} + + config_files={} + prev_layer_output = nodes.AddInputLayer(config_lines, feat_dim, splice_indexes[0], + ivector_dim) + + # Add the init config lines for estimating the preconditioning matrices + init_config_lines = copy.deepcopy(config_lines) + init_config_lines['components'].insert(0, '# Config file for initializing neural network prior to') + init_config_lines['components'].insert(0, '# preconditioning matrix computation') + + for o in output_nodes: + nodes.AddOutputLayer(init_config_lines, prev_layer_output, + objective_type = o.objective_type, suffix = o.output_suffix) + + config_files[config_dir + '/init.config'] = init_config_lines + + if cnn_layer is not None: + prev_layer_output = AddCnnLayers(config_lines, cnn_layer, cnn_bottleneck_dim, cepstral_lifter, config_dir, + feat_dim, splice_indexes[0], ivector_dim) + + # add_lda needs to be set "false" when using dense targets, + # or if the task is not a simple classification task + # (e.g. regression, multi-task) + if add_lda: + prev_layer_output = nodes.AddLdaLayer(config_lines, "L0", prev_layer_output, config_dir + '/lda.mat') + + left_context = 0 + right_context = 0 + # we moved the first splice layer to before the LDA.. + # so the input to the first affine layer is going to [0] index + splice_indexes[0] = [0] + + if not nonlin_output_dim is None: + nonlin_output_dims = [nonlin_output_dim] * num_hidden_layers + elif nonlin_output_dim_init < nonlin_output_dim_final and num_hidden_layers == 1: + raise Exception("num-hidden-layers has to be greater than 1 if relu-dim-init and relu-dim-final is different.") + else: + # computes relu-dim for each hidden layer. They increase geometrically across layers + factor = pow(float(nonlin_output_dim_final) / nonlin_output_dim_init, 1.0 / (num_hidden_layers - 1)) if num_hidden_layers > 1 else 1 + nonlin_output_dims = [int(round(nonlin_output_dim_init * pow(factor, i))) for i in range(0, num_hidden_layers)] + assert(nonlin_output_dims[-1] >= nonlin_output_dim_final - 1 and nonlin_output_dims[-1] <= nonlin_output_dim_final + 1) # due to rounding error + nonlin_output_dims[-1] = nonlin_output_dim_final # It ensures that the dim of the last hidden layer is exactly the same as what is specified + + for i in range(0, num_hidden_layers): + # make the intermediate config file for layerwise discriminative training + + # prepare the spliced input + if not (len(splice_indexes[i]) == 1 and splice_indexes[i][0] == 0): + try: + zero_index = splice_indexes[i].index(0) + except ValueError: + zero_index = None + # I just assume the prev_layer_output_descriptor is a simple forwarding descriptor + prev_layer_output_descriptor = prev_layer_output['descriptor'] + subset_output = prev_layer_output + if subset_dim > 0: + # if subset_dim is specified the script expects a zero in the splice indexes + assert(zero_index is not None) + subset_node_config = ("dim-range-node name=Tdnn_input_{0} " + "input-node={1} dim-offset={2} dim={3}".format( + i, prev_layer_output_descriptor, 0, subset_dim)) + subset_output = {'descriptor' : 'Tdnn_input_{0}'.format(i), + 'dimension' : subset_dim} + config_lines['component-nodes'].append(subset_node_config) + appended_descriptors = [] + appended_dimension = 0 + for j in range(len(splice_indexes[i])): + if j == zero_index: + appended_descriptors.append(prev_layer_output['descriptor']) + appended_dimension += prev_layer_output['dimension'] + continue + try: + offset = int(splice_indexes[i][j]) + # it's an integer offset. + appended_descriptors.append('Offset({0}, {1})'.format( + subset_output['descriptor'], splice_indexes[i][j])) + appended_dimension += subset_output['dimension'] + except ValueError: + # it's not an integer offset, so assume it specifies the + # statistics-extraction. + stats = nodes.StatisticsConfig(splice_indexes[i][j], prev_layer_output) + stats_layer = stats.AddLayer(config_lines, "Tdnn_stats_{0}".format(i)) + appended_descriptors.append(stats_layer['descriptor']) + appended_dimension += stats_layer['dimension'] + + prev_layer_output = {'descriptor' : "Append({0})".format(" , ".join(appended_descriptors)), + 'dimension' : appended_dimension} + else: + # this is a normal affine node + pass + + if nonlin_type == "relu": + prev_layer_output = nodes.AddAffRelNormLayer(config_lines, "Tdnn_{0}".format(i), + prev_layer_output, nonlin_output_dims[i], + self_repair_scale=self_repair_scale, + norm_target_rms=1.0 if i < num_hidden_layers -1 else final_layer_normalize_target) + elif nonlin_type == "pnorm": + prev_layer_output = nodes.AddAffPnormLayer(config_lines, "Tdnn_{0}".format(i), + prev_layer_output, nonlin_input_dim, nonlin_output_dim, + norm_target_rms=1.0 if i < num_hidden_layers -1 else final_layer_normalize_target) + else: + raise Exception("Unknown nonlinearity type") + # a final layer is added after each new layer as we are generating + # configs for layer-wise discriminative training + + AddOutputLayers(config_lines, prev_layer_output, output_nodes) + + config_files['{0}/layer{1}.config'.format(config_dir, i + 1)] = config_lines + config_lines = {'components':[], 'component-nodes':[]} + + left_context += int(parsed_splice_output['left_context']) + right_context += int(parsed_splice_output['right_context']) + + # write the files used by other scripts like steps/nnet3/get_egs.sh + f = open(config_dir + "/vars", "w") + print('model_left_context=' + str(left_context), file=f) + print('model_right_context=' + str(right_context), file=f) + print('num_hidden_layers=' + str(num_hidden_layers), file=f) + print('add_lda=' + ('true' if add_lda else 'false'), file=f) + f.close() + + # printing out the configs + # init.config used to train lda-mllt train + for key in config_files.keys(): + PrintConfig(key, config_files[key]) + +def ParseOutputNodesParameters(para_array): + output_parser = argparse.ArgumentParser() + output_parser.add_argument('--output-suffix', type=str, action=common_lib.NullstrToNoneAction, + help = "Name of the output node. e.g. output-xent") + output_parser.add_argument('--dim', type=int, required=True, + help = "Dimension of the output node") + output_parser.add_argument("--include-log-softmax", type=str, action=common_lib.StrToBoolAction, + help="add the final softmax layer ", + default=True, choices = ["false", "true"]) + output_parser.add_argument("--add-final-sigmoid", type=str, action=common_lib.StrToBoolAction, + help="add a sigmoid layer as the final layer. Applicable only if skip-final-softmax is true.", + choices=['true', 'false'], default = False) + output_parser.add_argument("--objective-type", type=str, default="linear", + choices = ["linear", "quadratic","xent-per-dim"], + help = "the type of objective; i.e. quadratic or linear") + output_parser.add_argument("--xent-regularize", type=float, + help="For chain models, if nonzero, add a separate output for cross-entropy " + "regularization (with learning-rate-factor equal to the inverse of this)", + default=0.0) + + output_nodes = [ output_parser.parse_args(shlex.split(x)) for x in para_array ] + + return output_nodes + +def Main(): + args = GetArgs() + + output_nodes = ParseOutputNodesParameters(args.output_node_para_array) + + MakeConfigs(config_dir = args.config_dir, + feat_dim = args.feat_dim, ivector_dim = args.ivector_dim, + add_lda = args.add_lda, + cepstral_lifter = args.cepstral_lifter, + splice_indexes_string = args.splice_indexes, + cnn_layer = args.cnn_layer, + cnn_bottleneck_dim = args.cnn_bottleneck_dim, + nonlin_type = args.nonlin_type, + nonlin_input_dim = args.nonlin_input_dim, + nonlin_output_dim = args.nonlin_output_dim, + subset_dim = args.subset_dim, + nonlin_output_dim_init = args.nonlin_output_dim_init, + nonlin_output_dim_final = args.nonlin_output_dim_final, + use_presoftmax_prior_scale = args.use_presoftmax_prior_scale, + final_layer_normalize_target = args.final_layer_normalize_target, + output_nodes = output_nodes, + self_repair_scale = args.self_repair_scale_nonlinearity) + +if __name__ == "__main__": + Main() + + diff --git a/egs/aspire/s5/local/segmentation/prepare_ami.sh b/egs/aspire/s5/local/segmentation/prepare_ami.sh new file mode 100755 index 00000000000..7147a3004cb --- /dev/null +++ b/egs/aspire/s5/local/segmentation/prepare_ami.sh @@ -0,0 +1,223 @@ +#! /bin/bash + +# Copyright 2016 Vimal Manohar +# Apache 2.0. + +. cmd.sh +. path.sh + +set -e +set -o pipefail +set -u + +stage=-1 + +dataset=dev +nj=18 + +. utils/parse_options.sh + +export PATH=$KALDI_ROOT/tools/sctk/bin:$PATH + +src_dir=/export/a09/vmanoha1/workspace_asr_diarization/egs/ami/s5b # AMI src_dir +dir=exp/sad_ami_sdm1_${dataset}/ref + +mkdir -p $dir + +# Expecting user to have done run.sh to run the AMI recipe in $src_dir for +# both sdm and ihm microphone conditions + +if [ $stage -le 1 ]; then + ( + cd $src_dir + local/prepare_parallel_train_data.sh --train-set ${dataset} sdm1 + + awk '{print $1" "$2}' $src_dir/data/ihm/${dataset}/segments > \ + $src_dir/data/ihm/${dataset}/utt2reco + awk '{print $1" "$2}' $src_dir/data/sdm1/${dataset}/segments > \ + $src_dir/data/sdm1/${dataset}/utt2reco + + cat $src_dir/data/sdm1/${dataset}_ihmdata/ihmutt2utt | \ + utils/filter_scp.pl -f 1 $src_dir/data/ihm/${dataset}/utt2reco | \ + utils/apply_map.pl -f 1 $src_dir/data/ihm/${dataset}/utt2reco | \ + utils/filter_scp.pl -f 2 $src_dir/data/sdm1/${dataset}/utt2reco | \ + utils/apply_map.pl -f 2 $src_dir/data/sdm1/${dataset}/utt2reco | \ + sort -u > $src_dir/data/sdm1/${dataset}_ihmdata/ihm2sdm_reco + ) +fi + +[ ! -s $src_dir/data/sdm1/${dataset}_ihmdata/ihm2sdm_reco ] && echo "Empty $src_dir/data/sdm1/${dataset}_ihmdata/ihm2sdm_reco!" && exit 1 + +phone_map=$dir/phone_map +if [ $stage -le 2 ]; then + ( + cd $src_dir + utils/data/get_reco2utt.sh $src_dir/data/sdm1/${dataset} + + steps/make_mfcc.sh --nj $nj --cmd "$train_cmd" \ + data/sdm1/${dataset}_ihmdata exp/sdm1/make_mfcc mfcc_sdm1 + steps/compute_cmvn_stats.sh \ + data/sdm1/${dataset}_ihmdata exp/sdm1/make_mfcc mfcc_sdm1 + utils/fix_data_dir.sh data/sdm1/${dataset}_ihmdata + ) + + steps/segmentation/get_sad_map.py \ + $src_dir/data/lang | utils/sym2int.pl -f 1 $src_dir/data/lang/phones.txt > \ + $phone_map +fi + +if [ $stage -le 3 ]; then + # Expecting user to have run local/run_cleanup_segmentation.sh in $src_dir + ( + cd $src_dir + steps/align_fmllr.sh --nj $nj --cmd "$train_cmd" \ + data/sdm1/${dataset}_ihmdata data/lang \ + exp/ihm/tri3_cleaned \ + exp/sdm1/tri3_cleaned_${dataset}_ihmdata + ) +fi + +if [ $stage -le 4 ]; then + steps/segmentation/internal/convert_ali_to_vad.sh --cmd "$train_cmd" \ + $src_dir/exp/sdm1/tri3_cleaned_${dataset}_ihmdata $phone_map $dir +fi + +echo "A 1" > $dir/channel_map +cat $src_dir/data/sdm1/${dataset}/reco2file_and_channel | \ + utils/apply_map.pl -f 3 $dir/channel_map > $dir/reco2file_and_channel + +utils/data/get_reco2utt.sh $src_dir/data/sdm1/${dataset}_ihmdata +cat $src_dir/data/sdm1/${dataset}_ihmdata/reco2utt | \ + awk 'BEGIN{i=1} {print $1" "i; i++;}' > \ + $src_dir/data/sdm1/${dataset}_ihmdata/reco.txt + +if [ $stage -le 5 ]; then + # Reference RTTM where SPEECH frames are obtainted by combining IHM VAD alignments + cat $src_dir/data/sdm1/${dataset}_ihmdata/reco.txt | \ + awk '{print $1" 1:"$2" 10000:10000 0:0"}' > $dir/ref_spk2label_map + + $train_cmd $dir/log/get_ref_spk_seg.log \ + segmentation-combine-segments --include-missing-utt-level-segmentations scp:$dir/sad_seg.scp \ + "ark:segmentation-init-from-segments --segment-label=10000 --shift-to-zero=false $src_dir/data/sdm1/${dataset}_ihmdata/segments ark:- |" \ + ark,t:$src_dir/data/sdm1/${dataset}_ihmdata/reco2utt ark:- \| \ + segmentation-copy --utt2label-map-rspecifier=ark,t:$dir/ref_spk2label_map \ + ark:- ark:- \| \ + segmentation-merge-recordings \ + "ark,t:utils/utt2spk_to_spk2utt.pl $src_dir/data/sdm1/${dataset}_ihmdata/ihm2sdm_reco |" \ + ark:- "ark:| gzip -c > $dir/ref_spk_seg.gz" +fi + +if [ $stage -le 6 ]; then + utils/data/get_reco2num_frames.sh --frame-shift 0.01 --frame-overlap 0.015 \ + --cmd queue.pl --nj $nj \ + $src_dir/data/sdm1/${dataset} + + ## Get a filter that selects only regions within the manual segments. + #$train_cmd $dir/log/get_manual_segments_regions.log \ + # segmentation-init-from-segments --shift-to-zero=false $src_dir/data/sdm1/${dataset}/segments ark:- \| \ + # segmentation-combine-segments-to-recordings ark:- ark,t:$src_dir/data/sdm1/${dataset}/reco2utt ark:- \| \ + # segmentation-create-subsegments --filter-label=1 --subsegment-label=1 \ + # "ark:segmentation-init-from-lengths --label=0 ark,t:$src_dir/data/sdm1/${dataset}/reco2num_frames ark:- |" ark:- ark,t:- \| \ + # perl -ane '$F[3] = 10000; $F[$#F-1] = 10000; print join(" ", @F) . "\n";' \| \ + # segmentation-create-subsegments --filter-label=10000 --subsegment-label=10000 \ + # ark,t:- "ark:gunzip -c $dir/ref_spk_seg.gz |" ark:- \| \ + # segmentation-post-process --merge-labels=0:1 --merge-dst-label=1 ark:- ark:- \| \ + # segmentation-post-process --merge-labels=10000 --merge-dst-label=0 --merge-adjacent-segments \ + # --max-intersegment-length=10000 ark,t:- \ + # "ark:| gzip -c > $dir/manual_segments_regions.seg.gz" +fi + +if [ $stage -le 7 ]; then + $train_cmd $dir/log/get_overlap_sad_seg.log \ + segmentation-get-stats --lengths-rspecifier=ark,t:$src_dir/data/sdm1/${dataset}/reco2num_frames \ + "ark:gunzip -c $dir/ref_spk_seg.gz |" \ + ark:/dev/null ark:/dev/null ark:- \| \ + classes-per-frame-to-labels --junk-label=10000 ark:- ark:- \| \ + segmentation-init-from-ali ark:- \ + "ark:| gzip -c > $dir/overlap_sad_seg.gz" +fi + +if [ $stage -le 8 ]; then + # To get the actual RTTM, we need to add no-score + $train_cmd $dir/log/get_ref_rttm.log \ + gunzip -c $dir/overlap_sad_seg.gz \| \ + segmentation-post-process --merge-labels=1:2 --merge-dst-label=1 \ + ark:- ark:- \| \ + segmentation-to-rttm --reco2file-and-channel=$dir/reco2file_and_channel \ + --no-score-label=10000 ark:- $dir/ref.rttm + + # Get RTTM for overlapped speech detection with 3 classes + # 0 -> SILENCE, 1 -> SINGLE_SPEAKER, 2 -> OVERLAP + $train_cmd $dir/log/get_ref_rttm.log \ + gunzip -c $dir/overlap_sad_seg.gz \| \ + segmentation-to-rttm --reco2file-and-channel=$dir/reco2file_and_channel \ + --no-score-label=10000 --map-to-speech-and-sil=false ark:- $dir/overlapping_speech_ref.rttm +fi + + +#if [ $stage -le 8 ]; then +# # Get RTTM for overlapped speech detection with 3 classes +# # 0 -> SILENCE, 1 -> SINGLE_SPEAKER, 2 -> OVERLAP +# $train_cmd $dir/log/get_overlapping_rttm.log \ +# segmentation-get-stats --lengths-rspecifier=ark,t:$src_dir/data/sdm1/${dataset}/reco2num_frames \ +# "ark:gunzip -c $dir/ref_spk_seg.gz | segmentation-post-process --remove-labels=0:10000 ark:- ark:- |" \ +# ark:/dev/null ark:- \| \ +# segmentation-init-from-ali ark:- ark:- \| \ +# segmentation-post-process --merge-labels=2:3:4:5:6:7:8:9:10 --merge-dst-label=2 \ +# --merge-adjacent-segments --max-intersegment-length=10000 ark:- ark:- \| \ +# segmentation-create-subsegments --filter-label=0 --subsegment-label=10000 \ +# ark:- "ark:gunzip -c $dir/manual_segments_regions.seg.gz |" ark:- \| \ +# segmentation-post-process --merge-adjacent-segments --max-intersegment-length=10000 ark:- ark:- \| \ +# segmentation-to-rttm --map-to-speech-and-sil=false --reco2file-and-channel=$dir/reco2file_and_channel \ +# --no-score-label=10000 ark:- $dir/overlapping_speech_ref.rttm +#fi + +# make $dir an absolute pathname. +dir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' $dir ${PWD}` + +if [ $stage -le 9 ]; then + # Get a filter that selects only regions of speech + $train_cmd $dir/log/get_speech_filter.log \ + gunzip -c $dir/overlap_sad_seg.gz \| \ + segmentation-post-process --merge-labels=1:2 --merge-dst-label=1 ark:- ark:- \| \ + segmentation-post-process --remove-labels=10000 ark:- ark:- \| \ + segmentation-to-ali --lengths-rspecifier=ark,t:$src_dir/data/sdm1/${dataset}/reco2num_frames \ + ark:- ark,t:- \| \ + steps/segmentation/convert_ali_to_vec.pl \| \ + copy-vector ark,t: ark,scp:$dir/deriv_weights_for_overlapping_sad.ark,$dir/deriv_weights_for_overlapping_sad.scp + + # Get deriv weights + $train_cmd $dir/log/get_speech_filter.log \ + gunzip -c $dir/overlap_sad_seg.gz \| \ + segmentation-post-process --merge-labels=0:1:2 --merge-dst-label=1 ark:- ark:- \| \ + segmentation-post-process --remove-labels=10000 ark:- ark:- \| \ + segmentation-to-ali --lengths-rspecifier=ark,t:$src_dir/data/sdm1/${dataset}/reco2num_frames \ + ark:- ark,t:- \| \ + steps/segmentation/convert_ali_to_vec.pl \| \ + copy-vector ark,t: ark,scp:$dir/deriv_weights.ark,$dir/deriv_weights.scp +fi + +if [ $stage -le 10 ]; then + $train_cmd $dir/log/get_overlapping_sad.log \ + gunzip -c $dir/overlap_sad_seg.gz \| \ + segmentation-post-process --remove-labels=10000 ark:- ark:- \| \ + segmentation-to-ali --lengths-rspecifier=ark,t:$src_dir/data/sdm1/${dataset}/reco2num_frames \ + ark:- ark,scp:$dir/overlapping_sad_labels.ark,$dir/overlapping_sad_labels.scp +fi + +if false && [ $stage -le 11 ]; then + utils/data/convert_data_dir_to_whole.sh \ + $src_dir/data/sdm1/${dataset} data/ami_sdm1_${dataset}_whole + utils/fix_data_dir.sh \ + data/ami_sdm1_${dataset}_whole + utils/copy_data_dir.sh \ + data/ami_sdm1_${dataset}_whole data/ami_sdm1_${dataset}_whole_hires_bp + utils/data/downsample_data_dir.sh 8000 data/ami_sdm1_${dataset}_whole_hires_bp + + steps/make_mfcc.sh --mfcc-config conf/mfcc_hires_bp.conf --nj $nj \ + data/ami_sdm1_${dataset}_whole_hires_bp exp/make_hires_bp mfcc_hires_bp + steps/compute_cmvn_stats.sh --fake \ + data/ami_sdm1_${dataset}_whole_hires_bp exp/make_hires_bp mfcc_hires_bp + utils/fix_data_dir.sh \ + data/ami_sdm1_${dataset}_whole_hires_bp +fi diff --git a/egs/aspire/s5/local/segmentation/prepare_babel_data.sh b/egs/aspire/s5/local/segmentation/prepare_babel_data.sh new file mode 100644 index 00000000000..e70dc216980 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/prepare_babel_data.sh @@ -0,0 +1,105 @@ +#! /bin/bash + +# Copyright 2016 Vimal Manohar +# Apache 2.0. + +# This script prepares Babel data for training speech activity detection, +# music detection. + +. path.sh +. cmd.sh + +set -e +set -o pipefail +set -u + +lang_id=assamese +subset= # Number of recordings to keep before speed perturbation and corruption. + # In limitedLP, this is about 120. So subset, if specified, must be lower that that. + +# All the paths below can be modified to any absolute path. +ROOT_DIR=/home/vimal/workspace_waveform/egs/babel/s5c_assamese/ + +stage=-1 + +. utils/parse_options.sh + +if [ $# -ne 0 ]; then + echo "Usage: $0" + echo "This script is to serve as an example recipe." + echo "Edit the script to change variables if needed." + exit 1 +fi + +dir=exp/unsad/make_unsad_babel_${lang_id}_train # Work dir + +model_dir=$ROOT_DIR/exp/tri4 # Model directory used for decoding +sat_model_dir=$ROOT_DIR/exp/tri5 # Model directory used for getting alignments +lang=$ROOT_DIR/data/lang # Language directory +lang_test=$ROOT_DIR/data/lang # Language directory used to build graph + +mkdir -p $dir + +# Hard code the mapping from phones to SAD labels +# 0 for silence, 1 for speech, 2 for noise, 3 for unk +cat < $dir/babel_sad.map + 3 +_B 3 +_E 3 +_I 3 +_S 3 + 2 +_B 2 +_E 2 +_I 2 +_S 2 + 2 +_B 2 +_E 2 +_I 2 +_S 2 +SIL 0 +SIL_B 0 +SIL_E 0 +SIL_I 0 +SIL_S 0 +EOF + +# The original data directory which will be converted to a whole (recording-level) directory. +utils/copy_data_dir.sh $ROOT_DIR/data/train data/babel_${lang_id}_train +train_data_dir=data/babel_${lang_id}_train + +# Expecting the user to have done run.sh to have $model_dir, +# $sat_model_dir, $lang, $lang_test, $train_data_dir +local/segmentation/prepare_unsad_data.sh \ + --sad-map $dir/babel_sad.map \ + --config-dir $ROOT_DIR/conf --feat-type plp --add-pitch true \ + --reco-nj 40 --nj 100 --cmd "$train_cmd" \ + --sat-model-dir $sat_model_dir \ + --lang-test $lang_test \ + $train_data_dir $lang $model_dir $dir + +orig_data_dir=${train_data_dir}_sp + +data_dir=${train_data_dir}_whole + +if [ ! -z $subset ]; then + # Work on a subset + utils/subset_data_dir.sh ${data_dir} $subset \ + ${data_dir}_$subset + data_dir=${data_dir}_$subset +fi + +reco_vad_dir=$dir/`basename $model_dir`_reco_vad_`basename $train_data_dir`_sp + +# Add noise from MUSAN corpus to data directory and create a new data directory +local/segmentation/do_corruption_data_dir_snr.sh \ + --data-dir $data_dir \ + --reco-vad-dir $reco_vad_dir \ + --feat-suffix hires_bp --mfcc-config conf/mfcc_hires_bp.conf + +# Add music from MUSAN corpus to data directory and create a new data directory +local/segmentation/do_corruption_data_dir_music.sh \ + --data-dir $data_dir \ + --reco-vad-dir $reco_vad_dir \ + --feat-suffix hires_bp --mfcc-config conf/mfcc_hires_bp.conf diff --git a/egs/aspire/s5/local/segmentation/prepare_babel_data_overlapped_speech.sh b/egs/aspire/s5/local/segmentation/prepare_babel_data_overlapped_speech.sh new file mode 100644 index 00000000000..a3e087d95ec --- /dev/null +++ b/egs/aspire/s5/local/segmentation/prepare_babel_data_overlapped_speech.sh @@ -0,0 +1,112 @@ +#! /bin/bash + +# Copyright 2016 Vimal Manohar +# Apache 2.0. + +# This script prepares Babel data for training speech activity detection, +# music detection, and overlapped speech detection systems. + +. path.sh +. cmd.sh + +set -e +set -o pipefail +set -u + +lang_id=assamese +subset=150 # Number of recordings to keep before speed perturbation and corruption +utt_subset=30000 # Number of utterances to keep after speed perturbation for adding overlapped-speech + +# All the paths below can be modified to any absolute path. +ROOT_DIR=/home/vimal/workspace_waveform/egs/babel/s5c_assamese/ + +. utils/parse_options.sh + +if [ $# -ne 0 ]; then + echo "Usage: $0" + echo "This script is to serve as an example recipe." + echo "Edit the script to change variables if needed." + exit 1 +fi + +dir=exp/unsad/make_unsad_babel_${lang_id}_train # Work dir + +# The original data directory which will be converted to a whole (recording-level) directory. +train_data_dir=$ROOT_DIR/data/train + +model_dir=$ROOT_DIR/exp/tri4 # Model directory used for decoding +sat_model_dir=$ROOT_DIR/exp/tri5 # Model directory used for getting alignments +lang=$ROOT_DIR/data/lang # Language directory +lang_test=$ROOT_DIR/data/lang # Language directory used to build graph + +# Hard code the mapping from phones to SAD labels +# 0 for silence, 1 for speech, 2 for noise, 3 for unk +cat < $dir/babel_sad.map + 3 +_B 3 +_E 3 +_I 3 +_S 3 + 2 +_B 2 +_E 2 +_I 2 +_S 2 + 2 +_B 2 +_E 2 +_I 2 +_S 2 +SIL 0 +SIL_B 0 +SIL_E 0 +SIL_I 0 +SIL_S 0 +EOF + +# Expecting the user to have done run.sh to have $model_dir, +# $sat_model_dir, $lang, $lang_test, $train_data_dir +local/segmentation/prepare_unsad_data.sh \ + --sad-map $dir/babel_sad.map \ + --config-dir $ROOT_DIR/conf \ + --reco-nj 40 --nj 100 --cmd "$train_cmd" \ + --sat-model $sat_model_dir \ + --lang-test $lang_test \ + $train_data_dir $lang $model_dir $dir + +orig_data_dir=${train_data_dir}_sp + +data_dir=${train_data_dir}_whole + +if [ ! -z $subset ]; then + # Work on a subset + utils/subset_data_dir.sh ${data_dir} $subset \ + ${data_dir}_$subset + data_dir=${data_dir}_$subset +fi + +reco_vad_dir=$dir/`basename $model_dir`_reco_vad_`basename $train_data_dir`_sp + +# Add noise from MUSAN corpus to data directory and create a new data directory +local/segmentation/do_corruption_data_dir.sh \ + --data-dir $data_dir \ + --reco-vad-dir $reco_vad_dir \ + --feat-suffix hires_bp --mfcc-config conf/mfcc_hires_bp.conf + +# Add music from MUSAN corpus to data directory and create a new data directory +local/segmentation/do_corruption_data_dir_music.sh \ + --data-dir $data_dir \ + --reco-vad-dir $reco_vad_dir \ + --feat-suffix hires_bp --mfcc-config conf/mfcc_hires_bp.conf + +if [ ! -z $utt_subset ]; then + utils/subset_data_dir.sh ${orig_data_dir} $utt_subset \ + ${orig_data_dir}_`echo $utt_subset | perl -e 's/000$/k/'` + orig_data_dir=${orig_data_dir}_`echo $utt_subset | perl -e 's/000$/k/'` +fi + +# Add overlapping speech from $orig_data_dir/segments and create a new data directory +utt_vad_dir=$dir/`baseline $sat_model_dir`_ali_`basename $train_data_dir`_sp_vad_`basename $train_data_dir`_sp +local/segmentation/do_corruption_data_dir_overlapped_speech.sh \ + --data-dir ${orig_data_dir} \ + --utt-vad-dir $utt_vad_dir diff --git a/egs/aspire/s5/local/segmentation/prepare_fisher_data.sh b/egs/aspire/s5/local/segmentation/prepare_fisher_data.sh new file mode 100644 index 00000000000..4f55cc6929e --- /dev/null +++ b/egs/aspire/s5/local/segmentation/prepare_fisher_data.sh @@ -0,0 +1,101 @@ +#! /bin/bash + +# This script prepares Fisher data for training a speech activity detection +# and music detection system + +# Copyright 2016 Vimal Manohar +# Apache 2.0. + +. path.sh +. cmd.sh + +set -e -o pipefail + +if [ $# -ne 0 ]; then + echo "Usage: $0" + echo "This script is to serve as an example recipe." + echo "Edit the script to change variables if needed." + exit 1 +fi + +dir=exp/unsad/make_unsad_fisher_train_100k # Work dir +subset=900 + +# All the paths below can be modified to any absolute path. + +# The original data directory which will be converted to a whole (recording-level) directory. +train_data_dir=data/fisher_train_100k + +model_dir=exp/tri3a # Model directory used for decoding +sat_model_dir=exp/tri4a # Model directory used for getting alignments +lang=data/lang # Language directory +lang_test=data/lang_test # Language directory used to build graph + +# Hard code the mapping from phones to SAD labels +# 0 for silence, 1 for speech, 2 for noise, 3 for unk +cat < $dir/fisher_sad.map +sil 0 +sil_B 0 +sil_E 0 +sil_I 0 +sil_S 0 +laughter 2 +laughter_B 2 +laughter_E 2 +laughter_I 2 +laughter_S 2 +noise 2 +noise_B 2 +noise_E 2 +noise_I 2 +noise_S 2 +oov 3 +oov_B 3 +oov_E 3 +oov_I 3 +oov_S 3 +EOF + +if [ ! -d RIRS_NOISES/ ]; then + # Prepare MUSAN rirs and noises + wget --no-check-certificate http://www.openslr.org/resources/28/rirs_noises.zip + unzip rirs_noises.zip +fi + +if [ ! -d RIRS_NOISES/music ]; then + # Prepare MUSAN music + local/segmentation/prepare_musan_music.sh /export/corpora/JHU/musan RIRS_NOISES/music +fi + +# Expecting the user to have done run.sh to have $model_dir, +# $sat_model_dir, $lang, $lang_test, $train_data_dir +local/segmentation/prepare_unsad_data.sh \ + --sad-map $dir/fisher_sad.map \ + --config-dir conf \ + --reco-nj 40 --nj 100 --cmd "$train_cmd" \ + --sat-model-dir $sat_model_dir \ + --lang-test $lang_test \ + $train_data_dir $lang $model_dir $dir + +data_dir=${train_data_dir}_whole + +if [ ! -z $subset ]; then + # Work on a subset + false && utils/subset_data_dir.sh ${data_dir} $subset \ + ${data_dir}_$subset + data_dir=${data_dir}_$subset +fi + +reco_vad_dir=$dir/`basename $model_dir`_reco_vad_`basename $train_data_dir`_sp + +# Add noise from MUSAN corpus to data directory and create a new data directory +local/segmentation/do_corruption_data_dir_snr.sh \ + --data-dir $data_dir \ + --reco-vad-dir $reco_vad_dir \ + --feat-suffix hires_bp --mfcc-config conf/mfcc_hires_bp.conf + +# Add music from MUSAN corpus to data directory and create a new data directory +local/segmentation/do_corruption_data_dir_music.sh \ + --data-dir $data_dir \ + --reco-vad-dir $reco_vad_dir \ + --feat-suffix hires_bp --mfcc-config conf/mfcc_hires_bp.conf diff --git a/egs/aspire/s5/local/segmentation/prepare_fisher_data_overlapped_speech.sh b/egs/aspire/s5/local/segmentation/prepare_fisher_data_overlapped_speech.sh new file mode 100644 index 00000000000..79a03fa9e9d --- /dev/null +++ b/egs/aspire/s5/local/segmentation/prepare_fisher_data_overlapped_speech.sh @@ -0,0 +1,113 @@ +#! /bin/bash + +# Copyright 2016 Vimal Manohar +# Apache 2.0. + +# This script prepares Fisher data for training speech activity detection, +# music detection, and overlapped speech detection systems. + +. path.sh +. cmd.sh + +if [ $# -ne 0 ]; then + echo "Usage: $0" + echo "This script is to serve as an example recipe." + echo "Edit the script to change variables if needed." + exit 1 +fi + +dir=exp/unsad/make_unsad_fisher_train_100k # Work dir +subset=60 # Number of recordings to keep before speed perturbation and corruption +utt_subset=75000 # Number of utterances to keep after speed perturbation for adding overlapped-speech + +# All the paths below can be modified to any absolute path. + +# The original data directory which will be converted to a whole (recording-level) directory. +train_data_dir=data/fisher_train_100k + +model_dir=exp/tri3a # Model directory used for decoding +sat_model_dir=exp/tri4a # Model directory used for getting alignments +lang=data/lang # Language directory +lang_test=data/lang_test # Language directory used to build graph + +# Hard code the mapping from phones to SAD labels +# 0 for silence, 1 for speech, 2 for noise, 3 for unk +cat < $dir/fisher_sad.map +sil 0 +sil_B 0 +sil_E 0 +sil_I 0 +sil_S 0 +laughter 2 +laughter_B 2 +laughter_E 2 +laughter_I 2 +laughter_S 2 +noise 2 +noise_B 2 +noise_E 2 +noise_I 2 +noise_S 2 +oov 3 +oov_B 3 +oov_E 3 +oov_I 3 +oov_S 3 +EOF + +# Expecting the user to have done run.sh to have $model_dir, +# $sat_model_dir, $lang, $lang_test, $train_data_dir +local/segmentation/prepare_unsad_data.sh \ + --sad-map $dir/fisher_sad.map \ + --config-dir conf \ + --reco-nj 40 --nj 100 --cmd "$train_cmd" \ + --sat-model $sat_model_dir \ + --lang-test $lang_test \ + $train_data_dir $lang $model_dir $dir + +orig_data_dir=${train_data_dir}_sp + +data_dir=${train_data_dir}_whole + +if [ ! -z $subset ]; then + # Work on a subset + utils/subset_data_dir.sh ${data_dir} $subset \ + ${data_dir}_$subset + data_dir=${data_dir}_$subset +fi + +reco_vad_dir=$dir/`basename $model_dir`_reco_vad_`basename $train_data_dir`_sp + +# Add noise from MUSAN corpus to data directory and create a new data directory +local/segmentation/do_corruption_data_dir.sh \ + --num-data-reps 5 \ + --data-dir $data_dir \ + --reco-vad-dir $reco_vad_dir + --feat-suffix hires_bp --mfcc-config conf/mfcc_hires_bp.conf + +# Add music from MUSAN corpus to data directory and create a new data directory +local/segmentation/do_corruption_data_dir_music.sh \ + --num-data-reps 5 \ + --data-dir $data_dir \ + --reco-vad-dir $reco_vad_dir + --feat-suffix hires_bp --mfcc-config conf/mfcc_hires_bp.conf + +if [ ! -z $utt_subset ]; then + utils/subset_data_dir.sh ${orig_data_dir} $utt_subset \ + ${orig_data_dir}_`echo $utt_subset | perl -e 's/000$/k/'` + orig_data_dir=${orig_data_dir}_`echo $utt_subset | perl -e 's/000$/k/'` +fi + +# Add overlapping speech from $orig_data_dir/segments and create a new data directory +utt_vad_dir=$dir/`baseline $sat_model_dir`_ali_`basename $train_data_dir`_sp_vad_`basename $train_data_dir`_sp +local/segmentation/do_corruption_data_dir_overlapped_speech.sh \ + --nj 40 --cmd queue.pl \ + --num-data-reps 1 \ + --data-dir ${orig_data_dir} \ + --utt-vad-dir $utt_vad_dir + +local/segmentation/prepare_unsad_overlapped_speech_labels.sh \ + --num-data-reps 1 --nj 40 --cmd queue.pl \ + ${orig_data_dir}_ovlp_corrupted_hires_bp \ + ${orig_data_dir}_ovlp_corrupted/overlapped_segments_info.txt \ + $utt_vad_dir exp/make_overlap_labels overlap_labels diff --git a/egs/aspire/s5/local/segmentation/prepare_musan_music.sh b/egs/aspire/s5/local/segmentation/prepare_musan_music.sh new file mode 100644 index 00000000000..16fb946b0c8 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/prepare_musan_music.sh @@ -0,0 +1,24 @@ +#! /bin/bash + +# Copyright 2016 Vimal Manohar +# Apache 2.0 + +if [ $# -ne 2 ]; then + echo "Usage: $0 " + echo " e.g.: $0 /export/corpora/JHU/musan RIRS_NOISES/music" + exit 1 +fi + +SRC_DIR=$1 +dir=$2 + +mkdir -p $dir + +local/segmentation/make_musan_music.py $SRC_DIR $dir/wav.scp + +wav-to-duration scp:$dir/wav.scp ark,t:$dir/reco2dur +steps/data/split_wavs_randomly.py $dir/wav.scp $dir/reco2dur \ + $dir/split_utt2dur $dir/split_wav.scp + +awk '{print $1" "int($2*100)}' $dir/split_utt2dur > $dir/split_utt2num_frames +steps/data/wav_scp2noise_list.py $dir/split_wav.scp $dir/music_list diff --git a/egs/aspire/s5/local/segmentation/prepare_unsad_data.sh b/egs/aspire/s5/local/segmentation/prepare_unsad_data.sh new file mode 100755 index 00000000000..cccc7e2db84 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/prepare_unsad_data.sh @@ -0,0 +1,518 @@ +#!/bin/bash + +# This script prepares speech labels and deriv weights for +# training unsad network for speech activity detection and music detection. + +set -u +set -o pipefail +set -e + +. path.sh + +stage=-2 +cmd=queue.pl +reco_nj=40 +nj=100 + +# Options to be passed to get_sad_map.py +map_noise_to_sil=true # Map noise phones to silence label (0) +map_unk_to_speech=true # Map unk phones to speech label (1) +sad_map= # Initial mapping from phones to speech/non-speech labels. + # Overrides the default mapping using phones/silence.txt + # and phones/nonsilence.txt + +# Options for feature extraction +feat_type=mfcc # mfcc or plp +add_pitch=false # Add pitch features + +config_dir=conf +feat_config= +pitch_config= + +mfccdir=mfcc +plpdir=plp + +speed_perturb=true + +sat_model_dir= # Model directory used for getting alignments +lang_test= # Language directory used to build graph. + # If its not provided, $lang will be used instead. + +. utils/parse_options.sh + +if [ $# -ne 4 ]; then + echo "This script takes a data directory and creates a new data directory " + echo "and speech activity labels" + echo "for the purpose of training a Universal Speech Activity Detector." + echo "Usage: $0 [options] " + echo " e.g.: $0 data/train_100k data/lang exp/tri4a exp/vad_data_prep" + echo "" + echo "Main options (for others, see top of script file)" + echo " --config # config file containing options" + echo " --cmd (run.pl|/queue.pl ) # how to run jobs." + echo " --reco-nj <#njobs|4> # Split a whole data directory into these many pieces" + echo " --nj <#njobs|4> # Split a segmented data directory into these many pieces" + exit 1 +fi + +data_dir=$1 +lang=$2 +model_dir=$3 +dir=$4 + +if [ $feat_type != "plp" ] && [ $feat_type != "mfcc" ]; then + echo "$0: --feat-type must be plp or mfcc. Must match the model_dir used." + exit 1 +fi + +[ -z "$feat_config" ] && feat_config=$config_dir/$feat_type.conf +[ -z "$pitch_config" ] && pitch_config=$config_dir/pitch.conf + +extra_files= + +if $add_pitch; then + extra_files="$extra_files $pitch_config" +fi + +for f in $feat_config $extra_files; do + if [ ! -f $f ]; then + echo "$f could not be found" + exit 1 + fi +done + +mkdir -p $dir + +function make_mfcc { + local nj=$nj + local mfcc_config=$feat_config + local add_pitch=$add_pitch + local cmd=$cmd + local pitch_config=$pitch_config + + while [ $# -gt 0 ]; do + if [ $1 == "--nj" ]; then + nj=$2 + shift; shift; + elif [ $1 == "--mfcc-config" ]; then + mfcc_config=$2 + shift; shift; + elif [ $1 == "--add-pitch" ]; then + add_pitch=$2 + shift; shift; + elif [ $1 == "--cmd" ]; then + cmd=$2 + shift; shift; + elif [ $1 == "--pitch-config" ]; then + pitch_config=$2 + shift; shift; + else + break + fi + done + + if [ $# -ne 3 ]; then + echo "Usage: make_mfcc " + exit 1 + fi + + if $add_pitch; then + steps/make_mfcc_pitch.sh --cmd "$cmd" --nj $nj \ + --mfcc-config $mfcc_config --pitch-config $pitch_config $1 $2 $3 || exit 1 + else + steps/make_mfcc.sh --cmd "$cmd" --nj $nj \ + --mfcc-config $mfcc_config $1 $2 $3 || exit 1 + fi + +} + +function make_plp { + local nj=$nj + local mfcc_config=$feat_config + local add_pitch=$add_pitch + local cmd=$cmd + local pitch_config=$pitch_config + + while [ $# -gt 0 ]; do + if [ $1 == "--nj" ]; then + nj=$2 + shift; shift; + elif [ $1 == "--plp-config" ]; then + plp_config=$2 + shift; shift; + elif [ $1 == "--add-pitch" ]; then + add_pitch=$2 + shift; shift; + elif [ $1 == "--cmd" ]; then + cmd=$2 + shift; shift; + elif [ $1 == "--pitch-config" ]; then + pitch_config=$2 + shift; shift; + else + break + fi + done + + if [ $# -ne 3 ]; then + echo "Usage: make_plp " + exit 1 + fi + + if $add_pitch; then + steps/make_plp_pitch.sh --cmd "$cmd" --nj $nj \ + --plp-config $plp_config --pitch-config $pitch_config $1 $2 $3 || exit 1 + else + steps/make_plp.sh --cmd "$cmd" --nj $nj \ + --plp-config $plp_config $1 $2 $3 || exit 1 + fi +} + +frame_shift_info=`cat $feat_config | steps/segmentation/get_frame_shift_info_from_config.pl` || exit 1 + +frame_shift=`echo $frame_shift_info | awk '{print $1}'` +frame_overlap=`echo $frame_shift_info | awk '{print $2}'` + +data_id=$(basename $data_dir) +whole_data_dir=${data_dir}_whole +whole_data_id=${data_id}_whole + +if [ $stage -le -2 ]; then + steps/segmentation/get_sad_map.py \ + --init-sad-map="$sad_map" \ + --map-noise-to-sil=$map_noise_to_sil \ + --map-unk-to-speech=$map_unk_to_speech \ + $lang | utils/sym2int.pl -f 1 $lang/phones.txt > $dir/sad_map + + utils/data/convert_data_dir_to_whole.sh ${data_dir} ${whole_data_dir} + utils/data/get_utt2dur.sh ${whole_data_dir} +fi + +if $speed_perturb; then + plpdir=${plpdir}_sp + mfccdir=${mfccdir}_sp + + if [ $stage -le -1 ]; then + utils/data/perturb_data_dir_speed_3way.sh ${whole_data_dir} ${whole_data_dir}_sp + utils/data/perturb_data_dir_speed_3way.sh ${data_dir} ${data_dir}_sp + + if [ $feat_type == "mfcc" ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $mfccdir/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$mfccdir/storage $mfccdir/storage + fi + make_mfcc --cmd "$cmd --max-jobs-run 40" --nj $nj \ + --mfcc-config $feat_config \ + --add-pitch $add_pitch --pitch-config $pitch_config \ + ${whole_data_dir}_sp exp/make_mfcc $mfccdir || exit 1 + steps/compute_cmvn_stats.sh \ + ${whole_data_dir}_sp exp/make_mfcc $mfccdir || exit 1 + elif [ $feat_type == "plp" ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $plpdir/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$plpdir/storage $plpdir/storage + fi + + make_plp --cmd "$cmd --max-jobs-run 40" --nj $nj \ + --plp-config $feat_config \ + --add-pitch $add_pitch --pitch-config $pitch_config \ + ${whole_data_dir}_sp exp/make_plp $plpdir || exit 1 + steps/compute_cmvn_stats.sh \ + ${whole_data_dir}_sp exp/make_plp $plpdir || exit 1 + else + echo "$0: Unknown feat-type $feat_type. Must be mfcc or plp." + exit 1 + fi + + utils/fix_data_dir.sh ${whole_data_dir}_sp + fi + + data_dir=${data_dir}_sp + whole_data_dir=${whole_data_dir}_sp + data_id=${data_id}_sp +fi + + +############################################################################### +# Compute length of recording +############################################################################### + +if [ $stage -le 0 ]; then + utils/subsegment_data_dir.sh $whole_data_dir ${data_dir}/segments ${data_dir}/tmp + cp $data_dir/tmp/feats.scp $data_dir + + if [ $feat_type == mfcc ]; then + steps/compute_cmvn_stats.sh ${data_dir} exp/make_mfcc/${data_id} $mfccdir + else + steps/compute_cmvn_stats.sh ${data_dir} exp/make_plp/${data_id} $plpdir + fi + + utils/fix_data_dir.sh $data_dir +fi + +if [ -z "$sat_model_dir" ]; then + ali_dir=${model_dir}_ali_${data_id} + if [ $stage -le 2 ]; then + steps/align_si.sh --nj $nj --cmd "$cmd" \ + ${data_dir} ${lang} ${model_dir} ${model_dir}_ali_${data_id} || exit 1 + fi +else + ali_dir=${sat_model_dir}_ali_${data_id} + #obtain the alignment of the perturbed data + if [ $stage -le 2 ]; then + steps/align_fmllr.sh --nj $nj --cmd "$cmd" \ + ${data_dir} ${lang} ${sat_model_dir} ${sat_model_dir}_ali_${data_id} || exit 1 + fi +fi + + +# All the data from this point is speed perturbed. + +data_id=$(basename $data_dir) +utils/split_data.sh $data_dir $nj + +############################################################################### +# Convert alignment for the provided segments into +# initial SAD labels at utterance-level in segmentation format +############################################################################### + +vad_dir=$dir/`basename ${ali_dir}`_vad_${data_id} +if [ $stage -le 3 ]; then + steps/segmentation/internal/convert_ali_to_vad.sh --cmd "$cmd" \ + $ali_dir $dir/sad_map $vad_dir +fi + +[ ! -s $vad_dir/sad_seg.scp ] && echo "$0: $vad_dir/vad.scp is empty" && exit 1 + +if [ $stage -le 4 ]; then + utils/copy_data_dir.sh $data_dir $dir/${data_id}_manual_segments + + awk '{print $1" "$2}' $dir/${data_id}_manual_segments/segments | sort -k1,1 > $dir/${data_id}_manual_segments/utt2spk + utils/utt2spk_to_spk2utt.pl $dir/${data_id}_manual_segments/utt2spk | sort -k1,1 > $dir/${data_id}_manual_segments/spk2utt + + if [ $feat_type == mfcc ]; then + steps/compute_cmvn_stats.sh $dir/${data_id}_manual_segments exp/make_mfcc/${data_id}_manual_segments $mfccdir + else + steps/compute_cmvn_stats.sh $dir/${data_id}_manual_segments exp/make_plp/${data_id}_manual_segments $plpdir + fi + + utils/fix_data_dir.sh $dir/${data_id}_manual_segments || true # Might fail because utt2spk will be not sorted on both utts and spks +fi + + +#utils/split_data.sh --per-reco $data_dir $reco_nj +#segmentation-combine-segments ark,s:$vad_dir/sad_seg.scp +# "ark,s:segmentation-init-from-segments --shift-to-zero=false --frame-shift=$ali_frame_shift --frame-overlap=$ali_frame_overlap ${data}/split${reco_nj}reco/JOB/segments ark:- |" \ +# "ark:cat ${data}/split${reco_nj}reco/JOB/segments | cut -d ' ' -f 1,2 | utils/utt2spk_to_spk2utt.pl | sort -k1,1 |" ark:- + +############################################################################### + + +# Create extended data directory that consists of the provided +# segments along with the segments outside it. +# This is basically dividing the whole recording into pieces +# consisting of pieces corresponding to the provided segments +# and outside the provided segments. + +############################################################################### +# Create segments outside of the manual segments +############################################################################### + +outside_data_dir=$dir/${data_id}_outside +if [ $stage -le 5 ]; then + rm -rf $outside_data_dir + mkdir -p $outside_data_dir/split${reco_nj}reco + + for f in wav.scp reco2file_and_channel stm glm; do + [ -f ${data_dir}/$f ] && cp ${data_dir}/$f $outside_data_dir + done + + steps/segmentation/split_data_on_reco.sh $data_dir $whole_data_dir $reco_nj + + for n in `seq $reco_nj`; do + dsn=$whole_data_dir/split${reco_nj}reco/$n + awk '{print $2}' $dsn/segments | \ + utils/filter_scp.pl /dev/stdin $whole_data_dir/utt2num_frames > \ + $dsn/utt2num_frames + mkdir -p $outside_data_dir/split${reco_nj}reco/$n + done + + $cmd JOB=1:$reco_nj $outside_data_dir/log/get_empty_segments.JOB.log \ + segmentation-init-from-segments --frame-shift=$frame_shift \ + --frame-overlap=$frame_overlap --shift-to-zero=false \ + ${data_dir}/split${reco_nj}reco/JOB/segments ark:- \| \ + segmentation-combine-segments-to-recordings ark:- \ + "ark,t:cut -d ' ' -f 1,2 ${data_dir}/split${reco_nj}reco/JOB/segments | utils/utt2spk_to_spk2utt.pl |" ark:- \| \ + segmentation-create-subsegments --filter-label=1 --subsegment-label=0 \ + "ark:segmentation-init-from-lengths --label=1 ark,t:${whole_data_dir}/split${reco_nj}reco/JOB/utt2num_frames ark:- |" \ + ark:- ark:- \| \ + segmentation-post-process --remove-labels=0 --max-segment-length=1000 \ + --post-process-label=1 --overlap-length=50 \ + ark:- ark:- \| segmentation-to-segments --single-speaker=true \ + --frame-shift=$frame_shift --frame-overlap=$frame_overlap \ + ark:- ark,t:$outside_data_dir/split${reco_nj}reco/JOB/utt2spk \ + $outside_data_dir/split${reco_nj}reco/JOB/segments || exit 1 + + for n in `seq $reco_nj`; do + cat $outside_data_dir/split${reco_nj}reco/$n/utt2spk + done | sort -k1,1 > $outside_data_dir/utt2spk + + for n in `seq $reco_nj`; do + cat $outside_data_dir/split${reco_nj}reco/$n/segments + done | sort -k1,1 > $outside_data_dir/segments + + utils/fix_data_dir.sh $outside_data_dir + +fi + + +if [ $stage -le 6 ]; then + utils/data/subsegment_data_dir.sh $whole_data_dir $outside_data_dir/segments \ + $outside_data_dir/tmp + cp $outside_data_dir/tmp/feats.scp $outside_data_dir +fi + +extended_data_dir=$dir/${data_id}_extended +if [ $stage -le 7 ]; then + cp $dir/${data_id}_manual_segments/cmvn.scp ${outside_data_dir} || exit 1 + utils/fix_data_dir.sh $outside_data_dir + + utils/combine_data.sh $extended_data_dir $data_dir $outside_data_dir + + steps/segmentation/split_data_on_reco.sh $data_dir $extended_data_dir $reco_nj +fi + +############################################################################### +# Create graph for decoding +############################################################################### + +# TODO: By default, we use word LM. If required, we can think +# consider phone LM. +graph_dir=$model_dir/graph +if [ $stage -le 8 ]; then + if [ ! -d $graph_dir ]; then + utils/mkgraph.sh ${lang_test} $model_dir $graph_dir || exit 1 + fi +fi + +############################################################################### +# Decode extended data directory +############################################################################### + + +# Decode without lattice (get only best path) +if [ $stage -le 8 ]; then + steps/decode_nolats.sh --cmd "$cmd --mem 2G" --nj $nj \ + --max-active 1000 --beam 10.0 --write-words false \ + --write-alignments true \ + $graph_dir ${extended_data_dir} \ + ${model_dir}/decode_${data_id}_extended || exit 1 + cp ${model_dir}/final.mdl ${model_dir}/decode_${data_id}_extended +fi + +model_id=`basename $model_dir` + +# Get VAD based on the decoded best path +decode_vad_dir=$dir/${model_id}_decode_vad_${data_id} +if [ $stage -le 9 ]; then + steps/segmentation/internal/convert_ali_to_vad.sh --cmd "$cmd" \ + ${model_dir}/decode_${data_id}_extended $dir/sad_map $decode_vad_dir +fi + +[ ! -s $decode_vad_dir/sad_seg.scp ] && echo "$0: $decode_vad_dir/vad.scp is empty" && exit 1 + +vad_dir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' $vad_dir ${PWD}` + +if [ $stage -le 10 ]; then + segmentation-init-from-segments --frame-shift=$frame_shift \ + --frame-overlap=$frame_overlap --label=0 \ + $outside_data_dir/segments \ + ark,scp:$vad_dir/outside_sad_seg.ark,$vad_dir/outside_sad_seg.scp +fi + +reco_vad_dir=$dir/${model_id}_reco_vad_${data_id} +mkdir -p $reco_vad_dir +if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $reco_vad_dir/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$reco_vad_dir/storage $reco_vad_dir/storage +fi + +reco_vad_dir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' $reco_vad_dir ${PWD}` + +echo $reco_nj > $reco_vad_dir/num_jobs + +if [ $stage -le 11 ]; then + $cmd JOB=1:$reco_nj $reco_vad_dir/log/intersect_vad.JOB.log \ + segmentation-intersect-segments --mismatch-label=10 \ + "scp:cat $vad_dir/sad_seg.scp $vad_dir/outside_sad_seg.scp | sort -k1,1 | utils/filter_scp.pl $extended_data_dir/split${reco_nj}reco/JOB/utt2spk |" \ + "scp:utils/filter_scp.pl $extended_data_dir/split${reco_nj}reco/JOB/utt2spk $decode_vad_dir/sad_seg.scp |" \ + ark:- \| segmentation-post-process --remove-labels=10 \ + --merge-adjacent-segments --max-intersegment-length=10 ark:- ark:- \| \ + segmentation-combine-segments ark:- "ark:segmentation-init-from-segments --shift-to-zero=false $extended_data_dir/split${reco_nj}reco/JOB/segments ark:- |" \ + ark,t:$extended_data_dir/split${reco_nj}reco/JOB/reco2utt \ + ark,scp:$reco_vad_dir/sad_seg.JOB.ark,$reco_vad_dir/sad_seg.JOB.scp + for n in `seq $reco_nj`; do + cat $reco_vad_dir/sad_seg.$n.scp + done > $reco_vad_dir/sad_seg.scp +fi + +set +e +for n in `seq $reco_nj`; do + utils/create_data_link.pl $reco_vad_dir/deriv_weights.$n.ark + utils/create_data_link.pl $reco_vad_dir/deriv_weights_for_uncorrupted.$n.ark + utils/create_data_link.pl $reco_vad_dir/speech_labels.$n.ark +done +set -e + +if [ $stage -le 12 ]; then + $cmd JOB=1:$reco_nj $reco_vad_dir/log/get_deriv_weights.JOB.log \ + segmentation-post-process --merge-labels=0:1:2:3 --merge-dst-label=1 \ + scp:$reco_vad_dir/sad_seg.JOB.scp ark:- \| \ + segmentation-to-ali --lengths-rspecifier=ark,t:${whole_data_dir}/utt2num_frames ark:- ark,t:- \| \ + steps/segmentation/convert_ali_to_vec.pl \| copy-vector ark,t:- \ + ark,scp:$reco_vad_dir/deriv_weights.JOB.ark,$reco_vad_dir/deriv_weights.JOB.scp + + for n in `seq $reco_nj`; do + cat $reco_vad_dir/deriv_weights.$n.scp + done > $reco_vad_dir/deriv_weights.scp +fi + +if [ $stage -le 13 ]; then + $cmd JOB=1:$reco_nj $reco_vad_dir/log/get_deriv_weights_for_uncorrupted.JOB.log \ + segmentation-post-process --remove-labels=1:2:3 scp:$reco_vad_dir/sad_seg.JOB.scp \ + ark:- \| segmentation-post-process --merge-labels=0 --merge-dst-label=1 ark:- ark:- \| \ + segmentation-to-ali --lengths-rspecifier=ark,t:${whole_data_dir}/utt2num_frames ark:- ark,t:- \| \ + steps/segmentation/convert_ali_to_vec.pl \| copy-vector ark,t:- \ + ark,scp:$reco_vad_dir/deriv_weights_for_uncorrupted.JOB.ark,$reco_vad_dir/deriv_weights_for_uncorrupted.JOB.scp + for n in `seq $reco_nj`; do + cat $reco_vad_dir/deriv_weights_for_uncorrupted.$n.scp + done > $reco_vad_dir/deriv_weights_for_uncorrupted.scp +fi + +if [ $stage -le 14 ]; then + $cmd JOB=1:$reco_nj $reco_vad_dir/log/get_speech_labels.JOB.log \ + segmentation-copy --keep-label=1 scp:$reco_vad_dir/sad_seg.JOB.scp ark:- \| \ + segmentation-to-ali --lengths-rspecifier=ark,t:${whole_data_dir}/utt2num_frames \ + ark:- ark,scp:$reco_vad_dir/speech_labels.JOB.ark,$reco_vad_dir/speech_labels.JOB.scp + for n in `seq $reco_nj`; do + cat $reco_vad_dir/speech_labels.$n.scp + done > $reco_vad_dir/speech_labels.scp +fi + +if [ $stage -le 15 ]; then + $cmd JOB=1:$reco_nj $reco_vad_dir/log/convert_manual_segments_to_deriv_weights.JOB.log \ + segmentation-init-from-segments --shift-to-zero=false \ + $data_dir/split${reco_nj}reco/JOB/segments ark:- \| \ + segmentation-combine-segments-to-recordings ark:- \ + ark:$data_dir/split${reco_nj}reco/JOB/reco2utt ark:- \| \ + segmentation-to-ali --lengths-rspecifier=ark,t:${whole_data_dir}/utt2num_frames \ + ark:- ark,t:- \| \ + steps/segmentation/convert_ali_to_vec.pl \| copy-vector ark,t:- \ + ark,scp:$reco_vad_dir/deriv_weights_manual_seg.JOB.ark,$reco_vad_dir/deriv_weights_manual_seg.JOB.scp + + for n in `seq $reco_nj`; do + cat $reco_vad_dir/deriv_weights_manual_seg.$n.scp + done > $reco_vad_dir/deriv_weights_manual_seg.scp +fi + +echo "$0: Finished creating corpus for training Universal SAD with data in $whole_data_dir and labels in $reco_vad_dir" diff --git a/egs/aspire/s5/local/segmentation/prepare_unsad_data_simple.sh b/egs/aspire/s5/local/segmentation/prepare_unsad_data_simple.sh new file mode 100755 index 00000000000..f3d1a7707e8 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/prepare_unsad_data_simple.sh @@ -0,0 +1,114 @@ +#!/bin/bash + +# Copyright 2016 Vimal Manohar +# Apache 2.0 + +# This script prepares speech labels for +# training unsad network for speech activity detection and music detection. +# This is similar to the script prepare_unsad_data.sh, but directly +# uses existing alignments to create labels, instead of creating new alignments. + +set -e +set -o pipefail +set -u + +. path.sh + +stage=-2 +cmd=queue.pl + +# Options to be passed to get_sad_map.py +map_noise_to_sil=true # Map noise phones to silence label (0) +map_unk_to_speech=true # Map unk phones to speech label (1) +sad_map= # Initial mapping from phones to speech/non-speech labels. + # Overrides the default mapping using phones/silence.txt + # and phones/nonsilence.txt + +. utils/parse_options.sh + +if [ $# -ne 4 ]; then + echo "This script takes a data directory and alignment directory and " + echo "converts it into speech activity labels" + echo "for the purpose of training a Universal Speech Activity Detector.\n" + echo "Usage: $0 [options] " + echo " e.g.: $0 data/train_100k data/lang exp/tri4a_ali exp/vad_data_prep" + echo "" + echo "Main options (for others, see top of script file)" + echo " --config # config file containing options" + echo " --cmd (run.pl|/queue.pl ) # how to run jobs." + exit 1 +fi + +data_dir=$1 +lang=$2 +ali_dir=$3 +dir=$4 + +extra_files= + +for f in $data_dir/feats.scp $lang/phones.txt $lang/phones/silence.txt $lang/phones/nonsilence.txt $sad_map $ali_dir/ali.1.gz $ali_dir/final.mdl $ali_dir/tree $extra_files; do + if [ ! -f $f ]; then + echo "$f could not be found" + exit 1 + fi +done + +mkdir -p $dir + +data_id=$(basename $data_dir) + +if [ $stage -le 0 ]; then + # Get a mapping from the phones to the speech / non-speech labels + steps/segmentation/get_sad_map.py \ + --init-sad-map="$sad_map" \ + --map-noise-to-sil=$map_noise_to_sil \ + --map-unk-to-speech=$map_unk_to_speech \ + $lang | utils/sym2int.pl -f 1 $lang/phones.txt > $dir/sad_map +fi + +############################################################################### +# Convert alignment into SAD labels at utterance-level in segmentation format +############################################################################### + +vad_dir=$dir/`basename ${ali_dir}`_vad_${data_id} + +# Convert relative path to full path +vad_dir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir;' $vad_dir ${PWD}` + +if [ $stage -le 1 ]; then + steps/segmentation/internal/convert_ali_to_vad.sh --cmd "$cmd" \ + $ali_dir $dir/sad_map $vad_dir +fi + +[ ! -s $vad_dir/sad_seg.scp ] && echo "$0: $vad_dir/sad_seg.scp is empty" && exit 1 + +############################################################################### +# Post-process the segmentation and create frame-level alignments and +# per-frame deriv weights. +############################################################################### + +if [ $stage -le 2 ]; then + # Create per-frame speech / non-speech labels. + nj=`cat $vad_dir/num_jobs` + + utils/data/get_utt2num_frames.sh --nj $nj --cmd "$cmd" $data_dir + + set +e + for n in `seq $nj`; do + utils/create_data_link.pl $vad_dir/speech_labels.$n.ark + done + set -e + + $cmd JOB=1:$nj $vad_dir/log/get_speech_labels.JOB.log \ + segmentation-copy --keep-label=1 scp:$vad_dir/sad_seg.JOB.scp ark:- \| \ + segmentation-to-ali --lengths-rspecifier=ark,t:$data_dir/utt2num_frames \ + ark:- ark,scp:$vad_dir/speech_labels.JOB.ark,$vad_dir/speech_labels.JOB.scp + + for n in `seq $nj`; do + cat $vad_dir/speech_labels.$n.scp + done > $vad_dir/speech_labels.scp + + cp $vad_dir/speech_labels.scp $data_dir +fi + +echo "$0: Finished creating corpus for training Universal SAD with data in $data_dir and labels in $vad_dir" diff --git a/egs/aspire/s5/local/segmentation/prepare_unsad_overlapped_speech_data.sh b/egs/aspire/s5/local/segmentation/prepare_unsad_overlapped_speech_data.sh new file mode 100755 index 00000000000..6d21859d7fe --- /dev/null +++ b/egs/aspire/s5/local/segmentation/prepare_unsad_overlapped_speech_data.sh @@ -0,0 +1,283 @@ +#! /bin/bash + +# Copyright 2016 Vimal Manohar +# Apache 2.0 + +set -e +set -u +set -o pipefail + +. path.sh + +num_data_reps=5 +nj=40 +cmd=queue.pl +snr_db_threshold=10 +stage=-1 + +. utils/parse_options.sh + +if [ $# -ne 5 ]; then + echo "Usage: $0 " + echo " e.g.: $0 data/fisher_train_100k_sp_75k_seg_ovlp_corrupted_hires_bp data/fisher_train_100k_sp_75k_seg_ovlp_corrupted exp/unsad/make_unsad_fisher_train_100k/tri4a_ali_fisher_train_100k_sp_vad_fisher_train_100k_sp exp/unsad overlap_labels" + exit 1 +fi + +corrupted_data_dir=$1 +orig_corrupted_data_dir=$2 +utt_vad_dir=$3 +tmpdir=$4 +overlap_labels_dir=$5 + +overlapped_segments_info=$orig_corrupted_data_dir/overlapped_segments_info.txt +corrupted_data_id=`basename $orig_corrupted_data_dir` + +for f in $corrupted_data_dir/feats.scp $overlapped_segments_info $utt_vad_dir/sad_seg.scp; do + [ ! -f $f ] && echo "Could not find file $f" && exit 1 +done + +overlap_dir=$tmpdir/make_overlap_labels_${corrupted_data_id} +unreliable_dir=$tmpdir/unreliable_${corrupted_data_id} + +mkdir -p $unreliable_dir + +# make $overlap_labels_dir an absolute pathname. +overlap_labels_dir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' $overlap_labels_dir ${PWD}` + +# Combine the VAD from the base recording and the VAD from the overlapping segments +# to create per-frame labels of the number of overlapping speech segments +# Unreliable segments are regions where no VAD labels were available for the +# overlapping segments. These can be later removed by setting deriv weights to 0. + +if [ $stage -le 1 ]; then + for n in `seq $num_data_reps`; do + cat $utt_vad_dir/sad_seg.scp | \ + awk -v n=$n '{print "ovlp"n"_"$0}' + done | sort -k1,1 > ${corrupted_data_dir}/sad_seg.scp + utils/data/get_utt2num_frames.sh $corrupted_data_dir + utils/split_data.sh ${corrupted_data_dir} $nj + + # 1) segmentation-init-from-additive-signals-info converts the informtation + # written out but by steps/data/make_corrupted_data_dir.py in overlapped_segments_info.txt + # and converts it to segments. It then adds those segments to the + # segments already present ($corrupted_data_dir/sad_seg.scp) + # 2) Retain only the speech segments (label 1) from these. + # 3) Convert this to overlap stats using segmentation-get-stats, which + # writes for each frame the number of overlapping segments. + # 4) Convert this per-frame "alignment" information to segmentation + # ($overlap_dir/overlap_seg.*.gz). + $cmd JOB=1:$nj $overlap_dir/log/get_overlap_seg.JOB.log \ + segmentation-init-from-additive-signals-info --lengths-rspecifier=ark,t:$corrupted_data_dir/utt2num_frames \ + --additive-signals-segmentation-rspecifier=scp:$utt_vad_dir/sad_seg.scp \ + --unreliable-segmentation-wspecifier="ark:| gzip -c > $unreliable_dir/unreliable_seg.JOB.gz" \ + "scp:utils/filter_scp.pl ${corrupted_data_dir}/split${nj}/JOB/utt2spk $corrupted_data_dir/sad_seg.scp |" \ + ark,t:$orig_corrupted_data_dir/overlapped_segments_info.txt ark:- \| \ + segmentation-copy --keep-label=1 ark:- ark:- \| \ + segmentation-get-stats --lengths-rspecifier=ark,t:$corrupted_data_dir/utt2num_frames \ + ark:- ark:- ark:/dev/null \| \ + segmentation-init-from-ali ark:- "ark:| gzip -c > $overlap_dir/overlap_seg.JOB.gz" +fi + +if [ $stage -le 2 ]; then + # Retain labels >2, i.e. regions where more than 1 speaker overlap. + # Write this out in alignment format as "overlapped_speech_labels" + $cmd JOB=1:$nj $overlap_dir/log/get_overlapped_speech_labels.JOB.log \ + gunzip -c $overlap_dir/overlap_seg.JOB.gz \| \ + segmentation-post-process --remove-labels=0:1 ark:- ark:- \| \ + segmentation-post-process --merge-labels=2:3:4:5:6:7:8:9:10 --merge-dst-label=1 ark:- ark:- \| \ + segmentation-to-ali --lengths-rspecifier=ark,t:${corrupted_data_dir}/utt2num_frames ark:- \ + ark,scp:$overlap_labels_dir/overlapped_speech_labels_${corrupted_data_id}.JOB.ark,$overlap_labels_dir/overlapped_speech_labels_${corrupted_data_id}.JOB.scp + + for n in `seq $nj`; do + cat $overlap_labels_dir/overlapped_speech_labels_${corrupted_data_id}.$n.scp + done > ${corrupted_data_dir}/overlapped_speech_labels.scp +fi + +if [ $stage -le 3 ]; then + # 1) Initialize a segmentation where all the frames have label 1 using + # segmentation-init-from-length. + # 2) Use the program segmentation-create-subsegments to set to 0 + # the regions of unreliable segments read from unreliable_seg.*.gz. + # This is the initial deriv weights. At this stage deriv weights is 1 for all + # but the unreliable segment regions. + # 3) Initialize a segmentation from the overlap labels (overlap_seg.*.gz) + # and retain regions where there is speech from at least one speaker. + # 4) Intersect this with the deriv weights segmentation from above. + # At this stage deriv weights is 1 for only the regions where there is + # at least one speaker and the the overlapping segment is not unreliable. + # Convert this to deriv weights. + $cmd JOB=1:$nj $unreliable_dir/log/get_deriv_weights.JOB.log \ + utils/filter_scp.pl $corrupted_data_dir/split$nj/JOB/utt2spk $corrupted_data_dir/utt2num_frames \| \ + segmentation-init-from-lengths ark,t:- ark:- \| \ + segmentation-create-subsegments --filter-label=1 --subsegment-label=0 --ignore-missing \ + ark:- "ark,s,cs:gunzip -c $unreliable_dir/unreliable_seg.JOB.gz | segmentation-to-segments ark:- - | segmentation-init-from-segments - ark:- |" ark:- \| \ + segmentation-intersect-segments --mismatch-label=0 \ + "ark:gunzip -c $overlap_dir/overlap_seg.JOB.gz | segmentation-post-process --remove-labels=0 --merge-labels=1:2:3:4:5:6:7:8:9:10 --merge-dst-label=1 ark:- ark:- |" \ + ark,s,cs:- ark:- \| segmentation-post-process --remove-labels=0 ark:- ark:- \| \ + segmentation-to-ali --lengths-rspecifier=ark,t:${corrupted_data_dir}/utt2num_frames ark:- ark,t:- \| \ + steps/segmentation/convert_ali_to_vec.pl \| copy-vector ark,t:- \ + ark,scp:$overlap_labels_dir/deriv_weights_for_overlapped_speech_${corrupted_data_id}.JOB.ark,$overlap_labels_dir/deriv_weights_for_overlapped_speech_${corrupted_data_id}.JOB.scp + + for n in `seq $nj`; do + cat $overlap_labels_dir/deriv_weights_for_overlapped_speech_$corrupted_data_id.${n}.scp + done > $corrupted_data_dir/deriv_weights_for_overlapped_speech.scp +fi + +if [ $stage -le 4 ]; then + # Find regions where there is at least one speaker speaking. + $cmd JOB=1:$nj $overlap_dir/log/get_speech_labels.JOB.log \ + gunzip -c $overlap_dir/overlap_seg.JOB.gz \| \ + segmentation-post-process --remove-labels=0 --merge-labels=1:2:3:4:5:6:7:8:9:10 --merge-dst-label=1 ark:- ark:- \| \ + segmentation-to-ali --lengths-rspecifier=ark,t:${corrupted_data_dir}/utt2num_frames ark:- ark,t:- \| \ + steps/segmentation/convert_ali_to_vec.pl \| \ + vector-to-feat ark:- \ + ark,scp:$overlap_labels_dir/speech_feat_${corrupted_data_id}.JOB.ark,$overlap_labels_dir/speech_feat_${corrupted_data_id}.JOB.scp + + for n in `seq $nj`; do + cat $overlap_labels_dir/speech_feat_${corrupted_data_id}.$n.scp + done > ${corrupted_data_dir}/speech_feat.scp +fi + +if [ $stage -le 5 ]; then + # Deriv weights speech / non-speech labels is 1 everywhere but the + # unreliable regions. + $cmd JOB=1:$nj $unreliable_dir/log/get_deriv_weights.JOB.log \ + utils/filter_scp.pl $corrupted_data_dir/split$nj/JOB/utt2spk $corrupted_data_dir/utt2num_frames \| \ + segmentation-init-from-lengths ark,t:- ark:- \| \ + segmentation-create-subsegments --filter-label=1 --subsegment-label=0 --ignore-missing \ + ark:- "ark,s,cs:gunzip -c $unreliable_dir/unreliable_seg.JOB.gz | segmentation-to-segments ark:- - | segmentation-init-from-segments - ark:- |" ark:- \| \ + segmentation-post-process --remove-labels=0 ark:- ark:- \| \ + segmentation-to-ali --lengths-rspecifier=ark,t:${corrupted_data_dir}/utt2num_frames ark:- ark,t:- \| \ + steps/segmentation/convert_ali_to_vec.pl \| copy-vector ark,t:- \ + ark,scp:$overlap_labels_dir/deriv_weights_${corrupted_data_id}.JOB.ark,$overlap_labels_dir/deriv_weights_${corrupted_data_id}.JOB.scp + + for n in `seq $nj`; do + cat $overlap_labels_dir/deriv_weights_$corrupted_data_id.${n}.scp + done > $corrupted_data_dir/deriv_weights.scp +fi + +snr_threshold=`perl -e "print $snr_db_threshold / 10.0 * log(10.0)"` + +cat < $overlap_dir/invert_labels.map +0 1 +1 0 +EOF + +if [ $stage -le 6 ]; then + if [ ! -f $corrupted_data_dir/log_snr.scp ]; then + echo "$0: Could not find $corrupted_data_dir/log_snr.scp. Run local/segmentation/do_corruption_data_dir_overlapped_speech.sh." + exit 1 + fi + + $cmd JOB=1:$nj $overlap_dir/log/fix_overlapped_speech_labels.JOB.log \ + copy-matrix --apply-power=1 \ + "scp:utils/filter_scp.pl $corrupted_data_dir/split$nj/JOB/utt2spk $corrupted_data_dir/log_snr.scp |" \ + ark:- \| extract-column ark:- ark,t:- \| \ + steps/segmentation/quantize_vector.pl $snr_threshold \| \ + segmentation-init-from-ali ark,t:- ark:- \| \ + segmentation-copy --label-map=$overlap_dir/invert_labels.map ark:- ark:- \| \ + segmentation-intersect-segments --mismatch-label=1000 \ + "ark:utils/filter_scp.pl $corrupted_data_dir/split$nj/JOB/utt2spk $corrupted_data_dir/overlapped_speech_labels.scp | segmentation-init-from-ali scp:- ark:- | segmentation-copy --keep-label=1 ark:- ark:- |" ark:- ark:- \| \ + segmentation-copy --keep-label=1 ark:- ark:- \| \ + segmentation-to-ali --lengths-rspecifier=ark,t:$corrupted_data_dir/utt2num_frames \ + ark:- ark,scp:$overlap_labels_dir/overlapped_speech_labels_fixed_${corrupted_data_id}.JOB.ark,$overlap_labels_dir/overlapped_speech_labels_fixed_${corrupted_data_id}.JOB.scp + + for n in `seq $nj`; do + cat $overlap_labels_dir/overlapped_speech_labels_fixed_${corrupted_data_id}.$n.scp + done > $corrupted_data_dir/overlapped_speech_labels_fixed.scp +fi + +exit 0 + +####exit 1 +#### +####if [ $stage -le 9 ]; then +#### mkdir -p $overlap_data_dir $unreliable_data_dir +#### cp $orig_corrupted_data_dir/wav.scp $overlap_data_dir +#### cp $orig_corrupted_data_dir/wav.scp $unreliable_data_dir +#### +#### # Create segments where there is definitely an overlap. +#### # Assume no more than 10 speakers overlap. +#### $cmd JOB=1:$nj $overlap_dir/log/process_to_segments.JOB.log \ +#### segmentation-post-process --remove-labels=0:1 \ +#### ark:$overlap_dir/overlap_seg_speed_unperturbed.JOB.ark ark:- \| \ +#### segmentation-post-process --merge-labels=2:3:4:5:6:7:8:9:10 --merge-dst-label=1 ark:- ark:- \| \ +#### segmentation-to-segments ark:- ark:$overlap_data_dir/utt2spk.JOB $overlap_data_dir/segments.JOB +#### +#### $cmd JOB=1:$nj $overlap_dir/log/get_unreliable_segments.JOB.log \ +#### segmentation-to-segments --single-speaker \ +#### ark:$unreliable_dir/unreliable_seg_speed_unperturbed.JOB.ark \ +#### ark:$unreliable_data_dir/utt2spk.JOB $unreliable_data_dir/segments.JOB +#### +#### for n in `seq $nj`; do cat $overlap_data_dir/utt2spk.$n; done > $overlap_data_dir/utt2spk +#### for n in `seq $nj`; do cat $overlap_data_dir/segments.$n; done > $overlap_data_dir/segments +#### for n in `seq $nj`; do cat $unreliable_data_dir/utt2spk.$n; done > $unreliable_data_dir/utt2spk +#### for n in `seq $nj`; do cat $unreliable_data_dir/segments.$n; done > $unreliable_data_dir/segments +#### +#### utils/fix_data_dir.sh $overlap_data_dir +#### utils/fix_data_dir.sh $unreliable_data_dir +#### +#### if $speed_perturb; then +#### utils/data/perturb_data_dir_speed_3way.sh $overlap_data_dir ${overlap_data_dir}_sp +#### utils/data/perturb_data_dir_speed_3way.sh $unreliable_data_dir ${unreliable_data_dir}_sp +#### fi +####fi +#### +####if $speed_perturb; then +#### overlap_data_dir=${overlap_data_dir}_sp +#### unreliable_data_dir=${unreliable_data_dir}_sp +####fi +#### +##### make $overlap_labels_dir an absolute pathname. +####overlap_labels_dir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' $overlap_labels_dir ${PWD}` +#### +####if [ $stage -le 10 ]; then +#### utils/split_data.sh ${overlap_data_dir} $nj +#### +#### $cmd JOB=1:$nj $overlap_dir/log/get_overlap_speech_labels.JOB.log \ +#### utils/data/get_reco2utt.sh ${overlap_data_dir}/split${reco_nj}reco/JOB '&&' \ +#### segmentation-init-from-segments --shift-to-zero=false \ +#### ${overlap_data_dir}/split${reco_nj}reco/JOB/segments ark:- \| \ +#### segmentation-combine-segments-to-recordings ark:- ark,t:${overlap_data_dir}/split${reco_nj}reco/JOB/reco2utt \ +#### ark:- \| \ +#### segmentation-to-ali --lengths-rspecifier=ark,t:${corrupted_data_dir}/utt2num_frames ark:- \ +#### ark,scp:$overlap_labels_dir/overlapped_speech_${corrupted_data_id}.JOB.ark,$overlap_labels_dir/overlapped_speech_${corrupted_data_id}.JOB.scp +####fi +#### +####for n in `seq $reco_nj`; do +#### cat $overlap_labels_dir/overlapped_speech_${corrupted_data_id}.$n.scp +####done > ${corrupted_data_dir}/overlapped_speech_labels.scp +#### +####if [ $stage -le 11 ]; then +#### utils/data/get_reco2utt.sh ${unreliable_data_dir} +#### +#### # First convert the unreliable segments into a recording-level segmentation. +#### # Initialize a segmentation from utt2num_frames and set to 0, the regions +#### # of unreliable segments. At this stage deriv weights is 1 for all but the +#### # unreliable segment regions. +#### # Initialize a segmentation from the VAD labels and retain only the speech segments. +#### # Intersect this with the deriv weights segmentation from above. At this stage +#### # deriv weights is 1 for only the regions where base VAD label is 1 and +#### # the overlapping segment is not unreliable. Convert this to deriv weights. +#### $cmd JOB=1:$reco_nj $unreliable_dir/log/get_deriv_weights.JOB.log\ +#### segmentation-init-from-segments --shift-to-zero=false \ +#### "utils/filter_scp.pl -f 2 ${overlap_data_dir}/split${reco_nj}reco/JOB/reco2utt ${unreliable_data_dir}/segments |" ark:- \| \ +#### segmentation-combine-segments-to-recordings ark:- "ark,t:utils/filter_scp.pl ${overlap_data_dir}/split${reco_nj}reco/JOB/reco2utt ${unreliable_data_dir}/reco2utt |" \ +#### ark:- \| \ +#### segmentation-create-subsegments --filter-label=1 --subsegment-label=0 --ignore-missing \ +#### "ark:utils/filter_scp.pl ${overlap_data_dir}/split${reco_nj}reco/JOB/reco2utt $corrupted_data_dir/utt2num_frames | segmentation-init-from-lengths ark,t:- ark:- |" \ +#### ark:- ark:- \| \ +#### segmentation-intersect-segments --mismatch-label=0 \ +#### "ark:utils/filter_scp.pl ${overlap_data_dir}/split${reco_nj}reco/JOB/reco2utt $corrupted_data_dir/sad_seg.scp | segmentation-post-process --remove-labels=0:2:3 scp:- ark:- |" \ +#### ark:- ark:- \| \ +#### segmentation-post-process --remove-labels=0 ark:- ark:- \| \ +#### segmentation-to-ali --lengths-rspecifier=ark,t:${corrupted_data_dir}/utt2num_frames ark:- ark,t:- \| \ +#### steps/segmentation/convert_ali_to_vec.pl \| copy-vector ark,t:- \ +#### ark,scp:$overlap_labels_dir/deriv_weights_for_overlapped_speech.JOB.ark,$overlap_labels_dir/deriv_weights_for_overlapped_speech.JOB.scp +#### +#### for n in `seq $reco_nj`; do +#### cat $overlap_labels_dir/deriv_weights_for_overlapped_speech.${n}.scp +#### done > $corrupted_data_dir/deriv_weights_for_overlapped_speech.scp +####fi +#### +####exit 0 diff --git a/egs/aspire/s5/local/segmentation/prepare_unsad_overlapped_speech_data_simple.sh b/egs/aspire/s5/local/segmentation/prepare_unsad_overlapped_speech_data_simple.sh new file mode 100755 index 00000000000..80810afd619 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/prepare_unsad_overlapped_speech_data_simple.sh @@ -0,0 +1,157 @@ +#! /bin/bash + +# Copyright 2016 Vimal Manohar +# Apache 2.0 + +set -e +set -u +set -o pipefail + +. path.sh + +num_data_reps=5 +nj=40 +cmd=queue.pl +snr_db_threshold=10 +stage=-1 + +. utils/parse_options.sh + +if [ $# -ne 5 ]; then + echo "Usage: $0 " + echo " e.g.: $0 data/fisher_train_100k_sp_75k_seg_ovlp_corrupted_hires_bp data/fisher_train_100k_sp_75k_seg_ovlp_corrupted exp/unsad/make_unsad_fisher_train_100k/tri4a_ali_fisher_train_100k_sp_vad_fisher_train_100k_sp exp/unsad overlapping_sad_labels" + exit 1 +fi + +corrupted_data_dir=$1 +orig_corrupted_data_dir=$2 +utt_vad_dir=$3 +tmpdir=$4 +overlap_labels_dir=$5 + +overlapped_segments_info=$orig_corrupted_data_dir/overlapped_segments_info.txt +corrupted_data_id=`basename $orig_corrupted_data_dir` + +for f in $corrupted_data_dir/feats.scp $overlapped_segments_info $utt_vad_dir/sad_seg.scp; do + [ ! -f $f ] && echo "Could not find file $f" && exit 1 +done + +overlap_dir=$tmpdir/make_overlapping_sad_labels_${corrupted_data_id} + +# make $overlap_labels_dir an absolute pathname. +overlap_labels_dir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' $overlap_labels_dir ${PWD}` +mkdir -p $overlap_labels_dir + +# Combine the VAD from the base recording and the VAD from the overlapping segments +# to create per-frame labels of the number of overlapping speech segments +# Unreliable segments are regions where no VAD labels were available for the +# overlapping segments. These can be later removed by setting deriv weights to 0. + +if [ $stage -le 1 ]; then + for n in `seq $num_data_reps`; do + cat $utt_vad_dir/sad_seg.scp | \ + awk -v n=$n '{print "ovlp"n"_"$0}' + done | sort -k1,1 > ${corrupted_data_dir}/sad_seg.scp + utils/data/get_utt2num_frames.sh $corrupted_data_dir + utils/split_data.sh ${corrupted_data_dir} $nj + + # 1) segmentation-init-from-additive-signals-info converts the informtation + # written out but by steps/data/make_corrupted_data_dir.py in overlapped_segments_info.txt + # and converts it to segments. It then adds those segments to the + # segments already present ($corrupted_data_dir/sad_seg.scp) + # 2) Retain only the speech segments (label 1) from these. + # 3) Convert this to overlap stats using segmentation-get-stats, which + # writes for each frame the number of overlapping segments. + # 4) Convert this per-frame "alignment" information to segmentation + # ($overlap_dir/overlap_seg.*.gz). + $cmd JOB=1:$nj $overlap_dir/log/get_overlapping_sad_seg.JOB.log \ + segmentation-init-from-additive-signals-info --lengths-rspecifier=ark,t:$corrupted_data_dir/utt2num_frames \ + --junk-label=10000 \ + --additive-signals-segmentation-rspecifier=scp:$utt_vad_dir/sad_seg.scp \ + "ark,t:utils/filter_scp.pl ${orig_corrupted_data_dir}/split${reco_nj}reco/JOB/reco2utt $orig_corrupted_data_dir/overlapped_segments_info.txt |" \ + ark:- \| \ + segmentation-merge "scp:utils/filter_scp.pl ${corrupted_data_dir}/split${nj}/JOB/utt2spk $corrupted_data_dir/sad_seg.scp |" ark:- ark:- \| \ + segmentation-get-stats --lengths-rspecifier=ark,t:$corrupted_data_dir/utt2num_frames \ + ark:- ark:/dev/null ark:/dev/null ark:- \| \ + classes-per-frame-to-labels --junk-label=10000 ark:- ark:- \| \ + segmentation-init-from-ali ark:- \ + "ark:| gzip -c > $overlap_dir/overlap_sad_seg.JOB.gz" +fi + +if [ $stage -le 2 ]; then + # Call labels >2, i.e. regions where more than 1 speaker overlap as overlapping speech. labels = 1 is single speaker and labels = 0 is silence. + # Write this out in alignment format as "overlapping_sad_labels" + $cmd JOB=1:$nj $overlap_dir/log/get_overlapping_sad_labels.JOB.log \ + gunzip -c $overlap_dir/overlap_sad_seg.JOB.gz \| \ + segmentation-post-process --remove-labels=10000 ark:- ark:- \| \ + segmentation-to-ali --lengths-rspecifier=ark,t:${corrupted_data_dir}/utt2num_frames ark:- \ + ark,scp:$overlap_labels_dir/overlapping_sad_labels_${corrupted_data_id}.JOB.ark,$overlap_labels_dir/overlapping_sad_labels_${corrupted_data_id}.JOB.scp + + for n in `seq $nj`; do + cat $overlap_labels_dir/overlapping_sad_labels_${corrupted_data_id}.$n.scp + done > ${corrupted_data_dir}/overlapping_sad_labels.scp +fi + +if [ $stage -le 3 ]; then + # Find regions where there is at least one speaker speaking. + $cmd JOB=1:$nj $overlap_dir/log/get_speech_feat.JOB.log \ + gunzip -c $overlap_dir/overlap_sad_seg.JOB.gz \| \ + segmentation-post-process --remove-labels=10000 ark:- ark:- \| \ + segmentation-post-process --merge-labels=1:2 --merge-dst-label=1 ark:- ark:- \| \ + segmentation-to-ali --lengths-rspecifier=ark,t:${corrupted_data_dir}/utt2num_frames ark:- ark,t:- \| \ + steps/segmentation/convert_ali_to_vec.pl \| \ + vector-to-feat ark:- \ + ark,scp:$overlap_labels_dir/speech_feat_${corrupted_data_id}.JOB.ark,$overlap_labels_dir/speech_feat_${corrupted_data_id}.JOB.scp + + for n in `seq $nj`; do + cat $overlap_labels_dir/speech_feat_${corrupted_data_id}.$n.scp + done > ${corrupted_data_dir}/speech_feat.scp +fi + +if [ $stage -le 4 ]; then + # Deriv weights is 1 everywhere but the + # unreliable regions. + $cmd JOB=1:$nj $overlap_dir/log/get_deriv_weights.JOB.log \ + gunzip -c $overlap_dir/overlap_sad_seg.JOB.gz \| \ + segmentation-post-process --merge-labels=0:1:2 --merge-dst-label=1 ark:- ark:- \| \ + segmentation-post-process --merge-labels=10000 --merge-dst-label=0 ark:- ark:- \| \ + segmentation-to-ali --lengths-rspecifier=ark,t:${corrupted_data_dir}/utt2num_frames ark:- ark,t:- \| \ + steps/segmentation/convert_ali_to_vec.pl \| copy-vector ark,t:- \ + ark,scp:$overlap_labels_dir/deriv_weights_${corrupted_data_id}.JOB.ark,$overlap_labels_dir/deriv_weights_${corrupted_data_id}.JOB.scp + + for n in `seq $nj`; do + cat $overlap_labels_dir/deriv_weights_$corrupted_data_id.${n}.scp + done > $corrupted_data_dir/deriv_weights.scp +fi + +snr_threshold=`perl -e "print $snr_db_threshold / 10.0 * log(10.0)"` + +cat < $overlap_dir/invert_labels.map +0 2 +1 1 +EOF + +if [ $stage -le 5 ]; then + if [ ! -f $corrupted_data_dir/log_snr.scp ]; then + echo "$0: Could not find $corrupted_data_dir/log_snr.scp. Run local/segmentation/do_corruption_data_dir_overlapped_speech.sh." + exit 1 + fi + + $cmd JOB=1:$nj $overlap_dir/log/fix_overlapping_sad_labels.JOB.log \ + copy-matrix --apply-power=1 \ + "scp:utils/filter_scp.pl $corrupted_data_dir/split$nj/JOB/utt2spk $corrupted_data_dir/log_snr.scp |" \ + ark:- \| extract-column ark:- ark,t:- \| \ + steps/segmentation/quantize_vector.pl $snr_threshold \| \ + segmentation-init-from-ali ark,t:- ark:- \| \ + segmentation-copy --label-map=$overlap_dir/invert_labels.map ark:- ark:- \| \ + segmentation-create-subsegments --filter-label=1 --subsegment-label=1 \ + "ark:utils/filter_scp.pl $corrupted_data_dir/split$nj/JOB/utt2spk $corrupted_data_dir/overlapping_sad_labels.scp | segmentation-init-from-ali scp:- ark:- |" ark:- ark:- \| \ + segmentation-to-ali --lengths-rspecifier=ark,t:$corrupted_data_dir/utt2num_frames \ + ark:- ark,scp:$overlap_labels_dir/overlapping_sad_labels_fixed_${corrupted_data_id}.JOB.ark,$overlap_labels_dir/overlapping_sad_labels_fixed_${corrupted_data_id}.JOB.scp + + for n in `seq $nj`; do + cat $overlap_labels_dir/overlapping_sad_labels_fixed_${corrupted_data_id}.$n.scp + done > $corrupted_data_dir/overlapping_sad_labels_fixed.scp +fi + +exit 0 diff --git a/egs/aspire/s5/local/segmentation/run_fisher.sh b/egs/aspire/s5/local/segmentation/run_fisher.sh new file mode 100644 index 00000000000..e39ef5f3a91 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/run_fisher.sh @@ -0,0 +1,23 @@ +#! /bin/bash + +# Copyright 2016 Vimal Manohar +# Apache 2.0. + +local/segmentation/prepare_fisher_data.sh + +utils/combine_data.sh --extra-files "speech_feat.scp deriv_weights.scp deriv_weights_manual_seg.scp music_labels.scp" \ + data/fisher_train_100k_whole_all_corrupted_sp_hires_bp \ + data/fisher_train_100k_whole_corrupted_sp_hires_bp \ + data/fisher_train_100k_whole_music_corrupted_sp_hires_bp + +local/segmentation/train_stats_sad_music.sh \ + --train-data-dir data/fisher_train_100k_whole_all_corrupted_sp_hires_bp \ + --speech-feat-scp data/fisher_train_100k_whole_corrupted_sp_hires_bp/speech_feat.scp \ + --deriv-weights-scp data/fisher_train_100k_whole_corrupted_sp_hires_bp/deriv_weights.scp \ + --music-labels-scp data/fisher_train-100k_whole_music_corrupted_sp_hires_bp/music_labels.scp \ + --max-param-change 0.2 \ + --num-epochs 2 --affix k \ + --splice-indexes "-3,-2,-1,0,1,2,3 -6,0,mean+count(-99:3:9:99) -9,0,3 0" + +local/segmentation/run_segmentation_ami.sh \ + --nnet-dir exp/nnet3_sad_snr/nnet_tdnn_k_n4 diff --git a/egs/aspire/s5/local/segmentation/run_fisher_babel.sh b/egs/aspire/s5/local/segmentation/run_fisher_babel.sh new file mode 100644 index 00000000000..bdf6d3585f7 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/run_fisher_babel.sh @@ -0,0 +1,2 @@ + +utils/combine_data.sh diff --git a/egs/aspire/s5/local/segmentation/run_segmentation_ami.sh b/egs/aspire/s5/local/segmentation/run_segmentation_ami.sh new file mode 100755 index 00000000000..48677598728 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/run_segmentation_ami.sh @@ -0,0 +1,452 @@ +#! /bin/bash + +# Copyright 2016 Vimal Manohar +# Apache 2.0. + +. cmd.sh +. path.sh + +set -e +set -o pipefail +set -u + +stage=-1 +nnet_dir=exp/nnet3_sad_snr/nnet_tdnn_k_n4 +extra_left_context=100 +extra_right_context=20 +task=SAD +iter=final + +segmentation_stage=-1 +sil_prior=0.7 +speech_prior=0.3 +min_silence_duration=30 +min_speech_duration=10 +frame_subsampling_factor=3 +ali_dir=/export/a09/vmanoha1/workspace_asr_diarization/egs/ami/s5b/exp/ihm/nnet3_cleaned/tdnn_sp_ali_dev_ihmdata_oraclespk + +. utils/parse_options.sh + +export PATH=$KALDI_ROOT/tools/sctk/bin:$PATH + +src_dir=/export/a09/vmanoha1/workspace_asr_diarization/egs/ami/s5b # AMI src_dir +dir=exp/sad_ami_sdm1_dev/ref + +mkdir -p $dir + +# Expecting user to have done run.sh to run the AMI recipe in $src_dir for +# both sdm and ihm microphone conditions + +if [ $stage -le 1 ]; then + ( + cd $src_dir + local/prepare_parallel_train_data.sh --train-set dev sdm1 + + awk '{print $1" "$2}' $src_dir/data/ihm/dev/segments > \ + $src_dir/data/ihm/dev/utt2reco + awk '{print $1" "$2}' $src_dir/data/sdm1/dev/segments > \ + $src_dir/data/sdm1/dev/utt2reco + + cat $src_dir/data/sdm1/dev_ihmdata/ihmutt2utt | \ + utils/apply_map.pl -f 1 $src_dir/data/ihm/dev/utt2reco | \ + utils/apply_map.pl -f 2 $src_dir/data/sdm1/dev/utt2reco | \ + sort -u > $src_dir/data/sdm1/dev_ihmdata/ihm2sdm_reco + ) +fi + +if [ $stage -le 2 ]; then + ( + cd $src_dir + + utils/copy_data_dir.sh $src_dir/data/sdm1/dev_ihmdata \ + $src_dir/data/sdm1/dev_ihmdata_oraclespk + + cut -d ' ' -f 1,2 $src_dir/data/ihm/dev/segments | \ + utils/apply_map.pl -f 1 $src_dir/data/sdm1/dev_ihmdata/ihmutt2utt > \ + $src_dir/data/sdm1/dev_ihmdata_oraclespk/utt2spk.temp + + cat $src_dir/data/sdm1/dev_ihmdata_oraclespk/utt2spk.temp | \ + awk '{print $1" "$2"-"$1}' > \ + $src_dir/data/sdm1/dev_ihmdata_oraclespk/utt2newutt + + utils/apply_map.pl -f 1 $src_dir/data/sdm1/dev_ihmdata_oraclespk/utt2newutt \ + < $src_dir/data/sdm1/dev_ihmdata_oraclespk/utt2spk.temp > \ + $src_dir/data/sdm1/dev_ihmdata_oraclespk/utt2spk + + for f in feats.scp segments text; do + utils/apply_map.pl -f 1 $src_dir/data/sdm1/dev_ihmdata_oraclespk/utt2newutt \ + < $src_dir/data/sdm1/dev_ihmdata/$f > \ + $src_dir/data/sdm1/dev_ihmdata_oraclespk/$f + done + + rm $src_dir/data/sdm1/dev_ihmdata_oraclespk/{spk2utt,cmvn.scp} + utils/fix_data_dir.sh \ + $src_dir/data/sdm1/dev_ihmdata_oraclespk + + utils/data/get_reco2utt.sh $src_dir/data/sdm1/dev_ihmdata_oraclespk + ) +fi + +phone_map=$dir/phone_map +if [ $stage -le 2 ]; then + steps/segmentation/get_sad_map.py \ + $src_dir/data/lang | utils/sym2int.pl -f 1 $src_dir/data/lang/phones.txt > \ + $phone_map +fi + +if [ -z $ali_dir ]; then + if [ $stage -le 3 ]; then + # Expecting user to have run local/run_cleanup_segmentation.sh in $src_dir + ( + cd $src_dir + steps/align_fmllr.sh --nj 18 --cmd "$train_cmd" \ + data/sdm1/dev_ihmdata_oraclespk data/lang \ + exp/ihm/tri3_cleaned \ + exp/sdm1/tri3_cleaned_dev_ihmdata_oraclespk + ) + fi + ali_dir=exp/sdm1/tri3_cleaned_ali_dev_ihmdata_oraclespk +fi + +if [ $stage -le 4 ]; then + steps/segmentation/internal/convert_ali_to_vad.sh --cmd "$train_cmd" \ + $ali_dir $phone_map $dir +fi + +echo "A 1" > $dir/channel_map +cat $src_dir/data/sdm1/dev/reco2file_and_channel | \ + utils/apply_map.pl -f 3 $dir/channel_map > $dir/reco2file_and_channel + +# Map each IHM recording to a unique integer id. +# This will be the "speaker label" as each recording is assumed to have a +# single speaker. +cat $src_dir/data/sdm1/dev_ihmdata_oraclespk/reco2utt | \ + awk 'BEGIN{i=1} {print $1" "1":"i" 100000:100000"; i++;}' > \ + $src_dir/data/sdm1/dev_ihmdata_oraclespk/reco.txt + +if [ $stage -le 5 ]; then + utils/data/get_reco2num_frames.sh --frame-shift 0.01 --frame-overlap 0.015 \ + --cmd "$train_cmd" --nj 18 \ + $src_dir/data/sdm1/dev + + # Get a filter that changes the first and the last segment region outside + # the manual segmentation (usually some preparation lines) that are not + # transcribed. + $train_cmd $dir/log/interior_regions.log \ + segmentation-init-from-segments --shift-to-zero=false --frame-overlap=0.0 $src_dir/data/sdm1/dev/segments ark:- \| \ + segmentation-combine-segments-to-recordings ark:- ark,t:$src_dir/data/sdm1/dev/reco2utt ark:- \| \ + segmentation-create-subsegments --filter-label=1 --subsegment-label=1 \ + "ark:segmentation-init-from-lengths --label=0 ark,t:$src_dir/data/sdm1/dev/reco2num_frames ark:- |" ark:- ark,t:- \| \ + perl -ane '$F[3] = 100000; $F[$#F-1] = 100000; print join(" ", @F) . "\n";' \| \ + segmentation-post-process --merge-labels=0:1 --merge-dst-label=1 ark:- ark:- \| \ + segmentation-post-process --merge-labels=100000 --merge-dst-label=0 --merge-adjacent-segments \ + --max-intersegment-length=1000000 ark,t:- \ + "ark:| gzip -c > $dir/interior_regions.seg.gz" + + $train_cmd $dir/log/get_manual_segments_regions.log \ + segmentation-init-from-segments --shift-to-zero=false --frame-overlap=0.0 $src_dir/data/sdm1/dev/segments ark:- \| \ + segmentation-combine-segments-to-recordings ark:- ark,t:$src_dir/data/sdm1/dev/reco2utt ark:- \| \ + segmentation-create-subsegments --filter-label=1 --subsegment-label=1 \ + "ark:segmentation-init-from-lengths --label=100000 ark,t:$src_dir/data/sdm1/dev/reco2num_frames ark:- |" ark:- ark:- \| \ + segmentation-post-process --merge-labels=100000 --merge-dst-label=0 --merge-adjacent-segments \ + --max-intersegment-length=1000000 ark,t:- \ + "ark:| gzip -c > $dir/manual_segments_regions.seg.gz" +fi + +if [ $stage -le 6 ]; then + # Reference RTTM where SPEECH frames are obtainted by combining IHM VAD alignments + $train_cmd $dir/log/get_ref_spk_seg.log \ + segmentation-combine-segments --include-missing-utt-level-segmentations scp:$dir/sad_seg.scp \ + "ark:segmentation-init-from-segments --shift-to-zero=false --frame-overlap=0.0 --label=100000 $src_dir/data/sdm1/dev_ihmdata_oraclespk/segments ark:- |" \ + ark,t:$src_dir/data/sdm1/dev_ihmdata_oraclespk/reco2utt ark:- \| \ + segmentation-post-process --remove-labels=0 ark:- ark:- \| \ + segmentation-copy --utt2label-map-rspecifier=ark,t:$src_dir/data/sdm1/dev_ihmdata_oraclespk/reco.txt \ + ark:- ark:- \| \ + segmentation-merge-recordings \ + "ark,t:utils/utt2spk_to_spk2utt.pl $src_dir/data/sdm1/dev_ihmdata/ihm2sdm_reco |" \ + ark:- "ark:| gzip -c > $dir/ref_spk_seg.gz" +fi + +if [ $stage -le 7 ]; then + # To get the actual RTTM, we need to add no-score + $train_cmd $dir/log/get_ref_spk_rttm_manual_seg.log \ + export PATH=$KALDI_ROOT/tools/sctk/bin:$PATH '&&' \ + segmentation-copy --keep-label=0 "ark:gunzip -c $dir/manual_segments_regions.seg.gz |" ark:- \| \ + segmentation-post-process --merge-labels=0 --merge-dst-label=100000 ark:- ark:- \| \ + segmentation-merge "ark:gunzip -c $dir/ref_spk_seg.gz |" ark:- ark:- \| \ + segmentation-to-rttm --reco2file-and-channel=$dir/reco2file_and_channel \ + --map-to-speech-and-sil=false --no-score-label=100000 ark:- - \| \ + rttmSmooth.pl -s 0 \| rttmSort.pl '>' $dir/ref_spk_manual_seg.rttm + + $train_cmd $dir/log/get_ref_spk_rttm_interior.log \ + export PATH=$KALDI_ROOT/tools/sctk/bin:$PATH '&&' \ + segmentation-copy --keep-label=0 "ark:gunzip -c $dir/interior_regions.seg.gz |" ark:- \| \ + segmentation-post-process --merge-labels=0 --merge-dst-label=100000 ark:- ark:- \| \ + segmentation-merge "ark:gunzip -c $dir/ref_spk_seg.gz |" ark:- ark:- \| \ + segmentation-to-rttm --reco2file-and-channel=$dir/reco2file_and_channel \ + --map-to-speech-and-sil=false --no-score-label=100000 ark:- - \| \ + rttmSmooth.pl -s 0 \| rttmSort.pl '>' $dir/ref_spk_interior.rttm + + $train_cmd $dir/log/get_ref_rttm_manual_seg.log \ + export PATH=$KALDI_ROOT/tools/sctk/bin:$PATH '&&' \ + segmentation-get-stats --lengths-rspecifier=ark,t:$src_dir/data/sdm1/dev/reco2num_frames \ + "ark:gunzip -c $dir/ref_spk_seg.gz | segmentation-post-process --remove-labels=0 ark:- ark:- |" \ + ark:/dev/null ark:- ark:/dev/null \| \ + segmentation-init-from-ali ark:- ark:- \| \ + segmentation-post-process --merge-labels=1:2:3:4:5:6:7:8:9:10 --merge-dst-label=1 \ + --merge-adjacent-segments --max-intersegment-length=10000 ark:- ark:- \| \ + segmentation-create-subsegments --filter-label=0 --subsegment-label=100000 \ + ark:- "ark:gunzip -c $dir/manual_segments_regions.seg.gz |" ark:- \| \ + segmentation-post-process --merge-adjacent-segments --max-intersegment-length=10000 ark:- ark:- \| \ + segmentation-to-rttm --reco2file-and-channel=$dir/reco2file_and_channel \ + --no-score-label=100000 ark:- - \| \ + rttmSmooth.pl -s 0 \| rttmSort.pl '>' $dir/ref_manual_seg.rttm + + $train_cmd $dir/log/get_ref_rttm_interior.log \ + export PATH=$KALDI_ROOT/tools/sctk/bin:$PATH '&&' \ + segmentation-get-stats --lengths-rspecifier=ark,t:$src_dir/data/sdm1/dev/reco2num_frames \ + "ark:gunzip -c $dir/ref_spk_seg.gz | segmentation-post-process --remove-labels=0 ark:- ark:- |" \ + ark:/dev/null ark:- ark:/dev/null \| \ + segmentation-init-from-ali ark:- ark:- \| \ + segmentation-post-process --merge-labels=1:2:3:4:5:6:7:8:9:10 --merge-dst-label=1 \ + --merge-adjacent-segments --max-intersegment-length=10000 ark:- ark:- \| \ + segmentation-create-subsegments --filter-label=0 --subsegment-label=100000 \ + ark:- "ark:gunzip -c $dir/interior_regions.seg.gz |" ark:- \| \ + segmentation-post-process --merge-adjacent-segments --max-intersegment-length=10000 ark:- ark:- \| \ + segmentation-to-rttm --reco2file-and-channel=$dir/reco2file_and_channel \ + --no-score-label=100000 ark:- - \| \ + rttmSmooth.pl -s 0 \| rttmSort.pl '>' $dir/ref_interior.rttm + + # Get RTTM for overlapped speech detection with 3 classes + # 0 -> SILENCE, 1 -> SINGLE_SPEAKER, 2 -> OVERLAP + $train_cmd $dir/log/get_overlapping_rttm_manual_seg.log \ + export PATH=$KALDI_ROOT/tools/sctk/bin:$PATH '&&' \ + segmentation-get-stats --lengths-rspecifier=ark,t:$src_dir/data/sdm1/dev/reco2num_frames \ + "ark:gunzip -c $dir/ref_spk_seg.gz | segmentation-post-process --remove-labels=0 ark:- ark:- |" \ + ark:/dev/null ark:- ark:/dev/null \| \ + segmentation-init-from-ali ark:- ark:- \| \ + segmentation-post-process --merge-labels=2:3:4:5:6:7:8:9:10 --merge-dst-label=2 \ + --merge-adjacent-segments --max-intersegment-length=10000 ark:- ark:- \| \ + segmentation-create-subsegments --filter-label=0 --subsegment-label=100000 \ + ark:- "ark:gunzip -c $dir/manual_segments_regions.seg.gz |" ark:- \| \ + segmentation-post-process --merge-adjacent-segments --max-intersegment-length=10000 ark:- ark:- \| \ + segmentation-to-rttm --map-to-speech-and-sil=false --reco2file-and-channel=$dir/reco2file_and_channel \ + --no-score-label=100000 ark:- - \| \ + rttmSmooth.pl -s 0 \| rttmSort.pl '>' $dir/overlapping_speech_ref_manual_seg.rttm + + $train_cmd $dir/log/get_overlapping_rttm_manual_seg.log \ + export PATH=$KALDI_ROOT/tools/sctk/bin:$PATH '&&' \ + segmentation-get-stats --lengths-rspecifier=ark,t:$src_dir/data/sdm1/dev/reco2num_frames \ + "ark:gunzip -c $dir/ref_spk_seg.gz | segmentation-post-process --remove-labels=0 ark:- ark:- |" \ + ark:/dev/null ark:- ark:/dev/null \| \ + segmentation-init-from-ali ark:- ark:- \| \ + segmentation-post-process --merge-labels=2:3:4:5:6:7:8:9:10 --merge-dst-label=2 \ + --merge-adjacent-segments --max-intersegment-length=10000 ark:- ark:- \| \ + segmentation-create-subsegments --filter-label=0 --subsegment-label=100000 \ + ark:- "ark:gunzip -c $dir/interior_regions.seg.gz |" ark:- \| \ + segmentation-post-process --merge-adjacent-segments --max-intersegment-length=10000 ark:- ark:- \| \ + segmentation-to-rttm --map-to-speech-and-sil=false --reco2file-and-channel=$dir/reco2file_and_channel \ + --no-score-label=100000 ark:- - \| \ + rttmSmooth.pl -s 0 \| rttmSort.pl'>' $dir/overlapping_speech_ref_interior.rttm +fi + +exit 0 + +if [ $stage -le 8 ]; then + # Get a filter that selects only regions of speech + $train_cmd $dir/log/get_speech_filter.log \ + segmentation-get-stats --lengths-rspecifier=ark,t:$src_dir/data/sdm1/dev/reco2num_frames \ + "ark:gunzip -c $dir/ref_spk_seg.gz | segmentation-post-process --remove-labels=0 ark:- ark:- |" \ + ark:/dev/null ark:- ark:/dev/null \| \ + segmentation-init-from-ali ark:- ark:- \| \ + segmentation-post-process --merge-labels=1:2:3:4:5:6:7:8:9:10 --merge-dst-label=1 ark:- ark:- \| \ + segmentation-create-subsegments --filter-label=0 --subsegment-label=0 \ + ark:- "ark:gunzip -c $dir/manual_segments_regions.seg.gz |" ark:- \| \ + segmentation-post-process --merge-adjacent-segments --max-intersegment-length=10000 \ + ark:- "ark:| gzip -c > $dir/manual_segments_speech_regions.seg.gz" +fi + +hyp_dir=${nnet_dir}/segmentation_ami_sdm1_dev_whole_bp/ami_sdm1_dev + +if [ $stage -le 9 ]; then + steps/segmentation/do_segmentation_data_dir.sh --reco-nj 18 \ + --mfcc-config conf/mfcc_hires_bp.conf --feat-affix bp --do-downsampling true \ + --extra-left-context $extra_left_context --extra-right-context $extra_right_context \ + --output-name output-speech --frame-subsampling-factor $frame_subsampling_factor --iter $iter \ + --stage $segmentation_stage \ + $src_dir/data/sdm1/dev $nnet_dir mfcc_hires_bp $hyp_dir +fi + +sad_dir=${nnet_dir}/sad_ami_sdm1_dev_whole_bp/ +hyp_dir=${hyp_dir}_seg + +if [ $stage -le 10 ]; then + utils/data/get_reco2utt.sh $src_dir/data/sdm1/dev_ihmdata_oraclespk + utils/data/get_reco2utt.sh $hyp_dir + + segmentation-init-from-segments --shift-to-zero=false --frame-overlap=0.0 $hyp_dir/segments ark:- | \ + segmentation-combine-segments-to-recordings ark:- ark,t:$hyp_dir/reco2utt ark:- | \ + segmentation-to-ali --length-tolerance=48 --lengths-rspecifier=ark,t:$src_dir/data/sdm1/dev/reco2num_frames \ + ark:- ark:- | \ + segmentation-init-from-ali ark:- ark:- | \ + segmentation-to-rttm --reco2file-and-channel=$dir/reco2file_and_channel ark:- $hyp_dir/sys.rttm + + #steps/segmentation/convert_utt2spk_and_segments_to_rttm.py \ + # $hyp_dir/utt2spk \ + # $hyp_dir/segments \ + # $dir/reco2file_and_channel \ + # /dev/stdout | spkr2sad.pl > $hyp_dir/sys.rttm +fi + +if [ $stage -le 11 ]; then + cat < $likes_dir/log_likes.JOB.gz" + cp $sad_dir/num_jobs $likes_dir + fi + else + if [ $stage -le 12 ]; then + steps/segmentation/do_segmentation_data_dir_generic.sh --reco-nj 18 \ + --mfcc-config conf/mfcc_hires_bp.conf --feat-affix bp --do-downsampling true \ + --extra-left-context $extra_left_context --extra-right-context $extra_right_context \ + --segmentation-config conf/segmentation_ovlp.conf \ + --output-name output-overlapping_sad \ + --min-durations 30:10:10 --priors 0.5:0.35:0.15 \ + --sad-name ovlp_sad --segmentation-name segmentation_ovlp_sad \ + --frame-subsampling-factor $frame_subsampling_factor --iter $iter \ + --stage $segmentation_stage \ + $src_dir/data/sdm1/dev $nnet_dir mfcc_hires_bp $hyp_dir + fi + + likes_dir=${nnet_dir}/ovlp_sad_ami_sdm1_dev_whole_bp/ + fi + + hyp_dir=${hyp_dir}_seg + mkdir -p $hyp_dir + + seg_dir=${nnet_dir}/segmentation_ovlp_sad_ami_sdm1_dev_whole_bp/ + lang=${seg_dir}/lang + + if [ $stage -le 14 ]; then + mkdir -p $lang + steps/segmentation/internal/prepare_sad_lang.py \ + --phone-transition-parameters="--phone-list=1 --min-duration=10 --end-transition-probability=0.1" \ + --phone-transition-parameters="--phone-list=2 --min-duration=3 --end-transition-probability=0.1" \ + --phone-transition-parameters="--phone-list=3 --min-duration=3 --end-transition-probability=0.1" $lang + cp $lang/phones.txt $lang/words.txt + + feat_dim=2 # dummy. We don't need this. + $train_cmd $seg_dir/log/create_transition_model.log gmm-init-mono \ + $lang/topo $feat_dim - $seg_dir/tree \| \ + copy-transition-model --binary=false - $seg_dir/trans.mdl || exit 1 +fi + + if [ $stage -le 15 ]; then + + cat > $lang/word2prior < $lang/G.fst +fi + + if [ $stage -le 16 ]; then + $train_cmd $seg_dir/log/make_vad_graph.log \ + steps/segmentation/internal/make_sad_graph.sh --iter trans \ + $lang $seg_dir $seg_dir/graph_test || exit 1 + fi + + if [ $stage -le 17 ]; then + steps/segmentation/decode_sad.sh \ + --acwt 1 --beam 10 --max-active 7000 --iter trans \ + $seg_dir/graph_test $likes_dir $seg_dir + fi + + if [ $stage -le 18 ]; then + cat < $hyp_dir/labels_map +1 0 +2 1 +3 2 +EOF + gunzip -c $seg_dir/ali.*.gz | \ + segmentation-init-from-ali ark:- ark:- | \ + segmentation-copy --frame-subsampling-factor=$frame_subsampling_factor \ + --label-map=$hyp_dir/labels_map ark:- ark:- | \ + segmentation-to-rttm --map-to-speech-and-sil=false \ + --reco2file-and-channel=$dir/reco2file_and_channel ark:- $hyp_dir/sys.rttm + fi + # Get RTTM for overlapped speech detection with 3 classes + # 0 -> SILENCE, 1 -> SINGLE_SPEAKER, 2 -> OVERLAP + $train_cmd $dir/log/get_overlapping_rttm.log \ + segmentation-get-stats --lengths-rspecifier=ark,t:$src_dir/data/sdm1/dev/reco2num_frames \ + "ark:gunzip -c $dir/ref_spk_seg.gz | segmentation-post-process --remove-labels=0 ark:- ark:- |" \ + ark:/dev/null ark:- ark:/dev/null \| \ + segmentation-init-from-ali ark:- ark:- \| \ + segmentation-post-process --merge-labels=2:3:4:5:6:7:8:9:10 --merge-dst-label=2 ark:- ark:- \| \ + segmentation-create-subsegments --filter-label=0 --subsegment-label=10000 \ + ark:- "ark:gunzip -c $dir/manual_segments_regions.seg.gz |" ark:- \| \ + segmentation-post-process --merge-adjacent-segments --max-intersegment-length=10000 ark:- ark:- \| \ + segmentation-to-rttm --map-to-speech-and-sil=false --reco2file-and-channel=$dir/reco2file_and_channel \ + --no-score-label=10000 ark:- $dir/overlapping_speech_ref.rttm + + if [ $stage -le 19 ]; then + cat < \ + $src_dir/data/ihm/train/utt2reco + awk '{print $1" "$2}' $src_dir/data/sdm1/train/segments > \ + $src_dir/data/sdm1/train/utt2reco + + cat $src_dir/data/sdm1/train_ihmdata/ihmutt2utt | \ + utils/apply_map.pl -f 1 $src_dir/data/ihm/train/utt2reco | \ + utils/apply_map.pl -f 2 $src_dir/data/sdm1/train/utt2reco | \ + sort -u > $src_dir/data/sdm1/train_ihmdata/ihm2sdm_reco + ) +fi + +if [ $stage -le 2 ]; then + ( + cd $src_dir + + utils/copy_data_dir.sh $src_dir/data/sdm1/train_ihmdata \ + $src_dir/data/sdm1/train_ihmdata_oraclespk + + cat $src_dir/data/ihm/train/utt2spk | \ + utils/apply_map.pl -f 1 $src_dir/data/sdm1/train_ihmdata/ihmutt2utt > \ + $src_dir/data/sdm1/train_ihmdata_oraclespk/utt2spk.temp + + cat $src_dir/data/sdm1/train_ihmdata_oraclespk/utt2spk.temp | \ + awk '{print $1" "$2"-"$1}' > \ + $src_dir/data/sdm1/train_ihmdata_oraclespk/utt2newutt + + utils/apply_map.pl -f 1 $src_dir/data/sdm1/train_ihmdata_oraclespk/utt2newutt \ + < $src_dir/data/sdm1/train_ihmdata_oraclespk/utt2spk.temp > \ + $src_dir/data/sdm1/train_ihmdata_oraclespk/utt2spk + + for f in feats.scp segments text; do + utils/apply_map.pl -f 1 $src_dir/data/sdm1/train_ihmdata_oraclespk/utt2newutt \ + < $src_dir/data/sdm1/train_ihmdata/$f > \ + $src_dir/data/sdm1/train_ihmdata_oraclespk/$f + done + + rm $src_dir/data/sdm1/train_ihmdata_oraclespk/{spk2utt,cmvn.scp} + utils/fix_data_dir.sh \ + $src_dir/data/sdm1/train_ihmdata_oraclespk + + utils/data/get_reco2utt.sh $src_dir/data/sdm1/train_ihmdata_oraclespk + ) +fi + +phone_map=$dir/phone_map +if [ $stage -le 2 ]; then + steps/segmentation/get_sad_map.py \ + $src_dir/data/lang | utils/sym2int.pl -f 1 $src_dir/data/lang/phones.txt > \ + $phone_map +fi + +if [ -z $ali_dir ]; then + if [ $stage -le 3 ]; then + # Expecting user to have run local/run_cleanup_segmentation.sh in $src_dir + ( + cd $src_dir + steps/align_fmllr.sh --nj 18 --cmd "$train_cmd" \ + data/sdm1/train_ihmdata_oraclespk data/lang \ + exp/ihm/tri3_cleaned \ + exp/sdm1/tri3_cleaned_train_ihmdata_oraclespk + ) + fi + ali_dir=exp/sdm1/tri3_cleaned_ali_train_ihmdata_oraclespk +fi + +if [ $stage -le 4 ]; then + steps/segmentation/internal/convert_ali_to_vad.sh --cmd "$train_cmd" \ + $ali_dir $phone_map $dir +fi + +echo "A 1" > $dir/channel_map +cat $src_dir/data/sdm1/train/reco2file_and_channel | \ + utils/apply_map.pl -f 3 $dir/channel_map > $dir/reco2file_and_channel + +# Map each IHM recording to a unique integer id. +# This will be the "speaker label" as each recording is assumed to have a +# single speaker. +cat $src_dir/data/sdm1/train_ihmdata_oraclespk/reco2utt | \ + awk 'BEGIN{i=1} {print $1" "1":"i; i++;}' > \ + $src_dir/data/sdm1/train_ihmdata_oraclespk/reco.txt +if [ $stage -le 5 ]; then + utils/data/get_reco2num_frames.sh --frame-shift 0.01 --frame-overlap 0.015 \ + --cmd "$train_cmd" --nj 18 \ + $src_dir/data/sdm1/train + + # Get a filter that changes the first and the last segment region outside + # the manual segmentation (usually some preparation lines) that are not + # transcribed. + $train_cmd $dir/log/interior_regions.log \ + segmentation-init-from-segments --shift-to-zero=false --frame-overlap=0.0 $src_dir/data/sdm1/train/segments ark:- \| \ + segmentation-combine-segments-to-recordings ark:- ark,t:$src_dir/data/sdm1/train/reco2utt ark:- \| \ + segmentation-create-subsegments --filter-label=1 --subsegment-label=1 \ + "ark:segmentation-init-from-lengths --label=0 ark,t:$src_dir/data/sdm1/train/reco2num_frames ark:- |" ark:- ark,t:- \| \ + perl -ane '$F[3] = 10000; $F[$#F-1] = 10000; print join(" ", @F) . "\n";' \| \ + segmentation-post-process --merge-labels=0:1 --merge-dst-label=1 ark:- ark:- \| \ + segmentation-post-process --merge-labels=10000 --merge-dst-label=0 --merge-adjacent-segments \ + --max-intersegment-length=1000000 ark,t:- \ + "ark:| gzip -c > $dir/interior_regions.seg.gz" + + $train_cmd $dir/log/get_manual_segments_regions.log \ + segmentation-init-from-segments --shift-to-zero=false --frame-overlap=0.0 $src_dir/data/sdm1/train/segments ark:- \| \ + segmentation-combine-segments-to-recordings ark:- ark,t:$src_dir/data/sdm1/train/reco2utt ark:- \| \ + segmentation-create-subsegments --filter-label=1 --subsegment-label=1 \ + "ark:segmentation-init-from-lengths --label=0 ark,t:$src_dir/data/sdm1/train/reco2num_frames ark:- |" ark:- ark,t:- \| \ + perl -ane '$F[3] = 10000; $F[$#F-1] = 10000; print join(" ", @F) . "\n";' \| \ + segmentation-post-process --merge-labels=0:1 --merge-dst-label=1 ark:- ark:- \| \ + segmentation-post-process --merge-labels=10000 --merge-dst-label=0 --merge-adjacent-segments \ + --max-intersegment-length=10000 ark,t:- \ + "ark:| gzip -c > $dir/manual_segments_regions.seg.gz" +fi + +if [ $stage -le 6 ]; then + # Reference RTTM where SPEECH frames are obtainted by combining IHM VAD alignments + $train_cmd $dir/log/get_ref_spk_seg.log \ + segmentation-combine-segments scp:$dir/sad_seg.scp \ + "ark:segmentation-init-from-segments --shift-to-zero=false --frame-overlap=0.0 $src_dir/data/sdm1/train_ihmdata_oraclespk/segments ark:- |" \ + ark,t:$src_dir/data/sdm1/train_ihmdata_oraclespk/reco2utt ark:- \| \ + segmentation-copy --keep-label=1 ark:- ark:- \| \ + segmentation-copy --utt2label-map-rspecifier=ark,t:$src_dir/data/sdm1/train_ihmdata/reco.txt \ + ark:- ark:- \| \ + segmentation-merge-recordings \ + "ark,t:utils/utt2spk_to_spk2utt.pl $src_dir/data/sdm1/train_ihmdata/ihm2sdm_reco |" \ + ark:- "ark:| gzip -c > $dir/ref_spk_seg.gz" +fi + +if [ $stage -le 7 ]; then + # To get the actual RTTM, we need to add no-score + $train_cmd $dir/log/get_ref_spk_rttm_manual_seg.log \ + export PATH=$KALDI_ROOT/tools/sctk/bin:$PATH '&&' \ + segmentation-copy --keep-label=0 "ark:gunzip -c $dir/manual_segments_regions.seg.gz |" ark:- \| \ + segmentation-post-process --merge-labels=0 --merge-dst-label=10000 ark:- ark:- \| \ + segmentation-merge "ark:gunzip -c $dir/ref_spk_seg.gz |" ark:- ark:- \| \ + segmentation-to-rttm --reco2file-and-channel=$dir/reco2file_and_channel \ + --map-to-speech-and-sil=false --no-score-label=10000 ark:- - \| \ + rttmSmooth.pl -s 0 \| rttmSort.pl '>' $dir/ref_spk_manual_seg.rttm + + $train_cmd $dir/log/get_ref_spk_rttm_interior.log \ + export PATH=$KALDI_ROOT/tools/sctk/bin:$PATH '&&' \ + segmentation-copy --keep-label=0 "ark:gunzip -c $dir/interior_regions.seg.gz |" ark:- \| \ + segmentation-post-process --merge-labels=0 --merge-dst-label=10000 ark:- ark:- \| \ + segmentation-merge "ark:gunzip -c $dir/ref_spk_seg.gz |" ark:- ark:- \| \ + segmentation-to-rttm --reco2file-and-channel=$dir/reco2file_and_channel \ + --map-to-speech-and-sil=false --no-score-label=10000 ark:- - \| \ + rttmSmooth.pl -s 0 \| rttmSort.pl '>' $dir/ref_spk_interior.rttm + + $train_cmd $dir/log/get_ref_rttm_manual_seg.log \ + export PATH=$KALDI_ROOT/tools/sctk/bin:$PATH '&&' \ + segmentation-get-stats --lengths-rspecifier=ark,t:$src_dir/data/sdm1/train/reco2num_frames \ + "ark:gunzip -c $dir/ref_spk_seg.gz | segmentation-post-process --remove-labels=0 ark:- ark:- |" \ + ark:/dev/null ark:- ark:/dev/null \| \ + segmentation-init-from-ali ark:- ark:- \| \ + segmentation-post-process --merge-labels=1:2:3:4:5:6:7:8:9:10 --merge-dst-label=1 \ + --merge-adjacent-segments --max-intersegment-length=10000 ark:- ark:- \| \ + segmentation-create-subsegments --filter-label=0 --subsegment-label=10000 \ + ark:- "ark:gunzip -c $dir/manual_segments_regions.seg.gz |" ark:- \| \ + segmentation-post-process --merge-adjacent-segments --max-intersegment-length=10000 ark:- ark:- \| \ + segmentation-to-rttm --reco2file-and-channel=$dir/reco2file_and_channel \ + --no-score-label=10000 ark:- - \| \ + rttmSmooth.pl -s 0 \| rttmSort.pl '>' $dir/ref_manual_seg.rttm + + $train_cmd $dir/log/get_ref_rttm_interior.log \ + export PATH=$KALDI_ROOT/tools/sctk/bin:$PATH '&&' \ + segmentation-get-stats --lengths-rspecifier=ark,t:$src_dir/data/sdm1/train/reco2num_frames \ + "ark:gunzip -c $dir/ref_spk_seg.gz | segmentation-post-process --remove-labels=0 ark:- ark:- |" \ + ark:/dev/null ark:- ark:/dev/null \| \ + segmentation-init-from-ali ark:- ark:- \| \ + segmentation-post-process --merge-labels=1:2:3:4:5:6:7:8:9:10 --merge-dst-label=1 \ + --merge-adjacent-segments --max-intersegment-length=10000 ark:- ark:- \| \ + segmentation-create-subsegments --filter-label=0 --subsegment-label=10000 \ + ark:- "ark:gunzip -c $dir/interior_regions.seg.gz |" ark:- \| \ + segmentation-post-process --merge-adjacent-segments --max-intersegment-length=10000 ark:- ark:- \| \ + segmentation-to-rttm --reco2file-and-channel=$dir/reco2file_and_channel \ + --no-score-label=10000 ark:- - \| \ + rttmSmooth.pl -s 0 \| rttmSort.pl '>' $dir/ref_interior.rttm + + # Get RTTM for overlapped speech detection with 3 classes + # 0 -> SILENCE, 1 -> SINGLE_SPEAKER, 2 -> OVERLAP + $train_cmd $dir/log/get_overlapping_rttm_manual_seg.log \ + export PATH=$KALDI_ROOT/tools/sctk/bin:$PATH '&&' \ + segmentation-get-stats --lengths-rspecifier=ark,t:$src_dir/data/sdm1/train/reco2num_frames \ + "ark:gunzip -c $dir/ref_spk_seg.gz | segmentation-post-process --remove-labels=0 ark:- ark:- |" \ + ark:/dev/null ark:- ark:/dev/null \| \ + segmentation-init-from-ali ark:- ark:- \| \ + segmentation-post-process --merge-labels=2:3:4:5:6:7:8:9:10 --merge-dst-label=2 \ + --merge-adjacent-segments --max-intersegment-length=10000 ark:- ark:- \| \ + segmentation-create-subsegments --filter-label=0 --subsegment-label=10000 \ + ark:- "ark:gunzip -c $dir/manual_segments_regions.seg.gz |" ark:- \| \ + segmentation-post-process --merge-adjacent-segments --max-intersegment-length=10000 ark:- ark:- \| \ + segmentation-to-rttm --map-to-speech-and-sil=false --reco2file-and-channel=$dir/reco2file_and_channel \ + --no-score-label=10000 ark:- - \| \ + rttmSmooth.pl -s 0 \| rttmSort.pl '>' $dir/overlapping_speech_ref_manual_seg.rttm + + $train_cmd $dir/log/get_overlapping_rttm_manual_seg.log \ + export PATH=$KALDI_ROOT/tools/sctk/bin:$PATH '&&' \ + segmentation-get-stats --lengths-rspecifier=ark,t:$src_dir/data/sdm1/train/reco2num_frames \ + "ark:gunzip -c $dir/ref_spk_seg.gz | segmentation-post-process --remove-labels=0 ark:- ark:- |" \ + ark:/dev/null ark:- ark:/dev/null \| \ + segmentation-init-from-ali ark:- ark:- \| \ + segmentation-post-process --merge-labels=2:3:4:5:6:7:8:9:10 --merge-dst-label=2 \ + --merge-adjacent-segments --max-intersegment-length=10000 ark:- ark:- \| \ + segmentation-create-subsegments --filter-label=0 --subsegment-label=10000 \ + ark:- "ark:gunzip -c $dir/interior_regions.seg.gz |" ark:- \| \ + segmentation-post-process --merge-adjacent-segments --max-intersegment-length=10000 ark:- ark:- \| \ + segmentation-to-rttm --map-to-speech-and-sil=false --reco2file-and-channel=$dir/reco2file_and_channel \ + --no-score-label=10000 ark:- - \| \ + rttmSmooth.pl -s 0 \| rttmSort.pl'>' $dir/overlapping_speech_ref_interior.rttm +fi + +exit 0 + +if [ $stage -le 8 ]; then + # Get a filter that selects only regions of speech + $train_cmd $dir/log/get_speech_filter.log \ + segmentation-get-stats --lengths-rspecifier=ark,t:$src_dir/data/sdm1/train/reco2num_frames \ + "ark:gunzip -c $dir/ref_spk_seg.gz | segmentation-post-process --remove-labels=0 ark:- ark:- |" \ + ark:/dev/null ark:- ark:/dev/null \| \ + segmentation-init-from-ali ark:- ark:- \| \ + segmentation-post-process --merge-labels=1:2:3:4:5:6:7:8:9:10 --merge-dst-label=1 ark:- ark:- \| \ + segmentation-create-subsegments --filter-label=0 --subsegment-label=0 \ + ark:- "ark:gunzip -c $dir/manual_segments_regions.seg.gz |" ark:- \| \ + segmentation-post-process --merge-adjacent-segments --max-intersegment-length=10000 \ + ark:- "ark:| gzip -c > $dir/manual_segments_speech_regions.seg.gz" +fi + +hyp_dir=${nnet_dir}/segmentation_ovlp_ami_sdm1_train_whole_bp/ami_sdm1_train + +if [ $stage -le 12 ]; then + steps/segmentation/do_segmentation_data_dir_generic.sh --reco-nj 18 \ + --mfcc-config conf/mfcc_hires_bp.conf --feat-affix bp --do-downsampling true \ + --extra-left-context $extra_left_context --extra-right-context $extra_right_context \ + --segmentation-config conf/segmentation_ovlp.conf \ + --output-name output-overlapping_sad \ + --min-durations 30:10:10 --priors 0.5:0.35:0.15 \ + --sad-name ovlp_sad --segmentation-name segmentation_ovlp_sad \ + --frame-subsampling-factor $frame_subsampling_factor --iter $iter \ + --stage $segmentation_stage \ + $src_dir/data/sdm1/train $nnet_dir mfcc_hires_bp $hyp_dir +fi + +likes_dir=${nnet_dir}/ovlp_sad_ami_sdm1_train_whole_bp/ + +hyp_dir=${hyp_dir}_seg +mkdir -p $hyp_dir + +seg_dir=${nnet_dir}/segmentation_ovlp_sad_ami_sdm1_train_whole_bp/ +lang=${seg_dir}/lang + +if [ $stage -le 14 ]; then +mkdir -p $lang +steps/segmentation/internal/prepare_sad_lang.py \ + --phone-transition-parameters="--phone-list=1 --min-duration=10 --end-transition-probability=0.1" \ + --phone-transition-parameters="--phone-list=2 --min-duration=3 --end-transition-probability=0.1" \ + --phone-transition-parameters="--phone-list=3 --min-duration=3 --end-transition-probability=0.1" $lang +cp $lang/phones.txt $lang/words.txt + +feat_dim=2 # dummy. We don't need this. +$train_cmd $seg_dir/log/create_transition_model.log gmm-init-mono \ + $lang/topo $feat_dim - $seg_dir/tree \| \ + copy-transition-model --binary=false - $seg_dir/trans.mdl || exit 1 +fi + +if [ $stage -le 15 ]; then + +cat > $lang/word2prior < $lang/G.fst +fi + +if [ $stage -le 16 ]; then + $train_cmd $seg_dir/log/make_vad_graph.log \ + steps/segmentation/internal/make_sad_graph.sh --iter trans \ + $lang $seg_dir $seg_dir/graph_test || exit 1 +fi + +if [ $stage -le 17 ]; then + steps/segmentation/decode_sad.sh \ + --acwt 1 --beam 10 --max-active 7000 \ + $seg_dir/graph_test $likes_dir $seg_dir +fi + +if [ $stage -le 18 ]; then + cat < $hyp_dir/labels_map +1 0 +2 1 +3 2 +EOF + gunzip -c $seg_dir/ali.*.gz | \ + segmentation-init-from-ali ark:- ark:- | \ + segmentation-copy --frame-subsampling-factor=$frame_subsampling_factor \ + --label-map=$hyp_dir/labels_map ark:- ark:- | \ + segmentation-to-rttm --map-to-speech-and-sil=false \ + --reco2file-and-channel=$dir/reco2file_and_channel ark:- $hyp_dir/sys.rttm +fi +# Get RTTM for overlapped speech detection with 3 classes +# 0 -> SILENCE, 1 -> SINGLE_SPEAKER, 2 -> OVERLAP +$train_cmd $dir/log/get_overlapping_rttm.log \ + segmentation-get-stats --lengths-rspecifier=ark,t:$src_dir/data/sdm1/train/reco2num_frames \ + "ark:gunzip -c $dir/ref_spk_seg.gz | segmentation-post-process --remove-labels=0 ark:- ark:- |" \ + ark:/dev/null ark:- \| \ + segmentation-init-from-ali ark:- ark:- \| \ + segmentation-post-process --merge-labels=2:3:4:5:6:7:8:9:10 --merge-dst-label=2 ark:- ark:- \| \ + segmentation-create-subsegments --filter-label=0 --subsegment-label=10000 \ + ark:- "ark:gunzip -c $dir/manual_segments_regions.seg.gz |" ark:- \| \ + segmentation-post-process --merge-adjacent-segments --max-intersegment-length=10000 ark:- ark:- \| \ + segmentation-to-rttm --map-to-speech-and-sil=false --reco2file-and-channel=$dir/reco2file_and_channel \ + --no-score-label=10000 ark:- $dir/overlapping_speech_ref.rttm + +if [ $stage -le 19 ]; then + cat < 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_lstm_sad_ovlp_snr/nnet_lstm +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$sad_data_dir/feats.scp -` name=input + output name=output-temp input=Append(-2,-1,0,1,2) + + relu-renorm-layer name=tdnn1 input=Append(input@-2, input@-1, input, input@1, input@2) dim=256 + relu-renorm-layer name=tdnn2 input=Append(-1,0,1,2) dim=256 + relu-renorm-layer name=tdnn3 input=Append(-3,0,3,6) dim=256 + fast-lstmp-layer name=lstm1 cell-dim=256 recurrent-projection-dim=128 non-recurrent-projection-dim=128 delay=-3 + relu-renorm-layer name=tdnn4 input=Append(-6,0,6,12) dim=256 + fast-lstmp-layer name=lstm2 cell-dim=256 recurrent-projection-dim=128 non-recurrent-projection-dim=128 delay=-6 + + output-layer name=output-speech include-log-softmax=true dim=2 objective-scale=$speech_scale input=lstm2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech.txt learning-rate-factor=0.05 + + output-layer name=output-snr include-log-softmax=false dim=$num_snr_bins objective-type=quadratic objective-scale=`perl -e "print $speech_scale / $num_snr_bins"` input=lstm2 max-change=0.75 learning-rate-factor=0.5 + + output-layer name=output-overlapping_sad include-log-softmax=true dim=3 objective-scale=$ovlp_scale input=lstm2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-overlapping_sad.txt max-change=0.75 learning-rate-factor=0.02 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 2 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{01,02,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$sad_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-snr --target-type=dense --targets-scp=$sad_data_dir/irm_targets.scp --deriv-weights-scp=$sad_data_dir/deriv_weights_manual_seg.scp" \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$sad_data_dir/speech_feat.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + fi + + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_ovlp/storage ]; then + utils/create_split_dir.pl \ + /export/b{01,02,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_ovlp/storage $dir/egs_ovlp/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$ovlp_sad_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$ovlp_sad_data_dir/speech_feat.scp --deriv-weights-scp=$ovlp_sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\"" \ + --targets-parameters="--output-name=output-overlapping_sad --target-type=sparse --dim=3 --targets-scp=$ovlp_sad_data_dir/overlapping_sad_labels_fixed.scp --deriv-weights-scp=$ovlp_sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\"" \ + --generate-egs-scp=true \ + --dir=$dir/egs_ovlp + fi + + if [ $stage -le 4 ]; then + # num_chunk_per_minibatch is multiplied by 4 to allow a buffer to use + # the same egs with a different num_chunk_per_minibatch + steps/nnet3/multilingual/get_egs.sh \ + --cmd "$train_cmd" \ + --minibatch-size $[num_chunk_per_minibatch * 4] \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_ovlp $dir/egs_multi + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --trainer.compute-per-dim-accuracy=true \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$sad_data_dir \ + --targets-scp="$sad_data_dir/speech_feat.scp" \ + --dir=$dir || exit 1 +fi diff --git a/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_1a.sh b/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_1a.sh new file mode 100644 index 00000000000..4f0754d8355 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_1a.sh @@ -0,0 +1,267 @@ +#!/bin/bash + +# This is a script to train a time-delay neural network for speech activity detection (SAD) and +# music-id using statistic pooling component for long-context information. +# This script is same as 1e, but removes the stats component in the 3rd layer. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +chunk_width=40 +num_chunk_per_minibatch=64 + +extra_left_context=80 +extra_right_context=0 + +relu_dim=256 +cell_dim=256 +projection_dim=64 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +sad_data_dir=data/train_aztec_unsad_whole_corrupted_sp_hires_bp_2400 +music_data_dir=data/train_aztec_unsad_whole_music_corrupted_sp_hires_bp + +extra_egs_copy_cmd="nnet3-copy-egs --keep-outputs=output-speech,output-music ark:- ark:- |" + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=1a + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +num_utts=`cat $sad_data_dir/utt2spk $music_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_lstm_sad_music_snr/nnet_lstm +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$sad_data_dir/feats.scp -` name=input + + relu-renorm-layer name=tdnn1 input=Append(-2,-1,0,1,2) dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn2 input=Append(-1,0,1,2) dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn3 input=Append(-3,0,3,6) dim=$relu_dim add-log-stddev=true + fast-lstmp-layer name=lstm1 cell-dim=$cell_dim recurrent-projection-dim=$projection_dim non-recurrent-projection-dim=$projection_dim delay=-3 + relu-renorm-layer name=tdnn4 input=Append(-6,0,6,12) dim=$relu_dim + fast-lstmp-layer name=lstm2 cell-dim=$cell_dim recurrent-projection-dim=$projection_dim non-recurrent-projection-dim=$projection_dim delay=-6 + relu-renorm-layer name=tdnn5 input=Append(-12,0,12,24) dim=$relu_dim + + output-layer name=output-snr include-log-softmax=false dim=$num_snr_bins objective-type=quadratic objective-scale=`perl -e "print (($num_frames_music / $num_frames_sad) ** 0.25) / $num_snr_bins"` input=tdnn5 + output-layer name=output-speech include-log-softmax=true dim=2 objective-scale=`perl -e "print (($num_frames_music / $num_frames_sad) ** 0.25)"` input=tdnn5 + output-layer name=output-music include-log-softmax=true dim=2 input=tdnn5 + + output name=output-temp input=Append(input@-2,input@-1,input,input@1,input@2) +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 2 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$sad_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-snr --target-type=dense --targets-scp=$sad_data_dir/irm_targets.scp --deriv-weights-scp=$sad_data_dir/deriv_weights_manual_seg.scp" \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$sad_data_dir/speech_feat.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + fi + + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_music/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_music/storage $dir/egs_music/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$music_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-music --target-type=sparse --dim=2 --targets-scp=$music_data_dir/music_labels.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_music + fi + + if [ $stage -le 4 ]; then + # num_chunk_per_minibatch is multiplied by 4 to allow a buffer to use + # the same egs with a different num_chunk_per_minibatch + steps/nnet3/multilingual/get_egs.sh \ + --cmd "$train_cmd" \ + --minibatch-size $[num_chunk_per_minibatch * 4] \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_music $dir/egs_multi + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --trainer.compute-per-dim-accuracy=true \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$sad_data_dir \ + --targets-scp="$sad_data_dir/speech_feat.scp" \ + --dir=$dir || exit 1 +fi diff --git a/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_1b.sh b/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_1b.sh new file mode 100644 index 00000000000..cbbb016607a --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_1b.sh @@ -0,0 +1,265 @@ +#!/bin/bash + +# This is a script to train a time-delay neural network for speech activity detection (SAD) and +# music-id using statistic pooling component for long-context information. +# This script is same as 1e, but removes the stats component in the 3rd layer. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +chunk_width=40 +num_chunk_per_minibatch=64 + +extra_left_context=80 +extra_right_context=0 + +relu_dim=256 +cell_dim=256 +projection_dim=64 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +sad_data_dir=data/train_aztec_unsad_whole_corrupted_sp_hires_bp_2400 +music_data_dir=data/train_aztec_unsad_whole_music_corrupted_sp_hires_bp + +extra_egs_copy_cmd="nnet3-copy-egs --keep-outputs=output-speech,output-music ark:- ark:- |" + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=1b + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +num_utts=`cat $sad_data_dir/utt2spk $music_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_lstm_sad_music_snr/nnet_lstm +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$sad_data_dir/feats.scp -` name=input + + relu-renorm-layer name=tdnn1 input=Append(-3,-2,-1,0,1,2,3) dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn2 input=Append(-6,0,6) dim=$relu_dim + fast-lstmp-layer name=lstm1 cell-dim=$cell_dim recurrent-projection-dim=$projection_dim non-recurrent-projection-dim=$projection_dim delay=-6 + relu-renorm-layer name=tdnn3 input=Append(-12,0,12) dim=$relu_dim + + output-layer name=output-snr include-log-softmax=false dim=$num_snr_bins objective-type=quadratic objective-scale=`perl -e "print (($num_frames_music / $num_frames_sad) ** 0.25) / $num_snr_bins"` input=tdnn3 + output-layer name=output-speech include-log-softmax=true dim=2 objective-scale=`perl -e "print (($num_frames_music / $num_frames_sad) ** 0.25)"` input=tdnn3 + output-layer name=output-music include-log-softmax=true dim=2 input=tdnn3 + + output name=output-temp input=Append(input@-3,input@-2,input@-1,input,input@1,input@2, input@3) +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 2 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$sad_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-snr --target-type=dense --targets-scp=$sad_data_dir/irm_targets.scp --deriv-weights-scp=$sad_data_dir/deriv_weights_manual_seg.scp" \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$sad_data_dir/speech_feat.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + fi + + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_music/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_music/storage $dir/egs_music/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$music_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-music --target-type=sparse --dim=2 --targets-scp=$music_data_dir/music_labels.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_music + fi + + if [ $stage -le 4 ]; then + # num_chunk_per_minibatch is multiplied by 4 to allow a buffer to use + # the same egs with a different num_chunk_per_minibatch + steps/nnet3/multilingual/get_egs.sh \ + --cmd "$train_cmd" \ + --minibatch-size $[num_chunk_per_minibatch * 4] \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_music $dir/egs_multi + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --trainer.compute-per-dim-accuracy=true \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$sad_data_dir \ + --targets-scp="$sad_data_dir/speech_feat.scp" \ + --dir=$dir || exit 1 +fi + diff --git a/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_1c.sh b/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_1c.sh new file mode 100644 index 00000000000..53c2a7a47ac --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_1c.sh @@ -0,0 +1,265 @@ +#!/bin/bash + +# This is a script to train a time-delay neural network for speech activity detection (SAD) and +# music-id using statistic pooling component for long-context information. +# This script is same as 1e, but removes the stats component in the 3rd layer. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +chunk_width=40 +num_chunk_per_minibatch=64 + +extra_left_context=80 +extra_right_context=0 + +relu_dim=256 +cell_dim=256 +projection_dim=64 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +sad_data_dir=data/train_aztec_unsad_whole_corrupted_sp_hires_bp_2400 +music_data_dir=data/train_aztec_unsad_whole_music_corrupted_sp_hires_bp + +extra_egs_copy_cmd="nnet3-copy-egs --keep-outputs=output-speech,output-music ark:- ark:- |" + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=1b + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +num_utts=`cat $sad_data_dir/utt2spk $music_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_lstm_sad_music_snr/nnet_lstm +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$sad_data_dir/feats.scp -` name=input + + relu-renorm-layer name=tdnn1 input=Append(-3,-2,-1,0,1,2,3) dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn2 input=Append(-6,0,6) dim=$relu_dim + fast-lstmp-layer name=lstm1 cell-dim=$cell_dim recurrent-projection-dim=$projection_dim non-recurrent-projection-dim=$projection_dim delay=-6 + relu-renorm-layer name=tdnn3 input=Append(-12,0,12) dim=$relu_dim + + output-layer name=output-snr include-log-softmax=false dim=$num_snr_bins objective-type=quadratic learning-rate-factor=0.1 objective-scale=`perl -e "print $speech_scale / $num_snr_bins"` input=tdnn3 + output-layer name=output-speech include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech.txt learning-rate-factor=0.1 objective-scale=$speech_scale input=tdnn3 + output-layer name=output-music include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-music.txt learning-rate-factor=0.1 objective-scale=$music_scale input=tdnn3 + + output name=output-temp input=Append(input@-3,input@-2,input@-1,input,input@1,input@2, input@3) +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 2 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$sad_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-snr --target-type=dense --targets-scp=$sad_data_dir/irm_targets.scp --deriv-weights-scp=$sad_data_dir/deriv_weights_manual_seg.scp" \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$sad_data_dir/speech_feat.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + fi + + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_music/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_music/storage $dir/egs_music/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$music_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-music --target-type=sparse --dim=2 --targets-scp=$music_data_dir/music_labels.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_music + fi + + if [ $stage -le 4 ]; then + # num_chunk_per_minibatch is multiplied by 4 to allow a buffer to use + # the same egs with a different num_chunk_per_minibatch + steps/nnet3/multilingual/get_egs.sh \ + --cmd "$train_cmd" \ + --minibatch-size $[num_chunk_per_minibatch * 4] \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_music $dir/egs_multi + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --trainer.compute-per-dim-accuracy=true \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$sad_data_dir \ + --targets-scp="$sad_data_dir/speech_feat.scp" \ + --dir=$dir || exit 1 +fi + diff --git a/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_1e.sh b/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_1e.sh new file mode 100644 index 00000000000..dfb1297c895 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_1e.sh @@ -0,0 +1,269 @@ +#!/bin/bash + +# This is a script to train a time-delay neural network for speech activity detection (SAD) and +# music-id using statistic pooling component for long-context information. +# This script is same as 1c, but uses larger amount of data. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +chunk_width=40 +num_chunk_per_minibatch=64 + +extra_left_context=40 +extra_right_context=0 + +relu_dim=256 +cell_dim=256 +projection_dim=64 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +extra_egs_copy_cmd="nnet3-copy-egs --keep-outputs=output-speech,output-music ark:- ark:- |" + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=1b + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp" \ + data/train_tztec_whole_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil}_train_whole_corrupted_spr_hires_bp/ + +utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp music_labels.scp" \ + data/train_tztec_whole_music_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_music_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil}_train_whole_music_corrupted_spr_hires_bp/ + +sad_data_dir=data/train_tztec_whole_corrupted_spr_hires_bp +music_data_dir=data/train_tztec_whole_music_corrupted_spr_hires_bp + +num_utts=`cat $sad_data_dir/utt2spk $music_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_lstm_sad_music/nnet_lstm +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$sad_data_dir/feats.scp -` name=input + + relu-renorm-layer name=tdnn1 input=Append(-3,-2,-1,0,1,2,3) dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn2 input=Append(-6,0,6) dim=$relu_dim + fast-lstmp-layer name=lstm1 cell-dim=$cell_dim recurrent-projection-dim=$projection_dim non-recurrent-projection-dim=$projection_dim delay=-6 + relu-renorm-layer name=tdnn3 input=Append(-12,0,12) dim=$relu_dim + + output-layer name=output-speech include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech.txt learning-rate-factor=0.1 objective-scale=$speech_scale input=tdnn3 + output-layer name=output-music include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-music.txt learning-rate-factor=0.1 objective-scale=$music_scale input=tdnn3 + + output name=output-temp input=Append(input@-3,input@-2,input@-1,input,input@1,input@2, input@3) +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 2 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$sad_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$sad_data_dir/speech_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + #--targets-parameters="--output-name=output-snr --target-type=dense --targets-scp=$sad_data_dir/irm_targets.scp --deriv-weights-scp=$sad_data_dir/deriv_weights_manual_seg.scp" \ + fi + + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_music/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_music/storage $dir/egs_music/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$music_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-music --target-type=sparse --dim=2 --targets-scp=$music_data_dir/music_labels.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$music_data_dir/speech_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_music + fi + + if [ $stage -le 4 ]; then + # num_chunk_per_minibatch is multiplied by 4 to allow a buffer to use + # the same egs with a different num_chunk_per_minibatch + steps/nnet3/multilingual/get_egs.sh \ + --cmd "$train_cmd" \ + --minibatch-size $[num_chunk_per_minibatch * 4] \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_music $dir/egs_multi + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --trainer.compute-per-dim-accuracy=true \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$sad_data_dir \ + --targets-scp="$sad_data_dir/speech_labels.scp" \ + --dir=$dir || exit 1 +fi + diff --git a/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_1f.sh b/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_1f.sh new file mode 100644 index 00000000000..782a31132c6 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_1f.sh @@ -0,0 +1,291 @@ +#!/bin/bash + +# This is a script to train a time-delay neural network for speech activity detection (SAD) and +# music-id using statistic pooling component for long-context information. +# This script is same as 1c, but uses larger amount of data. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +chunk_width=40 +num_chunk_per_minibatch=64 + +extra_left_context=40 +extra_right_context=0 + +relu_dim=256 +cell_dim=256 +projection_dim=64 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +extra_egs_copy_cmd="nnet3-copy-egs --keep-outputs=output-speech,output-music,output-speech_music ark:- ark:- |" + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=1b + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp" \ + data/train_tztec_whole_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil}_train_whole_corrupted_spr_hires_bp/ + +cp data/train_tztec_whole_corrupted_spr_hires_bp/{speech_labels.scp,speech_music_labels.scp} + +utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp music_labels.scp speech_music_labels.scp" \ + data/train_tztec_whole_music_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_music_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil}_train_whole_music_corrupted_spr_hires_bp/ + +sad_data_dir=data/train_tztec_whole_corrupted_spr_hires_bp +music_data_dir=data/train_tztec_whole_music_corrupted_spr_hires_bp + +num_utts=`cat $sad_data_dir/utt2spk $music_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_lstm_sad_music/nnet_lstm +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$sad_data_dir/feats.scp -` name=input + + relu-renorm-layer name=tdnn1 input=Append(-3,-2,-1,0,1,2,3) dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn2 input=Append(-6,0,6) dim=$relu_dim + fast-lstmp-layer name=lstm1 cell-dim=$cell_dim recurrent-projection-dim=$projection_dim non-recurrent-projection-dim=$projection_dim delay=-6 + relu-renorm-layer name=tdnn3 input=Append(-12,0,12) dim=$relu_dim + + output-layer name=output-speech include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech.txt learning-rate-factor=0.1 objective-scale=$speech_scale input=tdnn3 + output-layer name=output-music include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-music.txt learning-rate-factor=0.1 objective-scale=$music_scale input=tdnn3 + output-layer name=output-speech_music include-log-softmax=true dim=4 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech_music.txt learning-rate-factor=0.1 objective-scale=$speech_music_scale input=tdnn3 + + output name=output-temp input=Append(input@-3,input@-2,input@-1,input,input@1,input@2, input@3) +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 2 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$sad_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$sad_data_dir/speech_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech_music --target-type=sparse --dim=4 --targets-scp=$sad_data_dir/speech_music_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + #--targets-parameters="--output-name=output-snr --target-type=dense --targets-scp=$sad_data_dir/irm_targets.scp --deriv-weights-scp=$sad_data_dir/deriv_weights_manual_seg.scp" \ + fi + + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_music/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_music/storage $dir/egs_music/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$music_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-music --target-type=sparse --dim=2 --targets-scp=$music_data_dir/music_labels.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech_music --target-type=sparse --dim=4 --targets-scp=$music_data_dir/speech_music_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$music_data_dir/speech_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_music + fi + + if [ $stage -le 4 ]; then + # num_chunk_per_minibatch is multiplied by 4 to allow a buffer to use + # the same egs with a different num_chunk_per_minibatch + steps/nnet3/multilingual/get_egs.sh \ + --cmd "$train_cmd" \ + --minibatch-size $[num_chunk_per_minibatch * 4] \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_music $dir/egs_multi + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --trainer.compute-per-dim-accuracy=true \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$sad_data_dir \ + --targets-scp="$sad_data_dir/speech_labels.scp" \ + --dir=$dir || exit 1 +fi + diff --git a/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_1g.sh b/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_1g.sh new file mode 100644 index 00000000000..eea5956e005 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_1g.sh @@ -0,0 +1,291 @@ +#!/bin/bash + +# This is a script to train a time-delay neural network for speech activity detection (SAD) and +# music-id using statistic pooling component for long-context information. +# This script is same as 1c, but uses larger amount of data. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +chunk_width=40 +num_chunk_per_minibatch=64 + +extra_left_context=40 +extra_right_context=0 + +relu_dim=256 +cell_dim=256 +projection_dim=64 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +extra_egs_copy_cmd="nnet3-copy-egs --keep-outputs=output-speech,output-music,output-speech_music ark:- ark:- |" + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=1g + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp" \ + data/train_tztec_whole_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil}_train_whole_corrupted_spr_hires_bp/ + +cp data/train_tztec_whole_corrupted_spr_hires_bp/{speech_labels.scp,speech_music_labels.scp} + +utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp music_labels.scp speech_music_labels.scp" \ + data/train_tztec_whole_music_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_music_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil}_train_whole_music_corrupted_spr_hires_bp/ + +sad_data_dir=data/train_tztec_whole_corrupted_spr_hires_bp +music_data_dir=data/train_tztec_whole_music_corrupted_spr_hires_bp + +num_utts=`cat $sad_data_dir/utt2spk $music_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_lstm_sad_music/nnet_lstm +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$sad_data_dir/feats.scp -` name=input + + relu-renorm-layer name=tdnn1 input=Append(-3,-2,-1,0,1,2,3) dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn2 input=Append(-6,0,6) dim=$relu_dim + fast-lstmp-layer name=lstm1 cell-dim=$cell_dim recurrent-projection-dim=$projection_dim non-recurrent-projection-dim=$projection_dim delay=-6 + relu-renorm-layer name=tdnn3 input=Append(-12,0,12) dim=$relu_dim + + output-layer name=output-speech include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech.txt learning-rate-factor=0.1 objective-scale=$speech_scale input=tdnn3 + output-layer name=output-music include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-music.txt learning-rate-factor=0.1 objective-scale=$music_scale input=tdnn3 + output-layer name=output-speech_music include-log-softmax=true dim=4 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech_music.txt learning-rate-factor=0.1 objective-scale=$speech_music_scale input=tdnn3 + + output name=output-temp input=Append(input@-3,input@-2,input@-1,input,input@1,input@2, input@3) +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 2 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$sad_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$sad_data_dir/speech_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech_music --target-type=sparse --dim=4 --targets-scp=$sad_data_dir/speech_music_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + #--targets-parameters="--output-name=output-snr --target-type=dense --targets-scp=$sad_data_dir/irm_targets.scp --deriv-weights-scp=$sad_data_dir/deriv_weights_manual_seg.scp" \ + fi + + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_music/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_music/storage $dir/egs_music/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$music_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-music --target-type=sparse --dim=2 --targets-scp=$music_data_dir/music_labels.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech_music --target-type=sparse --dim=4 --targets-scp=$music_data_dir/speech_music_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$music_data_dir/speech_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_music + fi + + if [ $stage -le 4 ]; then + # num_chunk_per_minibatch is multiplied by 4 to allow a buffer to use + # the same egs with a different num_chunk_per_minibatch + steps/nnet3/multilingual/get_egs.sh \ + --cmd "$train_cmd" \ + --minibatch-size $[num_chunk_per_minibatch * 4] \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_music $dir/egs_multi + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --trainer.compute-per-dim-accuracy=true \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$sad_data_dir \ + --targets-scp="$sad_data_dir/speech_labels.scp" \ + --dir=$dir || exit 1 +fi + diff --git a/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_1h.sh b/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_1h.sh new file mode 100644 index 00000000000..d9e1966bf6a --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_1h.sh @@ -0,0 +1,291 @@ +#!/bin/bash + +# This is a script to train a time-delay neural network for speech activity detection (SAD) and +# music-id using statistic pooling component for long-context information. +# This script is same as 1c, but uses larger amount of data. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +chunk_width=40 +num_chunk_per_minibatch=64 + +extra_left_context=40 +extra_right_context=0 + +relu_dim=256 +cell_dim=256 +projection_dim=64 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +extra_egs_copy_cmd="nnet3-copy-egs --keep-outputs=output-speech,output-music,output-speech_music ark:- ark:- |" + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=1h + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp" \ + data/train_tztec_whole_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil}_train_whole_corrupted_spr_hires_bp/ + +cp data/train_tztec_whole_corrupted_spr_hires_bp/{speech_labels.scp,speech_music_labels.scp} + +utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp music_labels.scp speech_music_labels.scp" \ + data/train_tztec_whole_music_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_music_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil}_train_whole_music_corrupted_spr_hires_bp/ + +sad_data_dir=data/train_tztec_whole_corrupted_spr_hires_bp +music_data_dir=data/train_tztec_whole_music_corrupted_spr_hires_bp + +num_utts=`cat $sad_data_dir/utt2spk $music_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_lstm_sad_music/nnet_lstm +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$sad_data_dir/feats.scp -` name=input + + relu-renorm-layer name=tdnn1 input=Append(-3,-2,-1,0,1,2,3) dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn2 input=Append(-6,0,6) dim=$relu_dim + fast-lstmp-layer name=lstm1 cell-dim=$cell_dim recurrent-projection-dim=$projection_dim non-recurrent-projection-dim=$projection_dim delay=-6 + relu-renorm-layer name=tdnn3 input=Append(-12,0,12) dim=$relu_dim + + output-layer name=output-speech include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech.txt learning-rate-factor=0.1 objective-scale=$speech_scale input=tdnn3 + output-layer name=output-music include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-music.txt learning-rate-factor=0.1 objective-scale=$music_scale input=tdnn3 + output-layer name=output-speech_music include-log-softmax=true dim=4 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech_music.txt learning-rate-factor=0.1 objective-scale=$speech_music_scale input=tdnn3 + + output name=output-temp input=Append(input@-3,input@-2,input@-1,input,input@1,input@2, input@3) +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 2 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$sad_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$sad_data_dir/speech_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech_music --target-type=sparse --dim=4 --targets-scp=$sad_data_dir/speech_music_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + #--targets-parameters="--output-name=output-snr --target-type=dense --targets-scp=$sad_data_dir/irm_targets.scp --deriv-weights-scp=$sad_data_dir/deriv_weights_manual_seg.scp" \ + fi + + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_music/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_music/storage $dir/egs_music/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$music_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-music --target-type=sparse --dim=2 --targets-scp=$music_data_dir/music_labels.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech_music --target-type=sparse --dim=4 --targets-scp=$music_data_dir/speech_music_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$music_data_dir/speech_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_music + fi + + if [ $stage -le 4 ]; then + # num_chunk_per_minibatch is multiplied by 4 to allow a buffer to use + # the same egs with a different num_chunk_per_minibatch + steps/nnet3/multilingual/get_egs.sh \ + --cmd "$train_cmd" \ + --minibatch-size $num_chunk_per_minibatch \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_music $dir/egs_multi + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --trainer.compute-per-dim-accuracy=true \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$sad_data_dir \ + --targets-scp="$sad_data_dir/speech_labels.scp" \ + --dir=$dir || exit 1 +fi + diff --git a/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_1i.sh b/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_1i.sh new file mode 100644 index 00000000000..be568eefd97 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_1i.sh @@ -0,0 +1,308 @@ +#!/bin/bash + +# This is a script to train a TDNN-LSTM for speech activity detection (SAD) and +# music-id using LSTM for long-context information. +# This is same as 1h, but has more layers. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +chunk_width=20 +num_chunk_per_minibatch=64 + +extra_left_context=40 +extra_right_context=0 + +relu_dim=256 +cell_dim=256 +projection_dim=64 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +extra_egs_copy_cmd="nnet3-copy-egs --keep-outputs=output-speech,output-music,output-speech_music ark:- ark:- |" + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=1i + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if [ $stage -le -1 ]; then + utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp irm_targets.scp deriv_weights_for_irm_targets.scp" \ + data/train_tztec_whole_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil}_train_whole_corrupted_spr_hires_bp/ + + cp data/train_tztec_whole_corrupted_spr_hires_bp/{speech_labels.scp,speech_music_labels.scp} + + utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp music_labels.scp speech_music_labels.scp" \ + data/train_tztec_whole_music_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_music_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil}_train_whole_music_corrupted_spr_hires_bp/ +fi + +sad_data_dir=data/train_tztec_whole_corrupted_spr_hires_bp +music_data_dir=data/train_tztec_whole_music_corrupted_spr_hires_bp + +num_utts=`cat $sad_data_dir/utt2spk $music_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_lstm_sad_music/nnet_lstm +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/scales +fi + +if [ $stage -le 2 ]; then + echo "$0: creating neural net configs using the xconfig parser"; + + scales=`cat $dir/scales` + + speech_scale=`echo $scales | awk '{print $1}'` + music_scale=`echo $scales | awk '{print $2}'` + speech_music_scale=`echo $scales | awk '{print $3}'` + + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$sad_data_dir/feats.scp -` name=input + + relu-renorm-layer name=tdnn1 input=Append(-2,-1,0,1,2) dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn2 input=Append(-1,0,1,2) dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn3 input=Append(-3,0,3,6) dim=$relu_dim add-log-stddev=true + fast-lstmp-layer name=lstm1 cell-dim=$cell_dim recurrent-projection-dim=$projection_dim non-recurrent-projection-dim=$projection_dim delay=-6 + relu-renorm-layer name=tdnn4 input=Append(-6,0,6,12) dim=$relu_dim + relu-renorm-layer name=tdnn5 input=Append(-12,0,12,24) dim=$relu_dim + + output-layer name=output-speech include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech.txt learning-rate-factor=0.1 objective-scale=$speech_scale input=tdnn5 + output-layer name=output-music include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-music.txt learning-rate-factor=0.1 objective-scale=$music_scale input=tdnn5 + output-layer name=output-speech_music include-log-softmax=true dim=4 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech_music.txt learning-rate-factor=0.1 objective-scale=$speech_music_scale input=tdnn5 + + output name=output-temp input=Append(input@-3,input@-2,input@-1,input,input@1,input@2, input@3) +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$sad_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$sad_data_dir/speech_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech_music --target-type=sparse --dim=4 --targets-scp=$sad_data_dir/speech_music_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + fi + + if [ $stage -le 4 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_music/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_music/storage $dir/egs_music/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$music_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-music --target-type=sparse --dim=2 --targets-scp=$music_data_dir/music_labels.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech_music --target-type=sparse --dim=4 --targets-scp=$music_data_dir/speech_music_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$music_data_dir/speech_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_music + fi + + if [ $stage -le 5 ]; then + # num_chunk_per_minibatch is multiplied by 4 to allow a buffer to use + # the same egs with a different num_chunk_per_minibatch + steps/nnet3/multilingual/get_egs.sh \ + --cmd "$train_cmd" \ + --minibatch-size $num_chunk_per_minibatch \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_music $dir/egs_multi + fi +fi + +if [ $stage -le 6 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --trainer.compute-per-dim-accuracy=true \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$sad_data_dir \ + --targets-scp="$sad_data_dir/speech_labels.scp" \ + --dir=$dir || exit 1 +fi + + diff --git a/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_snr_1h.sh b/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_snr_1h.sh new file mode 100644 index 00000000000..ae85a93a7fc --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_snr_1h.sh @@ -0,0 +1,306 @@ +#!/bin/bash + +# This is a script to train a TDNN-LSTM for speech activity detection (SAD) and +# music-id using LSTM for long-context information. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +chunk_width=40 +num_chunk_per_minibatch=64 + +extra_left_context=40 +extra_right_context=0 + +relu_dim=256 +cell_dim=256 +projection_dim=64 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +extra_egs_copy_cmd="nnet3-copy-egs --keep-outputs=output-speech,output-music,output-speech_music,output-snr ark:- ark:- |" + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=1h + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if [ $stage -le -1 ]; then + utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp irm_targets.scp deriv_weights_for_irm_targets.scp" \ + data/train_tztec_whole_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil}_train_whole_corrupted_spr_hires_bp/ + + cp data/train_tztec_whole_corrupted_spr_hires_bp/{speech_labels.scp,speech_music_labels.scp} + + utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp music_labels.scp speech_music_labels.scp" \ + data/train_tztec_whole_music_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_music_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil}_train_whole_music_corrupted_spr_hires_bp/ +fi + +sad_data_dir=data/train_tztec_whole_corrupted_spr_hires_bp +music_data_dir=data/train_tztec_whole_music_corrupted_spr_hires_bp + +num_utts=`cat $sad_data_dir/utt2spk $music_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_lstm_sad_music_snr/nnet_lstm +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$sad_data_dir/feats.scp -` name=input + + relu-renorm-layer name=tdnn1 input=Append(-3,-2,-1,0,1,2,3) dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn2 input=Append(-6,0,6) dim=$relu_dim + fast-lstmp-layer name=lstm1 cell-dim=$cell_dim recurrent-projection-dim=$projection_dim non-recurrent-projection-dim=$projection_dim delay=-6 + relu-renorm-layer name=tdnn3 input=Append(-12,0,12) dim=$relu_dim + relu-renorm-layer name=tdnn3-snr input=Append(lstm1@-12,lstm1@0,lstm1@12,tdnn3) dim=$relu_dim + + output-layer name=output-speech include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech.txt learning-rate-factor=0.1 objective-scale=$speech_scale input=tdnn3 + output-layer name=output-music include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-music.txt learning-rate-factor=0.1 objective-scale=$music_scale input=tdnn3 + output-layer name=output-speech_music include-log-softmax=true dim=4 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech_music.txt learning-rate-factor=0.1 objective-scale=$speech_music_scale input=tdnn3 + output-layer name=output-snr include-log-softmax=false dim=$num_snr_bins objective-type=quadratic learning-rate-factor=0.1 objective-scale=$snr_scale input=tdnn3-snr + + output name=output-temp input=Append(input@-3,input@-2,input@-1,input,input@1,input@2, input@3) +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 2 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$sad_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$sad_data_dir/speech_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech_music --target-type=sparse --dim=4 --targets-scp=$sad_data_dir/speech_music_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-snr --target-type=dense --targets-scp=$sad_data_dir/irm_targets.scp --deriv-weights-scp=$sad_data_dir/deriv_weights_for_irm_targets.scp" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + fi + + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_music/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_music/storage $dir/egs_music/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$music_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-music --target-type=sparse --dim=2 --targets-scp=$music_data_dir/music_labels.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech_music --target-type=sparse --dim=4 --targets-scp=$music_data_dir/speech_music_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$music_data_dir/speech_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_music + fi + + if [ $stage -le 4 ]; then + # num_chunk_per_minibatch is multiplied by 4 to allow a buffer to use + # the same egs with a different num_chunk_per_minibatch + steps/nnet3/multilingual/get_egs.sh \ + --cmd "$train_cmd" \ + --minibatch-size $num_chunk_per_minibatch \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_music $dir/egs_multi + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --trainer.compute-per-dim-accuracy=true \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$sad_data_dir \ + --targets-scp="$sad_data_dir/speech_labels.scp" \ + --dir=$dir || exit 1 +fi + + diff --git a/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_snr_1i.sh b/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_snr_1i.sh new file mode 100644 index 00000000000..b6c43a92992 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_snr_1i.sh @@ -0,0 +1,315 @@ +#!/bin/bash + +# This is a script to train a TDNN-LSTM for speech activity detection (SAD) and +# music-id using LSTM for long-context information. +# This is same as 1h, but has more layers. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +chunk_width=20 +num_chunk_per_minibatch=64 + +extra_left_context=40 +extra_right_context=0 + +relu_dim=256 +cell_dim=256 +projection_dim=64 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +extra_egs_copy_cmd="nnet3-copy-egs --keep-outputs=output-speech,output-music,output-speech_music,output-snr ark:- ark:- |" + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=1i + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if [ $stage -le -1 ]; then + utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp irm_targets.scp deriv_weights_for_irm_targets.scp" \ + data/train_tztec_whole_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil}_train_whole_corrupted_spr_hires_bp/ + + cp data/train_tztec_whole_corrupted_spr_hires_bp/{speech_labels.scp,speech_music_labels.scp} + + utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp music_labels.scp speech_music_labels.scp" \ + data/train_tztec_whole_music_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_music_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil}_train_whole_music_corrupted_spr_hires_bp/ +fi + +sad_data_dir=data/train_tztec_whole_corrupted_spr_hires_bp +music_data_dir=data/train_tztec_whole_music_corrupted_spr_hires_bp + +num_utts=`cat $sad_data_dir/utt2spk $music_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_lstm_sad_music_snr/nnet_lstm +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/scales +fi + +if [ $stage -le 2 ]; then + echo "$0: creating neural net configs using the xconfig parser"; + + scales=`cat $dir/scales` + + speech_scale=`echo $scales | awk '{print $1}'` + music_scale=`echo $scales | awk '{print $2}'` + speech_music_scale=`echo $scales | awk '{print $3}'` + snr_scale=`echo $scales | awk '{print $4}'` + + num_snr_bins=`feat-to-dim scp:$sad_data_dir/irm_targets.scp -` + snr_scale=`perl -e "print $snr_scale / $num_snr_bins"` + + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$sad_data_dir/feats.scp -` name=input + + relu-renorm-layer name=tdnn1 input=Append(-2,-1,0,1,2) dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn2 input=Append(-1,0,1,2) dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn3 input=Append(-3,0,3,6) dim=$relu_dim add-log-stddev=true + fast-lstmp-layer name=lstm1 cell-dim=$cell_dim recurrent-projection-dim=$projection_dim non-recurrent-projection-dim=$projection_dim delay=-6 + relu-renorm-layer name=tdnn4 input=Append(-6,0,6,12) dim=$relu_dim + relu-renorm-layer name=tdnn5 input=Append(-12,0,12,24) dim=$relu_dim + relu-renorm-layer name=tdnn5-snr input=Append(lstm1@-6,lstm1@0,lstm1@6,lstm1@12,tdnn5) dim=$relu_dim + + output-layer name=output-speech include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech.txt learning-rate-factor=0.1 objective-scale=$speech_scale input=tdnn5 + output-layer name=output-music include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-music.txt learning-rate-factor=0.1 objective-scale=$music_scale input=tdnn5 + output-layer name=output-speech_music include-log-softmax=true dim=4 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech_music.txt learning-rate-factor=0.1 objective-scale=$speech_music_scale input=tdnn5 + output-layer name=output-snr include-log-softmax=false dim=$num_snr_bins objective-type=quadratic learning-rate-factor=0.1 objective-scale=$snr_scale input=tdnn5-snr + + output name=output-temp input=Append(input@-3,input@-2,input@-1,input,input@1,input@2, input@3) +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$sad_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$sad_data_dir/speech_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech_music --target-type=sparse --dim=4 --targets-scp=$sad_data_dir/speech_music_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-snr --target-type=dense --targets-scp=$sad_data_dir/irm_targets.scp --deriv-weights-scp=$sad_data_dir/deriv_weights_for_irm_targets.scp" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + fi + + if [ $stage -le 4 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_music/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_music/storage $dir/egs_music/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$music_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-music --target-type=sparse --dim=2 --targets-scp=$music_data_dir/music_labels.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech_music --target-type=sparse --dim=4 --targets-scp=$music_data_dir/speech_music_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$music_data_dir/speech_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_music + fi + + if [ $stage -le 5 ]; then + # num_chunk_per_minibatch is multiplied by 4 to allow a buffer to use + # the same egs with a different num_chunk_per_minibatch + steps/nnet3/multilingual/get_egs.sh \ + --cmd "$train_cmd" \ + --minibatch-size $num_chunk_per_minibatch \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_music $dir/egs_multi + fi +fi + +if [ $stage -le 6 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --trainer.compute-per-dim-accuracy=true \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$sad_data_dir \ + --targets-scp="$sad_data_dir/speech_labels.scp" \ + --dir=$dir || exit 1 +fi + + diff --git a/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_snr_1j.sh b/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_snr_1j.sh new file mode 100644 index 00000000000..bf397565148 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_snr_1j.sh @@ -0,0 +1,312 @@ +#!/bin/bash + +# This is a script to train a TDNN-LSTM for speech activity detection (SAD) and +# music-id using LSTM for long-context information. +# This is same as 1i, but removes the speech-music output. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +chunk_width=20 +num_chunk_per_minibatch=64 + +extra_left_context=40 +extra_right_context=0 + +relu_dim=256 +cell_dim=256 +projection_dim=64 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +extra_egs_copy_cmd="nnet3-copy-egs --keep-outputs=output-speech,output-music,output-snr ark:- ark:- |" + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=1j + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if [ $stage -le -1 ]; then + utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp irm_targets.scp deriv_weights_for_irm_targets.scp" \ + data/train_tztec_whole_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil}_train_whole_corrupted_spr_hires_bp/ + + cp data/train_tztec_whole_corrupted_spr_hires_bp/{speech_labels.scp,speech_music_labels.scp} + + utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp music_labels.scp speech_music_labels.scp" \ + data/train_tztec_whole_music_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_music_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil}_train_whole_music_corrupted_spr_hires_bp/ +fi + +sad_data_dir=data/train_tztec_whole_corrupted_spr_hires_bp +music_data_dir=data/train_tztec_whole_music_corrupted_spr_hires_bp + +num_utts=`cat $sad_data_dir/utt2spk $music_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_lstm_sad_music_snr/nnet_lstm +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/scales +fi + +if [ $stage -le 2 ]; then + echo "$0: creating neural net configs using the xconfig parser"; + + scales=`cat $dir/scales` + + speech_scale=`echo $scales | awk '{print $1}'` + music_scale=`echo $scales | awk '{print $2}'` + speech_music_scale=`echo $scales | awk '{print $3}'` + snr_scale=`echo $scales | awk '{print $4}'` + + num_snr_bins=`feat-to-dim scp:$sad_data_dir/irm_targets.scp -` + snr_scale=`perl -e "print $snr_scale / $num_snr_bins"` + + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$sad_data_dir/feats.scp -` name=input + + relu-renorm-layer name=tdnn1 input=Append(-2,-1,0,1,2) dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn2 input=Append(-1,0,1,2) dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn3 input=Append(-3,0,3,6) dim=$relu_dim add-log-stddev=true + fast-lstmp-layer name=lstm1 cell-dim=$cell_dim recurrent-projection-dim=$projection_dim non-recurrent-projection-dim=$projection_dim delay=-6 + relu-renorm-layer name=tdnn4 input=Append(-6,0,6,12) dim=$relu_dim + relu-renorm-layer name=tdnn5 input=Append(-12,0,12,24) dim=$relu_dim + relu-renorm-layer name=tdnn5-snr input=Append(lstm1@-6,lstm1@0,lstm1@6,lstm1@12,tdnn5) dim=$relu_dim + + output-layer name=output-speech include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech.txt learning-rate-factor=0.1 objective-scale=$speech_scale input=tdnn5 + output-layer name=output-music include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-music.txt learning-rate-factor=0.1 objective-scale=$music_scale input=tdnn5 + output-layer name=output-snr include-log-softmax=false dim=$num_snr_bins objective-type=quadratic learning-rate-factor=0.1 objective-scale=$snr_scale input=tdnn5-snr + + output name=output-temp input=Append(input@-3,input@-2,input@-1,input,input@1,input@2, input@3) +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$sad_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$sad_data_dir/speech_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech_music --target-type=sparse --dim=4 --targets-scp=$sad_data_dir/speech_music_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-snr --target-type=dense --targets-scp=$sad_data_dir/irm_targets.scp --deriv-weights-scp=$sad_data_dir/deriv_weights_for_irm_targets.scp" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + fi + + if [ $stage -le 4 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_music/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_music/storage $dir/egs_music/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$music_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-music --target-type=sparse --dim=2 --targets-scp=$music_data_dir/music_labels.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech_music --target-type=sparse --dim=4 --targets-scp=$music_data_dir/speech_music_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$music_data_dir/speech_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_music + fi + + if [ $stage -le 5 ]; then + # num_chunk_per_minibatch is multiplied by 4 to allow a buffer to use + # the same egs with a different num_chunk_per_minibatch + steps/nnet3/multilingual/get_egs.sh \ + --cmd "$train_cmd" \ + --minibatch-size $num_chunk_per_minibatch \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_music $dir/egs_multi + fi +fi + +if [ $stage -le 6 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --trainer.compute-per-dim-accuracy=true \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$sad_data_dir \ + --targets-scp="$sad_data_dir/speech_labels.scp" \ + --dir=$dir || exit 1 +fi diff --git a/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_snr_1k.sh b/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_snr_1k.sh new file mode 100644 index 00000000000..cb585523f74 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_snr_1k.sh @@ -0,0 +1,316 @@ +#!/bin/bash + +# This is a script to train a TDNN-LSTM for speech activity detection (SAD) and +# music-id using LSTM for long-context information. +# This is same as 1h, but has more layers. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +chunk_width=20 +num_chunk_per_minibatch=64 + +extra_left_context=40 +extra_right_context=0 + +relu_dim=256 +cell_dim=256 +projection_dim=64 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +extra_egs_copy_cmd="nnet3-copy-egs --keep-outputs=output-speech,output-music,output-speech_music,output-snr ark:- ark:- |" + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=1k + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if [ $stage -le -1 ]; then + utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp irm_targets.scp deriv_weights_for_irm_targets.scp" \ + data/train_tztec_whole_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil}_train_whole_corrupted_spr_hires_bp/ + + cp data/train_tztec_whole_corrupted_spr_hires_bp/{speech_labels.scp,speech_music_labels.scp} + + utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp music_labels.scp speech_music_labels.scp" \ + data/train_tztec_whole_music_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_music_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil}_train_whole_music_corrupted_spr_hires_bp/ +fi + +sad_data_dir=data/train_tztec_whole_corrupted_spr_hires_bp +music_data_dir=data/train_tztec_whole_music_corrupted_spr_hires_bp + +num_utts=`cat $sad_data_dir/utt2spk $music_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_lstm_sad_music_snr/nnet_lstm +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/scales +fi + +if [ $stage -le 2 ]; then + echo "$0: creating neural net configs using the xconfig parser"; + + scales=`cat $dir/scales` + + speech_scale=`echo $scales | awk '{print $1}'` + music_scale=`echo $scales | awk '{print $2}'` + speech_music_scale=`echo $scales | awk '{print $3}'` + snr_scale=`echo $scales | awk '{print $4}'` + + num_snr_bins=`feat-to-dim scp:$sad_data_dir/irm_targets.scp -` + snr_scale=`perl -e "print $snr_scale / $num_snr_bins"` + + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$sad_data_dir/feats.scp -` name=input + + relu-renorm-layer name=tdnn1 input=Append(-2,-1,0,1,2) dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn2 input=Append(-1,0,1,2) dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn3 input=Append(-3,0,3,6) dim=$relu_dim add-log-stddev=true + fast-lstmp-layer name=lstm1 cell-dim=$cell_dim recurrent-projection-dim=$projection_dim non-recurrent-projection-dim=$projection_dim delay=-6 + relu-renorm-layer name=tdnn4 input=Append(-6,0,6,12) dim=$relu_dim + fast-lstmp-layer name=lstm2 cell-dim=$cell_dim recurrent-projection-dim=$projection_dim non-recurrent-projection-dim=$projection_dim delay=-6 + relu-renorm-layer name=tdnn5 input=Append(-12,0,12,24) dim=$relu_dim + relu-renorm-layer name=tdnn5-snr input=Append(lstm2@-6,lstm2@0,lstm2@6,lstm2@12,tdnn5) dim=$relu_dim + + output-layer name=output-speech include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech.txt learning-rate-factor=0.1 objective-scale=$speech_scale input=tdnn5 + output-layer name=output-music include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-music.txt learning-rate-factor=0.1 objective-scale=$music_scale input=tdnn5 + output-layer name=output-speech_music include-log-softmax=true dim=4 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech_music.txt learning-rate-factor=0.1 objective-scale=$speech_music_scale input=tdnn5 + output-layer name=output-snr include-log-softmax=false dim=$num_snr_bins objective-type=quadratic learning-rate-factor=0.1 objective-scale=$snr_scale input=tdnn5-snr + + output name=output-temp input=Append(input@-3,input@-2,input@-1,input,input@1,input@2, input@3) +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$sad_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$sad_data_dir/speech_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech_music --target-type=sparse --dim=4 --targets-scp=$sad_data_dir/speech_music_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-snr --target-type=dense --targets-scp=$sad_data_dir/irm_targets.scp --deriv-weights-scp=$sad_data_dir/deriv_weights_for_irm_targets.scp" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + fi + + if [ $stage -le 4 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_music/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_music/storage $dir/egs_music/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$music_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-music --target-type=sparse --dim=2 --targets-scp=$music_data_dir/music_labels.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech_music --target-type=sparse --dim=4 --targets-scp=$music_data_dir/speech_music_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$music_data_dir/speech_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_music + fi + + if [ $stage -le 5 ]; then + # num_chunk_per_minibatch is multiplied by 4 to allow a buffer to use + # the same egs with a different num_chunk_per_minibatch + steps/nnet3/multilingual/get_egs.sh \ + --cmd "$train_cmd" \ + --minibatch-size $num_chunk_per_minibatch \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_music $dir/egs_multi + fi +fi + +if [ $stage -le 6 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --trainer.compute-per-dim-accuracy=true \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$sad_data_dir \ + --targets-scp="$sad_data_dir/speech_labels.scp" \ + --dir=$dir || exit 1 +fi + + diff --git a/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_snr_1l.sh b/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_snr_1l.sh new file mode 100644 index 00000000000..d8910053e61 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_lstm_sad_music_snr_1l.sh @@ -0,0 +1,316 @@ +#!/bin/bash + +# This is a script to train a TDNN-LSTM for speech activity detection (SAD) and +# music-id using LSTM for long-context information. +# This is same as 1h, but has more layers. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +chunk_width=20 +num_chunk_per_minibatch=64 + +extra_left_context=40 +extra_right_context=0 + +relu_dim=256 +cell_dim=256 +projection_dim=64 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +extra_egs_copy_cmd="nnet3-copy-egs --keep-outputs=output-speech,output-music,output-speech_music,output-snr ark:- ark:- |" + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=1k + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if [ $stage -le -1 ]; then + utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp irm_targets.scp deriv_weights_for_irm_targets.scp" \ + data/train_tztec_whole_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil}_train_whole_corrupted_spr_hires_bp/ + + cp data/train_tztec_whole_corrupted_spr_hires_bp/{speech_labels.scp,speech_music_labels.scp} + + utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp music_labels.scp speech_music_labels.scp" \ + data/train_tztec_whole_music_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_music_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil}_train_whole_music_corrupted_spr_hires_bp/ +fi + +sad_data_dir=data/train_tztec_whole_corrupted_spr_hires_bp +music_data_dir=data/train_tztec_whole_music_corrupted_spr_hires_bp + +num_utts=`cat $sad_data_dir/utt2spk $music_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_lstm_sad_music_snr/nnet_lstm +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/scales +fi + +if [ $stage -le 2 ]; then + echo "$0: creating neural net configs using the xconfig parser"; + + scales=`cat $dir/scales` + + speech_scale=`echo $scales | awk '{print $1}'` + music_scale=`echo $scales | awk '{print $2}'` + speech_music_scale=`echo $scales | awk '{print $3}'` + snr_scale=`echo $scales | awk '{print $4}'` + + num_snr_bins=`feat-to-dim scp:$sad_data_dir/irm_targets.scp -` + snr_scale=`perl -e "print $snr_scale / $num_snr_bins"` + + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$sad_data_dir/feats.scp -` name=input + + relu-renorm-layer name=tdnn1 input=Append(-2,-1,0,1,2) dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn2 input=Append(-1,0,1,2) dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn3 input=Append(-3,0,3,6) dim=$relu_dim add-log-stddev=true + fast-lstmp-layer name=lstm1 cell-dim=$cell_dim recurrent-projection-dim=$projection_dim non-recurrent-projection-dim=$projection_dim delay=-3 + relu-renorm-layer name=tdnn4 input=Append(-6,0,6,12) dim=$relu_dim + fast-lstmp-layer name=lstm2 cell-dim=$cell_dim recurrent-projection-dim=$projection_dim non-recurrent-projection-dim=$projection_dim delay=-6 + relu-renorm-layer name=tdnn5 input=Append(-12,0,12,24) dim=$relu_dim + relu-renorm-layer name=tdnn5-snr input=Append(lstm2@-6,lstm2@0,lstm2@6,lstm2@12,tdnn5) dim=$relu_dim + + output-layer name=output-speech include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech.txt learning-rate-factor=0.1 objective-scale=$speech_scale input=tdnn5 + output-layer name=output-music include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-music.txt learning-rate-factor=0.1 objective-scale=$music_scale input=tdnn5 + output-layer name=output-speech_music include-log-softmax=true dim=4 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech_music.txt learning-rate-factor=0.1 objective-scale=$speech_music_scale input=tdnn5 + output-layer name=output-snr include-log-softmax=false dim=$num_snr_bins objective-type=quadratic learning-rate-factor=0.1 objective-scale=$snr_scale input=tdnn5-snr + + output name=output-temp input=Append(input@-3,input@-2,input@-1,input,input@1,input@2, input@3) +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$sad_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$sad_data_dir/speech_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech_music --target-type=sparse --dim=4 --targets-scp=$sad_data_dir/speech_music_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-snr --target-type=dense --targets-scp=$sad_data_dir/irm_targets.scp --deriv-weights-scp=$sad_data_dir/deriv_weights_for_irm_targets.scp" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + fi + + if [ $stage -le 4 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_music/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_music/storage $dir/egs_music/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$music_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-music --target-type=sparse --dim=2 --targets-scp=$music_data_dir/music_labels.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech_music --target-type=sparse --dim=4 --targets-scp=$music_data_dir/speech_music_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$music_data_dir/speech_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_music + fi + + if [ $stage -le 5 ]; then + # num_chunk_per_minibatch is multiplied by 4 to allow a buffer to use + # the same egs with a different num_chunk_per_minibatch + steps/nnet3/multilingual/get_egs.sh \ + --cmd "$train_cmd" \ + --minibatch-size $num_chunk_per_minibatch \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_music $dir/egs_multi + fi +fi + +if [ $stage -le 6 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --trainer.compute-per-dim-accuracy=true \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$sad_data_dir \ + --targets-scp="$sad_data_dir/speech_labels.scp" \ + --dir=$dir || exit 1 +fi + + diff --git a/egs/aspire/s5/local/segmentation/tuning/train_lstm_stats_overlap_1a.sh b/egs/aspire/s5/local/segmentation/tuning/train_lstm_stats_overlap_1a.sh new file mode 100755 index 00000000000..adc4fc81c08 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_lstm_stats_overlap_1a.sh @@ -0,0 +1,202 @@ +#!/bin/bash + +# This is a script to train a time-delay neural network for overlapped speech activity detection +# using statistic pooling component for long-context information. + +# This scripts is similar to 1f but adds max-change=0.75 and learning-rate-factor=0.02 to the final affine. +# And changed relu-dim to 512. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +# TDNN options +chunk_width=40 # We use chunk training for training TDNN +num_chunk_per_minibatch=64 + +extra_left_context=100 # Maximum left context in egs apart from TDNN's left context +extra_right_context=20 # Maximum right context in egs apart from TDNN's right context + +# We randomly select an extra {left,right} context for each job between +# min_extra_*_context and extra_*_context so that the network can get used +# to different contexts used to compute statistics. +min_extra_left_context=20 +min_extra_right_context=0 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd="nnet3-copy-egs --keep-outputs=output-overlapped_speech ark:- ark:- |" # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +ovlp_data_dir=data/train_aztec_unsad_seg_ovlp_corrupted_hires_bp + +#extra_left_context=79 +#extra_right_context=11 + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=f + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +num_utts=`cat $ovlp_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_stats_ovlp/nnet_tdnn +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$ovlp_data_dir/feats.scp -` name=input + output name=output-temp input=Append(-3,-2,-1,0,1,2,3) + + relu-renorm-layer name=tdnn1 input=Append(input@-3, input@-2, input@-1, input, input@1, input@2, input@3) dim=512 + stats-layer name=tdnn2_stats config=mean+count(-99:3:9:99) + relu-renorm-layer name=tdnn2 input=Append(tdnn1@-6, tdnn1, tdnn2_stats) dim=512 + relu-renorm-layer name=tdnn3 input=Append(-9,0,3) dim=512 + relu-renorm-layer name=tdnn4 dim=512 + + output-layer name=output-overlapped_speech include-log-softmax=true dim=2 input=tdnn4 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-overlapped_speech.txt max-change=0.75 learning-rate-factor=0.02 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-overlapped_speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_ovlp + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_ovlp/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_ovlp/storage $dir/egs_ovlp/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$ovlp_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$ovlp_data_dir/speech_feat.scp --deriv-weights-scp=$ovlp_data_dir/deriv_weights.scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\"" \ + --targets-parameters="--output-name=output-overlapped_speech --target-type=sparse --dim=2 --targets-scp=$ovlp_data_dir/overlapped_speech_labels_fixed.scp --deriv-weights-scp=$ovlp_data_dir/deriv_weights_for_overlapped_speech.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\"" \ + --generate-egs-scp=true \ + --dir=$dir/egs_ovlp + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.min-chunk-left-context=$min_extra_left_context \ + --trainer.min-chunk-right-context=$min_extra_right_context \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$ovlp_data_dir \ + --targets-scp="$ovlp_data_dir/overlapped_spech_labels.scp" \ + --dir=$dir || exit 1 +fi diff --git a/egs/aspire/s5/local/segmentation/tuning/train_lstm_stats_sad_overlap_1a.sh b/egs/aspire/s5/local/segmentation/tuning/train_lstm_stats_sad_overlap_1a.sh new file mode 100755 index 00000000000..52a15686d28 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_lstm_stats_sad_overlap_1a.sh @@ -0,0 +1,259 @@ +#!/bin/bash + +# This is a script to train a LSTM for overlapped speech activity detection +# and SAD. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +chunk_width=40 +num_chunk_per_minibatch=64 + +extra_left_context=40 # Maximum left context in egs apart from TDNN's left context +extra_right_context=0 # Maximum right context in egs apart from TDNN's right context + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +sad_data_dir=data/train_aztec_unsad_whole_corrupted_sp_hires_bp_2400 +ovlp_data_dir=data/train_aztec_unsad_seg_ovlp_corrupted_hires_bp + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=a + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +num_utts=`cat $sad_data_dir/utt2spk $ovlp_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_stats_sad_ovlp_snr/nnet_lstm +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$sad_data_dir/feats.scp -` name=input + output name=output-temp input=Append(-2,-1,0,1,2) + + relu-renorm-layer name=tdnn1 input=Append(input@-2, input@-1, input, input@1, input@2) dim=512 + lstmp-layer name=lstm1 cell-dim=512 recurrent-projection-dim=256 non-recurrent-projection-dim=256 delay=-3 + relu-renorm-layer name=tdnn2 input=Append(-6,0,6) dim=512 + lstmp-layer name=lstm2 cell-dim=512 recurrent-projection-dim=256 non-recurrent-projection-dim=256 delay=-6 + + output-layer name=output-speech include-log-softmax=true dim=2 objective-scale=$speech_scale input=lstm2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech.txt learning-rate-factor=0.05 + + output-layer name=output-snr include-log-softmax=false dim=$num_snr_bins objective-type=quadratic objective-scale=`perl -e "print $speech_scale / $num_snr_bins"` input=lstm2 max-change=0.75 learning-rate-factor=0.5 + + output-layer name=output-overlapped_speech include-log-softmax=true dim=2 objective-scale=$ovlp_scale input=lstm2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-overlapped_speech.txt max-change=0.75 learning-rate-factor=0.02 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 2 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{01,02,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$sad_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-snr --target-type=dense --targets-scp=$sad_data_dir/irm_targets.scp --deriv-weights-scp=$sad_data_dir/deriv_weights_manual_seg.scp" \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$sad_data_dir/speech_feat.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + fi + + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_ovlp/storage ]; then + utils/create_split_dir.pl \ + /export/b{01,02,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_ovlp/storage $dir/egs_ovlp/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$ovlp_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$ovlp_data_dir/speech_feat.scp --deriv-weights-scp=$ovlp_data_dir/deriv_weights.scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\"" \ + --targets-parameters="--output-name=output-overlapped_speech --target-type=sparse --dim=2 --targets-scp=$ovlp_data_dir/overlapped_speech_labels_fixed.scp --deriv-weights-scp=$ovlp_data_dir/deriv_weights_for_overlapped_speech.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\"" \ + --generate-egs-scp=true \ + --dir=$dir/egs_ovlp + fi + + if [ $stage -le 4 ]; then + # num_chunk_per_minibatch is multiplied by 4 to allow a buffer to use + # the same egs with a different num_chunk_per_minibatch + steps/nnet3/multilingual/get_egs.sh \ + --cmd "$train_cmd" \ + --minibatch-size $[num_chunk_per_minibatch * 4] \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_ovlp $dir/egs_multi + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --trainer.compute-per-dim-accuracy=true \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$sad_data_dir \ + --targets-scp="$sad_data_dir/speech_feat.scp" \ + --dir=$dir || exit 1 +fi diff --git a/egs/aspire/s5/local/segmentation/tuning/train_lstm_stats_sad_overlap_ami_1a.sh b/egs/aspire/s5/local/segmentation/tuning/train_lstm_stats_sad_overlap_ami_1a.sh new file mode 100644 index 00000000000..d003f746c4b --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_lstm_stats_sad_overlap_ami_1a.sh @@ -0,0 +1,192 @@ +#!/bin/bash + +# This is a script to train a LSTM for overlapped speech activity detection +# and SAD. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +chunk_width=20 +num_chunk_per_minibatch=128 + +extra_left_context=40 # Maximum left context in egs apart from TDNN's left context +extra_right_context=0 # Maximum right context in egs apart from TDNN's right context + +# training options +num_epochs=8 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +data_dir=data/ami_sdm1_train_whole_hires_bp +labels_scp=exp/sad_ami_sdm1_train/ref/overlapping_sad_labels.scp +deriv_weights_scp=exp/sad_ami_sdm1_train/ref/deriv_weights_for_overlapping_sad.scp + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=a + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +num_utts=`cat $data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); $n = ($n > 4000 ? 4000 : $n); print ($n < 6 ? 6 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); $n = ($n > 4000 ? 4000 : $n); print ($n < 6 ? 6 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_ovlp_sad_ami/nnet_lstm +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$data_dir/feats.scp -` name=input + output name=output-temp input=Append(-2,-1,0,1,2) + + relu-renorm-layer name=tdnn1 input=Append(input@-2, input@-1, input, input@1, input@2) dim=256 + relu-renorm-layer name=tdnn2 input=Append(-1,0,1,2) dim=256 + relu-renorm-layer name=tdnn3 input=Append(-3,0,3,6) dim=256 + lstmp-layer name=lstm1 cell-dim=256 recurrent-projection-dim=128 non-recurrent-projection-dim=128 delay=-3 + relu-renorm-layer name=tdnn4 input=Append(-6,0,6) dim=256 + lstmp-layer name=lstm2 cell-dim=256 recurrent-projection-dim=128 non-recurrent-projection-dim=128 delay=-6 + + output-layer name=output-overlapping_sad include-log-softmax=true dim=3 input=lstm2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-overlapping_sad.txt learning-rate-factor=0.05 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-overlapping_sad new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_overlapping_sad + if [ $stage -le 2 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_overlapping_sad/storage ]; then + utils/create_split_dir.pl \ + /export/b{01,02,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_overlapping_sad/storage $dir/egs_overlapping_sad/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-overlapping_sad --target-type=sparse --dim=3 --targets-scp=$labels_scp --deriv-weights-scp=$deriv_weights_scp --scp2ark-cmd=\"ali-to-post scp:- ark: |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_overlapping_sad + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=false --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --trainer.compute-per-dim-accuracy=true \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$data_dir \ + --targets-scp="$labels_scp" \ + --dir=$dir || exit 1 +fi + diff --git a/egs/aspire/s5/local/segmentation/tuning/train_lstm_stats_sad_overlap_ami_1b.sh b/egs/aspire/s5/local/segmentation/tuning/train_lstm_stats_sad_overlap_ami_1b.sh new file mode 100644 index 00000000000..3aa4f28f99a --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_lstm_stats_sad_overlap_ami_1b.sh @@ -0,0 +1,192 @@ +#!/bin/bash + +# This is a script to train a LSTM for overlapped speech activity detection +# and SAD. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +chunk_width=20 +num_chunk_per_minibatch=128 + +extra_left_context=40 # Maximum left context in egs apart from TDNN's left context +extra_right_context=0 # Maximum right context in egs apart from TDNN's right context + +# training options +num_epochs=8 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +data_dir=data/ami_sdm1_train_whole_hires_bp +labels_scp=exp/sad_ami_sdm1_train/ref/overlapping_sad_labels.scp +deriv_weights_scp=exp/sad_ami_sdm1_train/ref/deriv_weights_for_overlapping_sad.scp + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=a + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +num_utts=`cat $data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); $n = ($n > 4000 ? 4000 : $n); print ($n < 6 ? 6 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); $n = ($n > 4000 ? 4000 : $n); print ($n < 6 ? 6 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_ovlp_sad_ami/nnet_lstm +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$data_dir/feats.scp -` name=input + output name=output-temp input=Append(-2,-1,0,1,2) + + relu-renorm-layer name=tdnn1 input=Append(input@-2, input@-1, input, input@1, input@2) dim=256 + relu-renorm-layer name=tdnn2 input=Append(-1,0,1,2) dim=256 + relu-renorm-layer name=tdnn3 input=Append(-3,0,3,6) dim=256 + lstmp-layer name=lstm1 cell-dim=256 recurrent-projection-dim=128 non-recurrent-projection-dim=128 delay=-3 + relu-renorm-layer name=tdnn4 input=Append(-6,0,6) dim=256 + lstmp-layer name=lstm2 cell-dim=256 recurrent-projection-dim=128 non-recurrent-projection-dim=128 delay=-6 + + output-layer name=output-overlapping_sad include-log-softmax=true dim=3 input=lstm2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-overlapping_sad.txt learning-rate-factor=0.05 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-overlapping_sad new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_overlapping_sad + if [ $stage -le 2 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_overlapping_sad/storage ]; then + utils/create_split_dir.pl \ + /export/b{01,02,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_overlapping_sad/storage $dir/egs_overlapping_sad/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-overlapping_sad --target-type=sparse --dim=3 --targets-scp=$labels_scp --deriv-weights-scp=$deriv_weights_scp --scp2ark-cmd=\"ali-to-post scp:- ark: |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_overlapping_sad + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=false --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --trainer.compute-per-dim-accuracy=true \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$data_dir \ + --targets-scp="$labels_scp" \ + --dir=$dir || exit 1 +fi + diff --git a/egs/aspire/s5/local/segmentation/tuning/train_rnn_overlap_1a.sh b/egs/aspire/s5/local/segmentation/tuning/train_rnn_overlap_1a.sh new file mode 100755 index 00000000000..e63c5d8a063 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_rnn_overlap_1a.sh @@ -0,0 +1,184 @@ +#!/bin/bash + +# This is a script to train a lstm for overlapped speech activity detection. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +chunk_width=40 +num_chunk_per_minibatch=64 + +extra_left_context=40 # Maximum left context in egs apart from TDNN's left context +extra_right_context=0 # Maximum right context in egs apart from TDNN's right context + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd="nnet3-copy-egs --keep-outputs=output-overlapped_speech ark:- ark:- |" # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +ovlp_data_dir=data/train_aztec_unsad_seg_ovlp_corrupted_hires_bp + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=f + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +num_utts=`cat $ovlp_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_stats_ovlp/nnet_lstm +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$ovlp_data_dir/feats.scp -` name=input + output name=output-temp input=Append(-2,-1,0,1,2) + + relu-renorm-layer name=tdnn1 dim=256 input=Append(input@-2, input@-1, input, input@1, input@2) + lstmp-layer name=lstm1 cell-dim=256 recurrent-projection-dim=128 non-recurrent-projection-dim=128 delay=-3 + relu-renorm-layer name=tdnn2 input=Append(-6,0,6) dim=256 + lstmp-layer name=lstm2 cell-dim=256 recurrent-projection-dim=128 non-recurrent-projection-dim=128 delay=-6 + + output-layer name=output-overlapped_speech include-log-softmax=true dim=2 input=lstm2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-overlapped_speech.txt max-change=0.75 learning-rate-factor=0.02 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-overlapped_speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_ovlp + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_ovlp/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_ovlp/storage $dir/egs_ovlp/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$ovlp_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$ovlp_data_dir/speech_feat.scp --deriv-weights-scp=$ovlp_data_dir/deriv_weights.scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\"" \ + --targets-parameters="--output-name=output-overlapped_speech --target-type=sparse --dim=2 --targets-scp=$ovlp_data_dir/overlapped_speech_labels_fixed.scp --deriv-weights-scp=$ovlp_data_dir/deriv_weights_for_overlapped_speech.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\"" \ + --generate-egs-scp=true \ + --dir=$dir/egs_ovlp + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$ovlp_data_dir \ + --targets-scp="$ovlp_data_dir/overlapped_spech_labels.scp" \ + --dir=$dir || exit 1 +fi + diff --git a/egs/aspire/s5/local/segmentation/tuning/train_rnn_overlap_1b.sh b/egs/aspire/s5/local/segmentation/tuning/train_rnn_overlap_1b.sh new file mode 100755 index 00000000000..15235882f90 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_rnn_overlap_1b.sh @@ -0,0 +1,184 @@ +#!/bin/bash + +# This is a script to train a LSTM for overlapped speech activity detection. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +chunk_width=40 +num_chunk_per_minibatch=64 + +extra_left_context=40 # Maximum left context in egs apart from TDNN's left context +extra_right_context=0 # Maximum right context in egs apart from TDNN's right context + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd="nnet3-copy-egs --keep-outputs=output-overlapped_speech ark:- ark:- |" # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +ovlp_data_dir=data/train_aztec_unsad_seg_ovlp_corrupted_hires_bp + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=b + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +num_utts=`cat $ovlp_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_stats_ovlp/nnet_lstm +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$ovlp_data_dir/feats.scp -` name=input + output name=output-temp input=Append(-2,-1,0,1,2) + + relu-renorm-layer name=tdnn1 dim=512 input=Append(input@-2, input@-1, input, input@1, input@2) + lstmp-layer name=lstm1 cell-dim=512 recurrent-projection-dim=256 non-recurrent-projection-dim=256 delay=-3 + relu-renorm-layer name=tdnn2 input=Append(-6,0,6) dim=512 + lstmp-layer name=lstm2 cell-dim=512 recurrent-projection-dim=256 non-recurrent-projection-dim=256 delay=-6 + + output-layer name=output-overlapped_speech include-log-softmax=true dim=2 input=lstm2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-overlapped_speech.txt max-change=0.75 learning-rate-factor=0.02 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-overlapped_speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_ovlp + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_ovlp/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_ovlp/storage $dir/egs_ovlp/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$ovlp_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$ovlp_data_dir/speech_feat.scp --deriv-weights-scp=$ovlp_data_dir/deriv_weights.scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\"" \ + --targets-parameters="--output-name=output-overlapped_speech --target-type=sparse --dim=2 --targets-scp=$ovlp_data_dir/overlapped_speech_labels_fixed.scp --deriv-weights-scp=$ovlp_data_dir/deriv_weights_for_overlapped_speech.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\"" \ + --generate-egs-scp=true \ + --dir=$dir/egs_ovlp + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$ovlp_data_dir \ + --targets-scp="$ovlp_data_dir/overlapped_spech_labels.scp" \ + --dir=$dir || exit 1 +fi + diff --git a/egs/aspire/s5/local/segmentation/tuning/train_stats_overlap_1f.sh b/egs/aspire/s5/local/segmentation/tuning/train_stats_overlap_1f.sh new file mode 100755 index 00000000000..2201f9fd8d1 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_stats_overlap_1f.sh @@ -0,0 +1,200 @@ +#!/bin/bash + +# This is a script to train a time-delay neural network for overlapped speech activity detection +# using statistic pooling component for long-context information. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +# TDNN options +relu_dim=256 +chunk_width=40 # We use chunk training for training TDNN +num_chunk_per_minibatch=64 + +extra_left_context=100 # Maximum left context in egs apart from TDNN's left context +extra_right_context=20 # Maximum right context in egs apart from TDNN's right context + +# We randomly select an extra {left,right} context for each job between +# min_extra_*_context and extra_*_context so that the network can get used +# to different contexts used to compute statistics. +min_extra_left_context=20 +min_extra_right_context=0 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd="nnet3-copy-egs --keep-outputs=output-overlapped_speech ark:- ark:- |" # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +ovlp_data_dir=data/train_aztec_unsad_seg_ovlp_corrupted_hires_bp + +#extra_left_context=79 +#extra_right_context=11 + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=f + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +num_utts=`cat $ovlp_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_stats_ovlp/nnet_tdnn +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$ovlp_data_dir/feats.scp -` name=input + output name=output-temp input=Append(-3,-2,-1,0,1,2,3) + + relu-renorm-layer name=tdnn1 input=Append(input@-3, input@-2, input@-1, input, input@1, input@2, input@3) dim=256 + stats-layer name=tdnn2_stats config=mean+count(-99:3:9:99) + relu-renorm-layer name=tdnn2 input=Append(tdnn1@-6, tdnn1, tdnn2_stats) dim=256 + relu-renorm-layer name=tdnn3 input=Append(-9,0,3) dim=256 + relu-renorm-layer name=tdnn4 dim=256 + + output-layer name=output-overlapped_speech include-log-softmax=true dim=2 input=tdnn4 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-overlapped_speech.txt +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-overlapped_speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_ovlp + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_ovlp/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_ovlp/storage $dir/egs_ovlp/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$ovlp_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$ovlp_data_dir/speech_feat.scp --deriv-weights-scp=$ovlp_data_dir/deriv_weights.scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\"" \ + --targets-parameters="--output-name=output-overlapped_speech --target-type=sparse --dim=2 --targets-scp=$ovlp_data_dir/overlapped_speech_labels_fixed.scp --deriv-weights-scp=$ovlp_data_dir/deriv_weights_for_overlapped_speech.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\"" \ + --generate-egs-scp=true \ + --dir=$dir/egs_ovlp + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.min-chunk-left-context=$min_extra_left_context \ + --trainer.min-chunk-right-context=$min_extra_right_context \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$ovlp_data_dir \ + --targets-scp="$ovlp_data_dir/overlapped_spech_labels.scp" \ + --dir=$dir || exit 1 +fi diff --git a/egs/aspire/s5/local/segmentation/tuning/train_stats_overlap_1g.sh b/egs/aspire/s5/local/segmentation/tuning/train_stats_overlap_1g.sh new file mode 100755 index 00000000000..81febb5fa09 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_stats_overlap_1g.sh @@ -0,0 +1,202 @@ +#!/bin/bash + +# This is a script to train a time-delay neural network for overlapped speech activity detection +# using statistic pooling component for long-context information. + +# This scripts is similar to 1f but adds max-change=0.75 and learning-rate-factor=0.1 to the final affine. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +# TDNN options +relu_dim=256 +chunk_width=40 # We use chunk training for training TDNN +num_chunk_per_minibatch=64 + +extra_left_context=100 # Maximum left context in egs apart from TDNN's left context +extra_right_context=20 # Maximum right context in egs apart from TDNN's right context + +# We randomly select an extra {left,right} context for each job between +# min_extra_*_context and extra_*_context so that the network can get used +# to different contexts used to compute statistics. +min_extra_left_context=20 +min_extra_right_context=0 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd="nnet3-copy-egs --keep-outputs=output-overlapped_speech ark:- ark:- |" # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +ovlp_data_dir=data/train_aztec_unsad_seg_ovlp_corrupted_hires_bp + +#extra_left_context=79 +#extra_right_context=11 + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=f + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +num_utts=`cat $ovlp_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_stats_ovlp/nnet_tdnn +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$ovlp_data_dir/feats.scp -` name=input + output name=output-temp input=Append(-3,-2,-1,0,1,2,3) + + relu-renorm-layer name=tdnn1 input=Append(input@-3, input@-2, input@-1, input, input@1, input@2, input@3) dim=256 + stats-layer name=tdnn2_stats config=mean+count(-99:3:9:99) + relu-renorm-layer name=tdnn2 input=Append(tdnn1@-6, tdnn1, tdnn2_stats) dim=256 + relu-renorm-layer name=tdnn3 input=Append(-9,0,3) dim=256 + relu-renorm-layer name=tdnn4 dim=256 + + output-layer name=output-overlapped_speech include-log-softmax=true dim=2 input=tdnn4 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-overlapped_speech.txt max-change=0.75 learning-rate-factor=0.1 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-overlapped_speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_ovlp + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_ovlp/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_ovlp/storage $dir/egs_ovlp/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$ovlp_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$ovlp_data_dir/speech_feat.scp --deriv-weights-scp=$ovlp_data_dir/deriv_weights.scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\"" \ + --targets-parameters="--output-name=output-overlapped_speech --target-type=sparse --dim=2 --targets-scp=$ovlp_data_dir/overlapped_speech_labels_fixed.scp --deriv-weights-scp=$ovlp_data_dir/deriv_weights_for_overlapped_speech.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\"" \ + --generate-egs-scp=true \ + --dir=$dir/egs_ovlp + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.min-chunk-left-context=$min_extra_left_context \ + --trainer.min-chunk-right-context=$min_extra_right_context \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$ovlp_data_dir \ + --targets-scp="$ovlp_data_dir/overlapped_spech_labels.scp" \ + --dir=$dir || exit 1 +fi diff --git a/egs/aspire/s5/local/segmentation/tuning/train_stats_overlap_1h.sh b/egs/aspire/s5/local/segmentation/tuning/train_stats_overlap_1h.sh new file mode 100755 index 00000000000..adc4fc81c08 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_stats_overlap_1h.sh @@ -0,0 +1,202 @@ +#!/bin/bash + +# This is a script to train a time-delay neural network for overlapped speech activity detection +# using statistic pooling component for long-context information. + +# This scripts is similar to 1f but adds max-change=0.75 and learning-rate-factor=0.02 to the final affine. +# And changed relu-dim to 512. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +# TDNN options +chunk_width=40 # We use chunk training for training TDNN +num_chunk_per_minibatch=64 + +extra_left_context=100 # Maximum left context in egs apart from TDNN's left context +extra_right_context=20 # Maximum right context in egs apart from TDNN's right context + +# We randomly select an extra {left,right} context for each job between +# min_extra_*_context and extra_*_context so that the network can get used +# to different contexts used to compute statistics. +min_extra_left_context=20 +min_extra_right_context=0 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd="nnet3-copy-egs --keep-outputs=output-overlapped_speech ark:- ark:- |" # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +ovlp_data_dir=data/train_aztec_unsad_seg_ovlp_corrupted_hires_bp + +#extra_left_context=79 +#extra_right_context=11 + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=f + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +num_utts=`cat $ovlp_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_stats_ovlp/nnet_tdnn +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$ovlp_data_dir/feats.scp -` name=input + output name=output-temp input=Append(-3,-2,-1,0,1,2,3) + + relu-renorm-layer name=tdnn1 input=Append(input@-3, input@-2, input@-1, input, input@1, input@2, input@3) dim=512 + stats-layer name=tdnn2_stats config=mean+count(-99:3:9:99) + relu-renorm-layer name=tdnn2 input=Append(tdnn1@-6, tdnn1, tdnn2_stats) dim=512 + relu-renorm-layer name=tdnn3 input=Append(-9,0,3) dim=512 + relu-renorm-layer name=tdnn4 dim=512 + + output-layer name=output-overlapped_speech include-log-softmax=true dim=2 input=tdnn4 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-overlapped_speech.txt max-change=0.75 learning-rate-factor=0.02 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-overlapped_speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_ovlp + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_ovlp/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_ovlp/storage $dir/egs_ovlp/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$ovlp_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$ovlp_data_dir/speech_feat.scp --deriv-weights-scp=$ovlp_data_dir/deriv_weights.scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\"" \ + --targets-parameters="--output-name=output-overlapped_speech --target-type=sparse --dim=2 --targets-scp=$ovlp_data_dir/overlapped_speech_labels_fixed.scp --deriv-weights-scp=$ovlp_data_dir/deriv_weights_for_overlapped_speech.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\"" \ + --generate-egs-scp=true \ + --dir=$dir/egs_ovlp + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.min-chunk-left-context=$min_extra_left_context \ + --trainer.min-chunk-right-context=$min_extra_right_context \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$ovlp_data_dir \ + --targets-scp="$ovlp_data_dir/overlapped_spech_labels.scp" \ + --dir=$dir || exit 1 +fi diff --git a/egs/aspire/s5/local/segmentation/tuning/train_stats_overlap_1i.sh b/egs/aspire/s5/local/segmentation/tuning/train_stats_overlap_1i.sh new file mode 100755 index 00000000000..dcd11ad2aa6 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_stats_overlap_1i.sh @@ -0,0 +1,202 @@ +#!/bin/bash + +# This is a script to train a time-delay neural network for overlapped speech activity detection +# using statistic pooling component for long-context information. + +# This scripts is similar to 1f but adds max-change=0.75 and learning-rate-factor=0.02 to the final affine. +# Similar to 1g but moved stats pooling to higher layer. Changed splicing to -12 from -9. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +# TDNN options +chunk_width=40 # We use chunk training for training TDNN +num_chunk_per_minibatch=64 + +extra_left_context=90 # Maximum left context in egs apart from TDNN's left context +extra_right_context=15 # Maximum right context in egs apart from TDNN's right context + +# We randomly select an extra {left,right} context for each job between +# min_extra_*_context and extra_*_context so that the network can get used +# to different contexts used to compute statistics. +min_extra_left_context=20 +min_extra_right_context=0 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd="nnet3-copy-egs --keep-outputs=output-overlapped_speech ark:- ark:- |" # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +ovlp_data_dir=data/train_aztec_unsad_seg_ovlp_corrupted_hires_bp + +#extra_left_context=79 +#extra_right_context=11 + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=f + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +num_utts=`cat $ovlp_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_stats_ovlp/nnet_tdnn +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$ovlp_data_dir/feats.scp -` name=input + output name=output-temp input=Append(-3,-2,-1,0,1,2,3) + + relu-renorm-layer name=tdnn1 input=Append(input@-3, input@-2, input@-1, input, input@1, input@2, input@3) dim=512 + relu-renorm-layer name=tdnn2 input=Append(tdnn1@-6, tdnn1) dim=512 + stats-layer name=tdnn3_stats config=mean+count(-96:6:12:96) + relu-renorm-layer name=tdnn3 input=Append(tdnn2@-12,tdnn2,tdnn2@6, tdnn3_stats) dim=512 + relu-renorm-layer name=tdnn4 dim=512 + + output-layer name=output-overlapped_speech include-log-softmax=true dim=2 input=tdnn4 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-overlapped_speech.txt max-change=0.75 learning-rate-factor=0.02 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-overlapped_speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_ovlp + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_ovlp/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_ovlp/storage $dir/egs_ovlp/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$ovlp_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$ovlp_data_dir/speech_feat.scp --deriv-weights-scp=$ovlp_data_dir/deriv_weights.scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\"" \ + --targets-parameters="--output-name=output-overlapped_speech --target-type=sparse --dim=2 --targets-scp=$ovlp_data_dir/overlapped_speech_labels_fixed.scp --deriv-weights-scp=$ovlp_data_dir/deriv_weights_for_overlapped_speech.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\"" \ + --generate-egs-scp=true \ + --dir=$dir/egs_ovlp + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.min-chunk-left-context=$min_extra_left_context \ + --trainer.min-chunk-right-context=$min_extra_right_context \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$ovlp_data_dir \ + --targets-scp="$ovlp_data_dir/overlapped_spech_labels.scp" \ + --dir=$dir || exit 1 +fi diff --git a/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1a.sh b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1a.sh new file mode 100755 index 00000000000..8242b83c747 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1a.sh @@ -0,0 +1,172 @@ +#!/bin/bash + +# This is a script to train a time-delay neural network for speech activity detection (SAD) and +# music-id using statistic pooling component for long-context information. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +# TDNN options +splice_indexes="-3,-2,-1,0,1,2,3 -6,0,mean+count(-99:3:9:99) -9,0,3 0" +relu_dim=256 +chunk_width=20 # We use chunk training for training TDNN +extra_left_context=100 # Maximum left context in egs apart from TDNN's left context +extra_right_context=20 # Maximum right context in egs apart from TDNN's right context + +# We randomly select an extra {left,right} context for each job between +# min_extra_*_context and extra_*_context so that the network can get used +# to different contexts used to compute statistics. +min_extra_left_context=20 +min_extra_right_context=0 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +num_utts_subset_valid=50 # "utts" is actually recording. So this is prettly small. +num_utts_subset_train=50 + +# target options +train_data_dir=data/train_azteec_whole_sp_corrupted_hires + +speech_feat_scp= +music_labels_scp= + +deriv_weights_scp= + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=a + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +num_hidden_layers=`echo $splice_indexes | perl -ane 'print scalar @F'` || exit 1 +if [ -z "$dir" ]; then + dir=exp/nnet3_stats_sad_music/nnet_tdnn +fi + +dir=$dir${affix:+_$affix}_n${num_hidden_layers} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$train_data_dir/feats.scp -` name=input + + # please note that it is important to have input layer with the name=input + # as the layer immediately preceding the fixed-affine-layer to enable + # the use of short notation for the descriptor + # This is disabled for now. + # fixed-affine-layer name=lda input=Append(-3,-2,-1,0,1,2,3) affine-transform-file=$dir/configs/lda.mat + # the first splicing is moved before the lda layer, so no splicing here + # relu-renorm-layer name=tdnn1 dim=625 + + relu-renorm-layer name=tdnn1 input=Append(-3,-2,-1,0,1,2,3) dim=256 + stats-layer name=tdnn2.stats config=mean+count(-99:3:9:99) + relu-renorm-layer name=tdnn2 input=Append(Offset(tdnn1, -6), tdnn1, tdnn2.stats) dim=256 + relu-renorm-layer name=tdnn3 input=Append(-9,0,3) dim=256 + relu-renorm-layer name=tdnn4 dim=256 + + output-layer name=output-speech include-log-softmax=true dim=2 input=tdnn4 + output-layer name=output-music include-log-softmax=true dim=2 input=tdnn4 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ +fi + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs + if [ $stage -le 4 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$train_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=20000 \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$speech_feat_scp --deriv-weights-scp=$deriv_weights_scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-music --target-type=sparse --dim=2 --targets-scp=$music_labels_scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --dir=$dir/egs + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=20 \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.min-chunk-left-context=$min_extra_left_context \ + --trainer.min-chunk-right-context=$min_extra_right_context \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=64 \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$train_data_dir \ + --targets-scp="$speech_feat_scp" \ + --dir=$dir || exit 1 +fi + +if [ $stage -le 6 ]; then + $train_cmd JOB=1:100 $dir/log/compute_post_output-speech.JOB.log \ + extract-column "scp:utils/split_scp.pl -j 100 \$[JOB-1] $speech_feat_scp |" ark,t:- \| \ + steps/segmentation/quantize_vector.pl \| \ + ali-to-post ark,t:- ark:- \| \ + weight-post ark:- scp:$deriv_weights_scp ark:- \| \ + post-to-feats --post-dim=2 ark:- ark:- \| \ + matrix-sum-rows ark:- ark:- \| \ + vector-sum ark:- $dir/post_output-speech.vec.JOB + eval vector-sum $dir/post_output-speech.vec.{`seq -s, 100`} $dir/post_output-speech.vec + + $train_cmd JOB=1:100 $dir/log/compute_post_output-music.JOB.log \ + ali-to-post "scp:utils/split_scp.pl -j 100 \$[JOB-1] $music_labels_scp |" ark:- \| \ + post-to-feats --post-dim=2 ark:- ark:- \| \ + matrix-sum-rows ark:- ark:- \| \ + vector-sum ark:- $dir/post_output-music.vec.JOB + eval vector-sum $dir/post_output-music.vec.{`seq -s, 100`} $dir/post_output-music.vec +fi diff --git a/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1c.sh b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1c.sh new file mode 100755 index 00000000000..163ea6df14d --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1c.sh @@ -0,0 +1,185 @@ +#!/bin/bash + +# This is a script to train a time-delay neural network for speech activity detection (SAD) and +# music-id using statistic pooling component for long-context information. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +# TDNN options +relu_dim=256 +chunk_width=20 # We use chunk training for training TDNN +extra_left_context=100 # Maximum left context in egs apart from TDNN's left context +extra_right_context=20 # Maximum right context in egs apart from TDNN's right context + +# We randomly select an extra {left,right} context for each job between +# min_extra_*_context and extra_*_context so that the network can get used +# to different contexts used to compute statistics. +min_extra_left_context=20 +min_extra_right_context=0 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +num_utts_subset_valid=50 # "utts" is actually recording. So this is prettly small. +num_utts_subset_train=50 + +# target options +train_data_dir=data/train_azteec_whole_sp_corrupted_hires + +speech_feat_scp= +music_labels_scp= + +deriv_weights_scp= + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=a + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if [ -z "$dir" ]; then + dir=exp/nnet3_stats_sad_music/nnet_tdnn +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$train_data_dir/feats.scp -` name=input + output name=output-temp input=Append(-3,-2,-1,0,1,2,3) + + relu-renorm-layer name=tdnn1 input=Append(input@-3, input@-2, input@-1, input, input@1, input@2, input@3) dim=256 + stats-layer name=tdnn2_stats config=mean+count(-99:3:9:99) + relu-renorm-layer name=tdnn2 input=Append(tdnn1@-9, tdnn1@-3, tdnn1, tdnn1@3, tdnn2_stats) dim=256 + stats-layer name=tdnn3_stats config=mean+count(-108:9:27:108) + relu-renorm-layer name=tdnn3 input=Append(tdnn2@-27, tdnn2@-9, tdnn2, tdnn2@9, tdnn3_stats) dim=256 + relu-renorm-layer name=tdnn4 dim=256 + + output-layer name=output-speech include-log-softmax=true dim=2 input=tdnn4 + output-layer name=output-music include-log-softmax=true dim=2 input=tdnn4 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs + if [ $stage -le 4 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$train_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=20000 \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$speech_feat_scp --deriv-weights-scp=$deriv_weights_scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-music --target-type=sparse --dim=2 --targets-scp=$music_labels_scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --dir=$dir/egs + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=20 \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.min-chunk-left-context=$min_extra_left_context \ + --trainer.min-chunk-right-context=$min_extra_right_context \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=64 \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$train_data_dir \ + --targets-scp="$speech_feat_scp" \ + --dir=$dir || exit 1 +fi + +if [ $stage -le 6 ]; then + $train_cmd JOB=1:100 $dir/log/compute_post_output-speech.JOB.log \ + extract-column "scp:utils/split_scp.pl -j 100 \$[JOB-1] $speech_feat_scp |" ark,t:- \| \ + steps/segmentation/quantize_vector.pl \| \ + ali-to-post ark,t:- ark:- \| \ + weight-post ark:- scp:$deriv_weights_scp ark:- \| \ + post-to-feats --post-dim=2 ark:- ark:- \| \ + matrix-sum-rows ark:- ark:- \| \ + vector-sum ark:- $dir/post_output-speech.vec.JOB + eval vector-sum $dir/post_output-speech.vec.{`seq -s, 100`} $dir/post_output-speech.vec + + $train_cmd JOB=1:100 $dir/log/compute_post_output-music.JOB.log \ + ali-to-post "scp:utils/split_scp.pl -j 100 \$[JOB-1] $music_labels_scp |" ark:- \| \ + post-to-feats --post-dim=2 ark:- ark:- \| \ + matrix-sum-rows ark:- ark:- \| \ + vector-sum ark:- $dir/post_output-music.vec.JOB + eval vector-sum $dir/post_output-music.vec.{`seq -s, 100`} $dir/post_output-music.vec +fi diff --git a/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1d.sh b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1d.sh new file mode 100755 index 00000000000..a013fcc49a7 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1d.sh @@ -0,0 +1,184 @@ +#!/bin/bash + +# This is a script to train a time-delay neural network for speech activity detection (SAD) and +# music-id using statistic pooling component for long-context information. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +# TDNN options +relu_dim=256 +chunk_width=20 # We use chunk training for training TDNN +extra_left_context=100 # Maximum left context in egs apart from TDNN's left context +extra_right_context=20 # Maximum right context in egs apart from TDNN's right context + +# We randomly select an extra {left,right} context for each job between +# min_extra_*_context and extra_*_context so that the network can get used +# to different contexts used to compute statistics. +min_extra_left_context=20 +min_extra_right_context=0 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +num_utts_subset_valid=50 # "utts" is actually recording. So this is prettly small. +num_utts_subset_train=50 + +# target options +train_data_dir=data/train_azteec_whole_sp_corrupted_hires + +speech_feat_scp= +music_labels_scp= + +deriv_weights_scp= + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=a + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if [ -z "$dir" ]; then + dir=exp/nnet3_stats_sad_music/nnet_tdnn +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$train_data_dir/feats.scp -` name=input + output name=output-temp input=Append(-3,-2,-1,0,1,2,3) + + relu-renorm-layer name=tdnn1 input=Append(input@-3, input@-2, input@-1, input, input@1, input@2, input@3) dim=256 + stats-layer name=tdnn2_stats config=mean+count(-99:3:9:99) + relu-renorm-layer name=tdnn2 input=Append(tdnn1@-9, tdnn1@-3, tdnn1, tdnn1@3, tdnn2_stats) dim=256 + stats-layer name=tdnn3_stats config=mean+count(-108:9:27:108) + relu-renorm-layer name=tdnn3 input=Append(tdnn2@-27, tdnn2@-9, tdnn2, tdnn2@9, tdnn3_stats) dim=256 + + output-layer name=output-speech include-log-softmax=true dim=2 input=tdnn3 + output-layer name=output-music include-log-softmax=true dim=2 input=tdnn3 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs + if [ $stage -le 4 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$train_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=20000 \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$speech_feat_scp --deriv-weights-scp=$deriv_weights_scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-music --target-type=sparse --dim=2 --targets-scp=$music_labels_scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --dir=$dir/egs + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=20 \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.min-chunk-left-context=$min_extra_left_context \ + --trainer.min-chunk-right-context=$min_extra_right_context \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=64 \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$train_data_dir \ + --targets-scp="$speech_feat_scp" \ + --dir=$dir || exit 1 +fi + +if [ $stage -le 6 ]; then + $train_cmd JOB=1:100 $dir/log/compute_post_output-speech.JOB.log \ + extract-column "scp:utils/split_scp.pl -j 100 \$[JOB-1] $speech_feat_scp |" ark,t:- \| \ + steps/segmentation/quantize_vector.pl \| \ + ali-to-post ark,t:- ark:- \| \ + weight-post ark:- scp:$deriv_weights_scp ark:- \| \ + post-to-feats --post-dim=2 ark:- ark:- \| \ + matrix-sum-rows ark:- ark:- \| \ + vector-sum ark:- $dir/post_output-speech.vec.JOB + eval vector-sum $dir/post_output-speech.vec.{`seq -s, 100`} $dir/post_output-speech.vec + + $train_cmd JOB=1:100 $dir/log/compute_post_output-music.JOB.log \ + ali-to-post "scp:utils/split_scp.pl -j 100 \$[JOB-1] $music_labels_scp |" ark:- \| \ + post-to-feats --post-dim=2 ark:- ark:- \| \ + matrix-sum-rows ark:- ark:- \| \ + vector-sum ark:- $dir/post_output-music.vec.JOB + eval vector-sum $dir/post_output-music.vec.{`seq -s, 100`} $dir/post_output-music.vec +fi diff --git a/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1e.sh b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1e.sh new file mode 100755 index 00000000000..703865b8ad5 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1e.sh @@ -0,0 +1,229 @@ +#!/bin/bash + +# This is a script to train a time-delay neural network for speech activity detection (SAD) and +# music-id using statistic pooling component for long-context information. +# This script is same as 1d, but add add-log-stddev to norm layers. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +# TDNN options +relu_dim=256 +chunk_width=20 # We use chunk training for training TDNN +num_chunk_per_minibatch=64 + +extra_left_context=79 # Maximum left context in egs apart from TDNN's left context +extra_right_context=11 # Maximum right context in egs apart from TDNN's right context + +# We randomly select an extra {left,right} context for each job between +# min_extra_*_context and extra_*_context so that the network can get used +# to different contexts used to compute statistics. +min_extra_left_context=79 +min_extra_right_context=11 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +num_utts_subset_valid=50 # "utts" is actually recording. So this is prettly small. +num_utts_subset_train=50 + +# target options +train_data_dir=data/train_aztec_small_unsad_whole_all_corrupted_sp_hires_bp + +speech_feat_scp=data/train_aztec_unsad_whole_corrupted_sp_hires_bp_2400/speech_feat.scp +deriv_weights_scp=data/train_aztec_unsad_whole_corrupted_sp_hires_bp_2400/deriv_weights.scp +music_labels_scp=data/train_aztec_small_unsad_whole_music_corrupted_sp_hires_bp/music_labels.scp + +extra_egs_copy_cmd="nnet3-copy-egs --keep-outputs=output-speech,output-music ark:- ark:- |" + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=a + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if [ -z "$dir" ]; then + dir=exp/nnet3_stats_sad_music/nnet_tdnn +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$train_data_dir/feats.scp -` name=input + output name=output-temp input=Append(-3,-2,-1,0,1,2,3) + + relu-renorm-layer name=tdnn1 input=Append(input@-3, input@-2, input@-1, input, input@1, input@2, input@3) dim=256 add-log-stddev=true + stats-layer name=tdnn2_stats config=mean+count(-99:3:9:99) + relu-renorm-layer name=tdnn2 input=Append(tdnn1@-9, tdnn1@-3, tdnn1, tdnn1@3, tdnn2_stats) dim=256 add-log-stddev=true + stats-layer name=tdnn3_stats config=mean+count(-108:9:27:108) + relu-renorm-layer name=tdnn3 input=Append(tdnn2@-27, tdnn2@-9, tdnn2, tdnn2@9, tdnn3_stats) dim=256 add-log-stddev=true + + output-layer name=output-speech include-log-softmax=true dim=2 input=tdnn3 + output-layer name=output-music include-log-softmax=true dim=2 input=tdnn3 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` +speech_data_dir=$dir/`basename $train_data_dir`_speech +music_data_dir=$dir/`basename $train_data_dir`_music + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 2 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + + . $dir/configs/vars + + utils/subset_data_dir.sh --utt-list $speech_feat_scp ${train_data_dir} $dir/`basename ${train_data_dir}`_speech + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$speech_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$speech_feat_scp --deriv-weights-scp=$deriv_weights_scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + fi + + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_music/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_music/storage $dir/egs_music/storage + fi + + . $dir/configs/vars + utils/subset_data_dir.sh --utt-list $music_labels_scp ${train_data_dir} $dir/`basename ${train_data_dir}`_music + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$music_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-music --target-type=sparse --dim=2 --targets-scp=$music_labels_scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_music + fi + + if [ $stage -le 4 ]; then + steps/nnet3/multilingual/get_egs.sh \ + --minibatch-size $[chunk_width * num_chunk_per_minibatch] \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_music $dir/egs_multi + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.min-chunk-left-context=$min_extra_left_context \ + --trainer.min-chunk-right-context=$min_extra_right_context \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$train_data_dir \ + --targets-scp="$speech_feat_scp" \ + --dir=$dir || exit 1 +fi + +if [ $stage -le 6 ]; then + $train_cmd JOB=1:100 $dir/log/compute_post_output-speech.JOB.log \ + extract-column "scp:utils/split_scp.pl -j 100 \$[JOB-1] $speech_feat_scp |" ark,t:- \| \ + steps/segmentation/quantize_vector.pl \| \ + ali-to-post ark,t:- ark:- \| \ + weight-post ark:- scp:$deriv_weights_scp ark:- \| \ + post-to-feats --post-dim=2 ark:- ark:- \| \ + matrix-sum-rows ark:- ark:- \| \ + vector-sum ark:- $dir/post_output-speech.vec.JOB + eval vector-sum $dir/post_output-speech.vec.{`seq -s, 100`} $dir/post_output-speech.vec + + $train_cmd JOB=1:100 $dir/log/compute_post_output-music.JOB.log \ + ali-to-post "scp:utils/split_scp.pl -j 100 \$[JOB-1] $music_labels_scp |" ark:- \| \ + post-to-feats --post-dim=2 ark:- ark:- \| \ + matrix-sum-rows ark:- ark:- \| \ + vector-sum ark:- $dir/post_output-music.vec.JOB + eval vector-sum $dir/post_output-music.vec.{`seq -s, 100`} $dir/post_output-music.vec +fi + diff --git a/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1f.sh b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1f.sh new file mode 100755 index 00000000000..0afdd0072ac --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1f.sh @@ -0,0 +1,227 @@ +#!/bin/bash + +# This is a script to train a time-delay neural network for speech activity detection (SAD) and +# music-id using statistic pooling component for long-context information. +# This script is same as 1e, but removes the stats component. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +# TDNN options +relu_dim=256 +chunk_width=20 # We use chunk training for training TDNN +num_chunk_per_minibatch=64 + +extra_left_context=79 # Maximum left context in egs apart from TDNN's left context +extra_right_context=11 # Maximum right context in egs apart from TDNN's right context + +# We randomly select an extra {left,right} context for each job between +# min_extra_*_context and extra_*_context so that the network can get used +# to different contexts used to compute statistics. +min_extra_left_context=20 +min_extra_right_context=0 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +num_utts_subset_valid=50 # "utts" is actually recording. So this is prettly small. +num_utts_subset_train=50 + +# target options +train_data_dir=data/train_aztec_small_unsad_whole_all_corrupted_sp_hires_bp + +speech_feat_scp=data/train_aztec_unsad_whole_corrupted_sp_hires_bp_2400/speech_feat.scp +deriv_weights_scp=data/train_aztec_unsad_whole_corrupted_sp_hires_bp_2400/deriv_weights.scp +music_labels_scp=data/train_aztec_small_unsad_whole_music_corrupted_sp_hires_bp/music_labels.scp + +extra_egs_copy_cmd="nnet3-copy-egs --keep-outputs=output-speech,output-music ark:- ark:- |" + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=a + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if [ -z "$dir" ]; then + dir=exp/nnet3_stats_sad_music/nnet_tdnn +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$train_data_dir/feats.scp -` name=input + output name=output-temp input=Append(-3,-2,-1,0,1,2,3) + + relu-renorm-layer name=tdnn1 input=Append(input@-3, input@-2, input@-1, input, input@1, input@2, input@3) dim=256 add-log-stddev=true + relu-renorm-layer name=tdnn2 input=Append(tdnn1@-9, tdnn1@-3, tdnn1, tdnn1@3) dim=256 add-log-stddev=true + relu-renorm-layer name=tdnn3 input=Append(tdnn2@-27, tdnn2@-9, tdnn2, tdnn2@9) dim=256 add-log-stddev=true + + output-layer name=output-speech include-log-softmax=true dim=2 input=tdnn3 + output-layer name=output-music include-log-softmax=true dim=2 input=tdnn3 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` +speech_data_dir=$dir/`basename $train_data_dir`_speech +music_data_dir=$dir/`basename $train_data_dir`_music + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 2 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + + . $dir/configs/vars + + utils/subset_data_dir.sh --utt-list $speech_feat_scp ${train_data_dir} $dir/`basename ${train_data_dir}`_speech + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$speech_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$speech_feat_scp --deriv-weights-scp=$deriv_weights_scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + fi + + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_music/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_music/storage $dir/egs_music/storage + fi + + . $dir/configs/vars + utils/subset_data_dir.sh --utt-list $music_labels_scp ${train_data_dir} $dir/`basename ${train_data_dir}`_music + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$music_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-music --target-type=sparse --dim=2 --targets-scp=$music_labels_scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_music + fi + + if [ $stage -le 4 ]; then + steps/nnet3/multilingual/get_egs.sh \ + --minibatch-size $[chunk_width * num_chunk_per_minibatch] \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_music $dir/egs_multi + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.min-chunk-left-context=$min_extra_left_context \ + --trainer.min-chunk-right-context=$min_extra_right_context \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$train_data_dir \ + --targets-scp="$speech_feat_scp" \ + --dir=$dir || exit 1 +fi + +if [ $stage -le 6 ]; then + $train_cmd JOB=1:100 $dir/log/compute_post_output-speech.JOB.log \ + extract-column "scp:utils/split_scp.pl -j 100 \$[JOB-1] $speech_feat_scp |" ark,t:- \| \ + steps/segmentation/quantize_vector.pl \| \ + ali-to-post ark,t:- ark:- \| \ + weight-post ark:- scp:$deriv_weights_scp ark:- \| \ + post-to-feats --post-dim=2 ark:- ark:- \| \ + matrix-sum-rows ark:- ark:- \| \ + vector-sum ark:- $dir/post_output-speech.vec.JOB + eval vector-sum $dir/post_output-speech.vec.{`seq -s, 100`} $dir/post_output-speech.vec + + $train_cmd JOB=1:100 $dir/log/compute_post_output-music.JOB.log \ + ali-to-post "scp:utils/split_scp.pl -j 100 \$[JOB-1] $music_labels_scp |" ark:- \| \ + post-to-feats --post-dim=2 ark:- ark:- \| \ + matrix-sum-rows ark:- ark:- \| \ + vector-sum ark:- $dir/post_output-music.vec.JOB + eval vector-sum $dir/post_output-music.vec.{`seq -s, 100`} $dir/post_output-music.vec +fi + diff --git a/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1g.sh b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1g.sh new file mode 100755 index 00000000000..e411b94c893 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1g.sh @@ -0,0 +1,234 @@ +#!/bin/bash + +# This is a script to train a time-delay neural network for speech activity detection (SAD) and +# music-id using statistic pooling component for long-context information. +# This script is same as 1e, but removes the stats component in the 3rd layer. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +# TDNN options +relu_dim=256 +chunk_width=20 # We use chunk training for training TDNN +num_chunk_per_minibatch=64 + +extra_left_context=79 # Maximum left context in egs apart from TDNN's left context +extra_right_context=11 # Maximum right context in egs apart from TDNN's right context + +# We randomly select an extra {left,right} context for each job between +# min_extra_*_context and extra_*_context so that the network can get used +# to different contexts used to compute statistics. +min_extra_left_context=20 +min_extra_right_context=0 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +sad_data_dir=data/train_aztec_unsad_whole_corrupted_sp_hires_bp_2400 +music_data_dir=data/train_aztec_unsad_whole_music_corrupted_sp_hires_bp + +extra_egs_copy_cmd="nnet3-copy-egs --keep-outputs=output-speech,output-music ark:- ark:- |" + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=a + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +num_utts=`cat $sad_data_dir/utt2spk $music_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_stats_sad_music/nnet_tdnn +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$sad_data_dir/feats.scp -` name=input + output name=output-temp input=Append(-3,-2,-1,0,1,2,3) + + relu-renorm-layer name=tdnn1 input=Append(input@-3, input@-2, input@-1, input, input@1, input@2, input@3) dim=256 add-log-stddev=true + stats-layer name=tdnn2_stats config=mean+count(-99:3:9:99) + relu-renorm-layer name=tdnn2 input=Append(tdnn1@-9, tdnn1@-3, tdnn1, tdnn1@3, tdnn2_stats) dim=256 add-log-stddev=true + relu-renorm-layer name=tdnn3 input=Append(tdnn2@-27, tdnn2@-9, tdnn2, tdnn2@9) dim=256 add-log-stddev=true + + output-layer name=output-snr include-log-softmax=false dim=$num_snr_bins objective-type=quadratic objective-scale=`perl -e "print (($num_frames_music / $num_frames_sad) ** 0.25) / $num_snr_bins"` input=tdnn3 + output-layer name=output-speech include-log-softmax=true dim=2 input=tdnn3 objective-scale=`perl -e "print (($num_frames_music / $num_frames_sad) ** 0.25)"` + output-layer name=output-music include-log-softmax=true dim=2 input=tdnn3 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 2 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$sad_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-snr --target-type=dense --targets-scp=$sad_data_dir/irm_targets.scp --deriv-weights-scp=$sad_data_dir/deriv_weights_manual_seg.scp" \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$sad_data_dir/speech_feat.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + fi + + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_music/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_music/storage $dir/egs_music/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$music_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-music --target-type=sparse --dim=2 --targets-scp=$music_labels_scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_music + fi + + if [ $stage -le 4 ]; then + # num_chunk_per_minibatch is multiplied by 4 to allow a buffer to use + # the same egs with a different num_chunk_per_minibatch + steps/nnet3/multilingual/get_egs.sh \ + --cmd "$train_cmd" \ + --minibatch-size $[num_chunk_per_minibatch * 4] \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_music $dir/egs_multi + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.min-chunk-left-context=$min_extra_left_context \ + --trainer.min-chunk-right-context=$min_extra_right_context \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$sad_data_dir \ + --targets-scp="$sad_data_dir/speech_feat.scp" \ + --dir=$dir || exit 1 +fi + diff --git a/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_snr_1h.sh b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_snr_1h.sh new file mode 100644 index 00000000000..e585f27e5fd --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_snr_1h.sh @@ -0,0 +1,310 @@ +#!/bin/bash + +# This is a script to train a time-delay neural network for speech activity detection (SAD) and +# music-id using statistic pooling component for long-context information. +# This script is same as 1c, but uses larger amount of data. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +chunk_width=20 +num_chunk_per_minibatch=64 + +extra_left_context=79 +extra_right_context=11 + +relu_dim=256 +cell_dim=256 +projection_dim=64 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +extra_egs_copy_cmd="nnet3-copy-egs --keep-outputs=output-speech,output-music,output-speech_music,output-snr ark:- ark:- |" + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=1h + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if [ $stage -le -1 ]; then + utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp irm_targets.scp deriv_weights_for_irm_targets.scp" \ + data/train_tztec_whole_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil}_train_whole_corrupted_spr_hires_bp/ + + cp data/train_tztec_whole_corrupted_spr_hires_bp/{speech_labels.scp,speech_music_labels.scp} + + utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp music_labels.scp speech_music_labels.scp" \ + data/train_tztec_whole_music_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_music_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil}_train_whole_music_corrupted_spr_hires_bp/ +fi + +sad_data_dir=data/train_tztec_whole_corrupted_spr_hires_bp +music_data_dir=data/train_tztec_whole_music_corrupted_spr_hires_bp + +num_utts=`cat $sad_data_dir/utt2spk $music_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_stats_sad_music_snr/nnet_tdnn_stats +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$sad_data_dir/feats.scp -` name=input + + relu-renorm-layer name=tdnn1 input=Append(-3,-2,-1,0,1,2,3) dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn2 input=Append(-6,0,6) dim=$relu_dim + stats-layer name=tdnn2_stats config=mean+count(-108:6:18:108) + relu-renorm-layer name=tdnn3 input=Append(tdnn2@-12,tdnn2@0,tdnn2@12,tdnn2_stats) dim=$relu_dim + stats-layer name=tdnn3_stats config=mean+count(-108:12:36:108) + relu-renorm-layer name=tdnn4 input=Append(tdnn3@-12,tdnn3@0,tdnn3@12,tdnn3_stats) dim=$relu_dim + relu-renorm-layer name=tdnn4-snr input=Append(tdnn3@-12,tdnn3@0,tdnn3@12,tdnn3_stats) dim=$relu_dim + + output-layer name=output-speech include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech.txt learning-rate-factor=0.1 objective-scale=$speech_scale input=tdnn4 + output-layer name=output-music include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-music.txt learning-rate-factor=0.1 objective-scale=$music_scale input=tdnn4 + output-layer name=output-speech_music include-log-softmax=true dim=4 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech_music.txt learning-rate-factor=0.1 objective-scale=$speech_music_scale input=tdnn4 + output-layer name=output-snr include-log-softmax=false dim=$num_snr_bins objective-type=quadratic learning-rate-factor=0.1 objective-scale=$snr_scale input=tdnn4-snr + + output name=output-temp input=Append(input@-3,input@-2,input@-1,input,input@1,input@2, input@3) +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 2 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$sad_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$sad_data_dir/speech_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech_music --target-type=sparse --dim=4 --targets-scp=$sad_data_dir/speech_music_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-snr --target-type=dense --targets-scp=$sad_data_dir/irm_targets.scp --deriv-weights-scp=$sad_data_dir/deriv_weights_for_irm_targets.scp" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + fi + + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_music/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_music/storage $dir/egs_music/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$music_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-music --target-type=sparse --dim=2 --targets-scp=$music_data_dir/music_labels.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech_music --target-type=sparse --dim=4 --targets-scp=$music_data_dir/speech_music_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$music_data_dir/speech_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_music + fi + + if [ $stage -le 4 ]; then + # num_chunk_per_minibatch is multiplied by 4 to allow a buffer to use + # the same egs with a different num_chunk_per_minibatch + steps/nnet3/multilingual/get_egs.sh \ + --cmd "$train_cmd" \ + --minibatch-size $num_chunk_per_minibatch \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_music $dir/egs_multi + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --trainer.compute-per-dim-accuracy=true \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$sad_data_dir \ + --targets-scp="$sad_data_dir/speech_labels.scp" \ + --dir=$dir || exit 1 +fi + + + diff --git a/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_snr_1i.sh b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_snr_1i.sh new file mode 100644 index 00000000000..3ddcdd795db --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_snr_1i.sh @@ -0,0 +1,310 @@ +#!/bin/bash + +# This is a script to train a time-delay neural network for speech activity detection (SAD) and +# music-id using statistic pooling component for long-context information. +# This script is same as 1c, but uses larger amount of data. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +chunk_width=20 +num_chunk_per_minibatch=64 + +extra_left_context=79 +extra_right_context=11 + +relu_dim=256 +cell_dim=256 +projection_dim=64 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +extra_egs_copy_cmd="nnet3-copy-egs --keep-outputs=output-speech,output-music,output-speech_music,output-snr ark:- ark:- |" + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=1i + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if [ $stage -le -1 ]; then + utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp irm_targets.scp deriv_weights_for_irm_targets.scp" \ + data/train_tztec_whole_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil,amharic}_train_whole_corrupted_spr_hires_bp/ + + cp data/train_tztec_whole_corrupted_spr_hires_bp/{speech_labels.scp,speech_music_labels.scp} + + utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp music_labels.scp speech_music_labels.scp" \ + data/train_tztec_whole_music_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_music_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil,amharic}_train_whole_music_corrupted_spr_hires_bp/ +fi + +sad_data_dir=data/train_tztec_whole_corrupted_spr_hires_bp +music_data_dir=data/train_tztec_whole_music_corrupted_spr_hires_bp + +num_utts=`cat $sad_data_dir/utt2spk $music_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_stats_sad_music_snr/nnet_tdnn_stats +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$sad_data_dir/feats.scp -` name=input + + relu-renorm-layer name=tdnn1 input=Append(-3,-2,-1,0,1,2,3) dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn2 input=Append(-6,0,6) dim=$relu_dim + stats-layer name=tdnn2_stats config=mean+count(-108:6:18:108) + relu-renorm-layer name=tdnn3 input=Append(tdnn2@-12,tdnn2@0,tdnn2@12,tdnn2_stats) dim=$relu_dim + stats-layer name=tdnn3_stats config=mean+count(-108:12:36:108) + relu-renorm-layer name=tdnn4 input=Append(tdnn3@-12,tdnn3@0,tdnn3@12,tdnn3_stats) dim=$relu_dim + relu-renorm-layer name=tdnn4-snr input=Append(tdnn3@-12,tdnn3@0,tdnn3@12,tdnn3_stats) dim=$relu_dim + + output-layer name=output-speech include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech.txt learning-rate-factor=0.1 objective-scale=$speech_scale input=tdnn4 + output-layer name=output-music include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-music.txt learning-rate-factor=0.1 objective-scale=$music_scale input=tdnn4 + output-layer name=output-speech_music include-log-softmax=true dim=4 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech_music.txt learning-rate-factor=0.1 objective-scale=$speech_music_scale input=tdnn4 + output-layer name=output-snr include-log-softmax=false dim=$num_snr_bins objective-type=quadratic learning-rate-factor=0.1 objective-scale=$snr_scale input=tdnn4-snr + + output name=output-temp input=Append(input@-3,input@-2,input@-1,input,input@1,input@2, input@3) +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 2 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$sad_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$sad_data_dir/speech_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech_music --target-type=sparse --dim=4 --targets-scp=$sad_data_dir/speech_music_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-snr --target-type=dense --targets-scp=$sad_data_dir/irm_targets.scp --deriv-weights-scp=$sad_data_dir/deriv_weights_for_irm_targets.scp" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + fi + + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_music/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_music/storage $dir/egs_music/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$music_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-music --target-type=sparse --dim=2 --targets-scp=$music_data_dir/music_labels.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech_music --target-type=sparse --dim=4 --targets-scp=$music_data_dir/speech_music_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$music_data_dir/speech_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_music + fi + + if [ $stage -le 4 ]; then + # num_chunk_per_minibatch is multiplied by 4 to allow a buffer to use + # the same egs with a different num_chunk_per_minibatch + steps/nnet3/multilingual/get_egs.sh \ + --cmd "$train_cmd" \ + --minibatch-size $num_chunk_per_minibatch \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_music $dir/egs_multi + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --trainer.compute-per-dim-accuracy=true \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$sad_data_dir \ + --targets-scp="$sad_data_dir/speech_labels.scp" \ + --dir=$dir || exit 1 +fi + + + diff --git a/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_snr_1j.sh b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_snr_1j.sh new file mode 100644 index 00000000000..059fbf7b1a9 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_snr_1j.sh @@ -0,0 +1,316 @@ +#!/bin/bash + +# This is a script to train a TDNN-LSTM for speech activity detection (SAD) and +# music-id using LSTM for long-context information. +# This is same as 1h, but has more layers. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +chunk_width=20 +num_chunk_per_minibatch=64 + +extra_left_context=79 +extra_right_context=11 + +relu_dim=256 +cell_dim=256 +projection_dim=64 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +extra_egs_copy_cmd="nnet3-copy-egs --keep-outputs=output-speech,output-music,output-speech_music,output-snr ark:- ark:- |" + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=1j + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if [ $stage -le -1 ]; then + utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp irm_targets.scp deriv_weights_for_irm_targets.scp" \ + data/train_tztec_whole_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil}_train_whole_corrupted_spr_hires_bp/ + + cp data/train_tztec_whole_corrupted_spr_hires_bp/{speech_labels.scp,speech_music_labels.scp} + + utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp music_labels.scp speech_music_labels.scp" \ + data/train_tztec_whole_music_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_music_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil}_train_whole_music_corrupted_spr_hires_bp/ +fi + +sad_data_dir=data/train_tztec_whole_corrupted_spr_hires_bp +music_data_dir=data/train_tztec_whole_music_corrupted_spr_hires_bp + +num_utts=`cat $sad_data_dir/utt2spk $music_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_stats_sad_music_snr/nnet_tdnn_stats +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/scales +fi + +if [ $stage -le 2 ]; then + echo "$0: creating neural net configs using the xconfig parser"; + + scales=`cat $dir/scales` + + speech_scale=`echo $scales | awk '{print $1}'` + music_scale=`echo $scales | awk '{print $2}'` + speech_music_scale=`echo $scales | awk '{print $3}'` + snr_scale=`echo $scales | awk '{print $4}'` + + num_snr_bins=`feat-to-dim scp:$sad_data_dir/irm_targets.scp -` + snr_scale=`perl -e "print $snr_scale / $num_snr_bins"` + + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$sad_data_dir/feats.scp -` name=input + + relu-renorm-layer name=tdnn1 input=Append(-2,-1,0,1,2) dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn2 input=Append(-1,0,1,2) dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn3 input=Append(-3,0,3,6) dim=$relu_dim add-log-stddev=true + stats-layer name=tdnn3_stats config=mean+stddev+count(-99:3:9:99) + relu-renorm-layer name=tdnn4 input=Append(tdnn3@-6,tdnn3@0,tdnn3@6,tdnn3@12,tdnn3_stats) add-log-stddev=true dim=$relu_dim + stats-layer name=tdnn4_stats config=mean+stddev+count(-108:6:18:108) + relu-renorm-layer name=tdnn5 input=Append(tdnn4@-12,tdnn4@0,tdnn4@12,tdnn4@24,tdnn4_stats) dim=$relu_dim + relu-renorm-layer name=tdnn5-snr input=Append(tdnn3@-6,tdnn3@0,tdnn3@6,tdnn3@12,tdnn5) dim=$relu_dim + + output-layer name=output-speech include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech.txt learning-rate-factor=0.1 objective-scale=$speech_scale input=tdnn5 + output-layer name=output-music include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-music.txt learning-rate-factor=0.1 objective-scale=$music_scale input=tdnn5 + output-layer name=output-speech_music include-log-softmax=true dim=4 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech_music.txt learning-rate-factor=0.1 objective-scale=$speech_music_scale input=tdnn5 + output-layer name=output-snr include-log-softmax=false dim=$num_snr_bins objective-type=quadratic learning-rate-factor=0.1 objective-scale=$snr_scale input=tdnn5-snr + + output name=output-temp input=Append(input@-3,input@-2,input@-1,input,input@1,input@2, input@3) +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$sad_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$sad_data_dir/speech_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech_music --target-type=sparse --dim=4 --targets-scp=$sad_data_dir/speech_music_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-snr --target-type=dense --targets-scp=$sad_data_dir/irm_targets.scp --deriv-weights-scp=$sad_data_dir/deriv_weights_for_irm_targets.scp" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + fi + + if [ $stage -le 4 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_music/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_music/storage $dir/egs_music/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$music_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-music --target-type=sparse --dim=2 --targets-scp=$music_data_dir/music_labels.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech_music --target-type=sparse --dim=4 --targets-scp=$music_data_dir/speech_music_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$music_data_dir/speech_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_music + fi + + if [ $stage -le 5 ]; then + # num_chunk_per_minibatch is multiplied by 4 to allow a buffer to use + # the same egs with a different num_chunk_per_minibatch + steps/nnet3/multilingual/get_egs.sh \ + --cmd "$train_cmd" \ + --minibatch-size $num_chunk_per_minibatch \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_music $dir/egs_multi + fi +fi + +if [ $stage -le 6 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --trainer.compute-per-dim-accuracy=true \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$sad_data_dir \ + --targets-scp="$sad_data_dir/speech_labels.scp" \ + --dir=$dir || exit 1 +fi + + diff --git a/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_snr_1k.sh b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_snr_1k.sh new file mode 100644 index 00000000000..48425e50386 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_snr_1k.sh @@ -0,0 +1,317 @@ +#!/bin/bash + +# This is a script to train a TDNN-LSTM for speech activity detection (SAD) and +# music-id using LSTM for long-context information. +# This is same as 1h, but has more layers. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +chunk_width=20 +num_chunk_per_minibatch=64 + +extra_left_context=79 +extra_right_context=11 + +relu_dim=256 +cell_dim=256 +projection_dim=64 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +extra_egs_copy_cmd="nnet3-copy-egs --keep-outputs=output-speech,output-music,output-speech_music ark:- ark:- |" + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=1k + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if [ $stage -le -1 ]; then + utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp irm_targets.scp deriv_weights_for_irm_targets.scp" \ + data/train_tztec_whole_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil}_train_whole_corrupted_spr_hires_bp/ + + cp data/train_tztec_whole_corrupted_spr_hires_bp/{speech_labels.scp,speech_music_labels.scp} + + utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp music_labels.scp speech_music_labels.scp" \ + data/train_tztec_whole_music_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_music_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil}_train_whole_music_corrupted_spr_hires_bp/ +fi + +sad_data_dir=data/train_tztec_whole_corrupted_spr_hires_bp +music_data_dir=data/train_tztec_whole_music_corrupted_spr_hires_bp + +num_utts=`cat $sad_data_dir/utt2spk $music_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_stats_sad_music_snr/nnet_tdnn_stats +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/scales +fi + +if [ $stage -le 2 ]; then + echo "$0: creating neural net configs using the xconfig parser"; + + scales=`cat $dir/scales` + + speech_scale=`echo $scales | awk '{print $1}'` + music_scale=`echo $scales | awk '{print $2}'` + speech_music_scale=`echo $scales | awk '{print $3}'` + snr_scale=`echo $scales | awk '{print $4}'` + + num_snr_bins=`feat-to-dim scp:$sad_data_dir/irm_targets.scp -` + snr_scale=`perl -e "print $snr_scale / $num_snr_bins"` + + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$sad_data_dir/feats.scp -` name=input + + relu-renorm-layer name=tdnn1 input=Append(-2,-1,0,1,2) dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn2 input=Append(-1,0,1,2) dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn3 input=Append(-3,0,3,6) dim=$relu_dim add-log-stddev=true + stats-layer name=tdnn3_stats config=mean+stddev+count(-99:3:9:99) + relu-renorm-layer name=tdnn4 input=Append(tdnn3@-6,tdnn3@0,tdnn3@6,tdnn3@12,tdnn3_stats) add-log-stddev=true dim=$relu_dim + stats-layer name=tdnn4_stats config=mean+stddev+count(-108:6:18:108) + relu-renorm-layer name=tdnn5 input=Append(tdnn4@-12,tdnn4@0,tdnn4@12,tdnn4@24,tdnn4_stats) dim=$relu_dim + relu-renorm-layer name=tdnn5-snr input=Append(tdnn3@-6,tdnn3@0,tdnn3@6,tdnn3@12,tdnn5) dim=$relu_dim + + output-layer name=output-speech include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech.txt learning-rate-factor=0.1 objective-scale=$speech_scale input=tdnn5 + output-layer name=output-music include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-music.txt learning-rate-factor=0.1 objective-scale=$music_scale input=tdnn5 + output-layer name=output-speech_music include-log-softmax=true dim=4 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech_music.txt learning-rate-factor=0.1 objective-scale=$speech_music_scale input=tdnn5 + output-layer name=output-snr include-log-softmax=false dim=$num_snr_bins objective-type=quadratic learning-rate-factor=0.1 objective-scale=$snr_scale input=tdnn5-snr + + output name=output-temp input=Append(input@-3,input@-2,input@-1,input,input@1,input@2, input@3) +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$sad_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$sad_data_dir/speech_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech_music --target-type=sparse --dim=4 --targets-scp=$sad_data_dir/speech_music_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-snr --target-type=dense --targets-scp=$sad_data_dir/irm_targets.scp --deriv-weights-scp=$sad_data_dir/deriv_weights_for_irm_targets.scp" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + fi + + if [ $stage -le 4 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_music/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_music/storage $dir/egs_music/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$music_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-music --target-type=sparse --dim=2 --targets-scp=$music_data_dir/music_labels.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech_music --target-type=sparse --dim=4 --targets-scp=$music_data_dir/speech_music_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$music_data_dir/speech_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_music + fi + + if [ $stage -le 5 ]; then + # num_chunk_per_minibatch is multiplied by 4 to allow a buffer to use + # the same egs with a different num_chunk_per_minibatch + steps/nnet3/multilingual/get_egs.sh \ + --cmd "$train_cmd" \ + --minibatch-size $num_chunk_per_minibatch \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_music $dir/egs_multi + fi +fi + +if [ $stage -le 6 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --trainer.compute-per-dim-accuracy=true \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$sad_data_dir \ + --targets-scp="$sad_data_dir/speech_labels.scp" \ + --dir=$dir || exit 1 +fi + + + diff --git a/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_snr_1l.sh b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_snr_1l.sh new file mode 100644 index 00000000000..689c31e623a --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_snr_1l.sh @@ -0,0 +1,318 @@ +#!/bin/bash + +# This is a script to train a TDNN-LSTM for speech activity detection (SAD) and +# music-id using LSTM for long-context information. +# This is same as 1h, but has more layers. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +chunk_width=20 +num_chunk_per_minibatch=64 + +extra_left_context=79 +extra_right_context=11 + +relu_dim=256 +cell_dim=256 +projection_dim=64 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +extra_egs_copy_cmd="nnet3-copy-egs --keep-outputs=output-speech,output-music,output-speech_music,output-snr ark:- ark:- |" + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=1l + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if [ $stage -le -1 ]; then + utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp irm_targets.scp deriv_weights_for_irm_targets.scp" \ + data/train_tztec_whole_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil}_train_whole_corrupted_spr_hires_bp/ + + cp data/train_tztec_whole_corrupted_spr_hires_bp/{speech_labels.scp,speech_music_labels.scp} + + utils/combine_data.sh --extra-files "deriv_weights.scp speech_labels.scp music_labels.scp speech_music_labels.scp" \ + data/train_tztec_whole_music_corrupted_spr_hires_bp data/fisher_train_100k_whole_900_music_corrupted_spr_hires_bp/ \ + data/babel_{turkish,zulu,cantonese,tamil}_train_whole_music_corrupted_spr_hires_bp/ +fi + +sad_data_dir=data/train_tztec_whole_corrupted_spr_hires_bp +music_data_dir=data/train_tztec_whole_music_corrupted_spr_hires_bp + +num_utts=`cat $sad_data_dir/utt2spk $music_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_stats_sad_music_snr/nnet_tdnn_stats +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/scales +fi + +if [ $stage -le 2 ]; then + echo "$0: creating neural net configs using the xconfig parser"; + + scales=`cat $dir/scales` + + speech_scale=`echo $scales | awk '{print $1}'` + music_scale=`echo $scales | awk '{print $2}'` + speech_music_scale=`echo $scales | awk '{print $3}'` + snr_scale=`echo $scales | awk '{print $4}'` + + num_snr_bins=`feat-to-dim scp:$sad_data_dir/irm_targets.scp -` + snr_scale=`perl -e "print $snr_scale / $num_snr_bins"` + + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$sad_data_dir/feats.scp -` name=input + + relu-renorm-layer name=tdnn1 input=Append(-2,-1,0,1,2) dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn2 input=Append(-1,0,1,2) dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn3 input=Append(-3,0,3,6) dim=$relu_dim add-log-stddev=true + stats-layer name=tdnn3_stats config=mean+count(-99:3:9:99) + relu-renorm-layer name=tdnn4 input=Append(tdnn3@-6,tdnn3@0,tdnn3@6,tdnn3@12,tdnn3_stats) add-log-stddev=true dim=$relu_dim + stats-layer name=tdnn4_stats config=mean+count(-108:6:18:108) + relu-renorm-layer name=tdnn5 input=Append(tdnn4@-12,tdnn4@0,tdnn4@12,tdnn4@24,tdnn4_stats) dim=$relu_dim + relu-renorm-layer name=tdnn5-snr input=Append(tdnn3@-6,tdnn3@0,tdnn3@6,tdnn3@12,tdnn5) dim=$relu_dim + + output-layer name=output-speech include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech.txt learning-rate-factor=0.1 objective-scale=$speech_scale input=tdnn5 + output-layer name=output-music include-log-softmax=true dim=2 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-music.txt learning-rate-factor=0.1 objective-scale=$music_scale input=tdnn5 + output-layer name=output-speech_music include-log-softmax=true dim=4 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech_music.txt learning-rate-factor=0.1 objective-scale=$speech_music_scale input=tdnn5 + output-layer name=output-snr include-log-softmax=false dim=$num_snr_bins objective-type=quadratic learning-rate-factor=0.1 objective-scale=$snr_scale input=tdnn5-snr + + output name=output-temp input=Append(input@-3,input@-2,input@-1,input,input@1,input@2, input@3) +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$sad_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$sad_data_dir/speech_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech_music --target-type=sparse --dim=4 --targets-scp=$sad_data_dir/speech_music_labels.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-snr --target-type=dense --targets-scp=$sad_data_dir/irm_targets.scp --deriv-weights-scp=$sad_data_dir/deriv_weights_for_irm_targets.scp" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + fi + + if [ $stage -le 4 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_music/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_music/storage $dir/egs_music/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$music_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-music --target-type=sparse --dim=2 --targets-scp=$music_data_dir/music_labels.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech_music --target-type=sparse --dim=4 --targets-scp=$music_data_dir/speech_music_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$music_data_dir/speech_labels.scp --deriv-weights-scp=$music_data_dir/deriv_weights.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_music + fi + + if [ $stage -le 5 ]; then + # num_chunk_per_minibatch is multiplied by 4 to allow a buffer to use + # the same egs with a different num_chunk_per_minibatch + steps/nnet3/multilingual/get_egs.sh \ + --cmd "$train_cmd" \ + --minibatch-size $num_chunk_per_minibatch \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_music $dir/egs_multi + fi +fi + +if [ $stage -le 6 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --trainer.compute-per-dim-accuracy=true \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$sad_data_dir \ + --targets-scp="$sad_data_dir/speech_labels.scp" \ + --dir=$dir || exit 1 +fi + + + + diff --git a/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1a.sh b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1a.sh new file mode 100755 index 00000000000..c8a7c887fef --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1a.sh @@ -0,0 +1,206 @@ +#!/bin/bash + +# This is a script to train a time-delay neural network for overlapped speech activity detection +# using statistic pooling component for long-context information. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +# TDNN options +relu_dim=256 +chunk_width=40 # We use chunk training for training TDNN +extra_left_context=100 # Maximum left context in egs apart from TDNN's left context +extra_right_context=20 # Maximum right context in egs apart from TDNN's right context + +# We randomly select an extra {left,right} context for each job between +# min_extra_*_context and extra_*_context so that the network can get used +# to different contexts used to compute statistics. +min_extra_left_context=20 +min_extra_right_context=0 + +# training options +num_epochs=1 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +# target options +train_data_dir=data/train_aztec_small_unsad_a +speech_feat_scp=data/train_aztec_small_unsad_a/speech_feat.scp +deriv_weights_scp=data/train_aztec_small_unsad_a/deriv_weights.scp + +#train_data_dir=data/train_aztec_small_unsad_whole_sad_ovlp_corrupted_sp +#speech_feat_scp=data/train_aztec_unsad_whole_corrupted_sp_hires_bp/speech_feat.scp +#deriv_weights_scp=data/train_aztec_unsad_whole_corrupted_sp_hires_bp_2400/deriv_weights.scp +#data/train_aztec_small_unsad_whole_all_corrupted_sp_hires_bp + +# Only for SAD +snr_scp=data/train_aztec_unsad_whole_corrupted_sp_hires_bp/irm_targets.scp +deriv_weights_for_irm_scp=data/train_aztec_unsad_whole_corrupted_sp_hires_bp/deriv_weights_manual_seg.scp + +# Only for overlapped speech detection +deriv_weights_for_overlapped_speech_scp=data/train_aztec_unsad_seg_ovlp_corrupted_hires_bp/deriv_weights_for_overlapped_speech.scp +overlapped_speech_labels_scp=data/train_aztec_unsad_seg_ovlp_corrupted_hires_bp/overlapped_speech_labels.scp + +#extra_left_context=79 +#extra_right_context=11 + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=a + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +num_utts=`cat $train_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_stats_sad_ovlp_snr/nnet_tdnn +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$train_data_dir/feats.scp -` name=input + output name=output-temp input=Append(-3,-2,-1,0,1,2,3) + + relu-renorm-layer name=tdnn1 input=Append(input@-3, input@-2, input@-1, input, input@1, input@2, input@3) dim=256 + stats-layer name=tdnn2_stats config=mean+count(-99:3:9:99) + relu-renorm-layer name=tdnn2 input=Append(tdnn1@-6, tdnn1, tdnn2_stats) dim=256 + relu-renorm-layer name=tdnn3 input=Append(-9,0,3) dim=256 + + relu-renorm-layer name=pre-final-speech dim=256 input=tdnn3 + output-layer name=output-speech include-log-softmax=true dim=2 objective-scale=`perl -e 'print (1.0/6)'` + + relu-renorm-layer name=pre-final-snr dim=256 input=tdnn3 + output-layer name=output-snr include-log-softmax=false dim=$num_snr_bins objective-type=quadratic objective-scale=`perl -e "print 1.0/$num_snr_bins"` + + relu-renorm-layer name=pre-final-overlapped_speech dim=256 input=tdnn3 + output-layer name=output-overlapped_speech include-log-softmax=true dim=2 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs + if [ $stage -le 4 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$train_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=20000 \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-snr --target-type=dense --targets-scp=$snr_scp --deriv-weights-scp=$deriv_weights_for_irm_scp" \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$speech_feat_scp --deriv-weights-scp=$deriv_weights_scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\"" \ + --targets-parameters="--output-name=output-overlapped_speech --target-type=sparse --dim=2 --targets-scp=$overlapped_speech_labels_scp --deriv-weights-scp=$deriv_weights_for_overlapped_speech_scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\"" \ + --dir=$dir/egs + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.min-chunk-left-context=$min_extra_left_context \ + --trainer.min-chunk-right-context=$min_extra_right_context \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=128 \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$train_data_dir \ + --targets-scp="$speech_feat_scp" \ + --dir=$dir || exit 1 +fi + +if [ $stage -le 6 ]; then + $train_cmd JOB=1:100 $dir/log/compute_post_output-speech.JOB.log \ + extract-column "scp:utils/split_scp.pl -j 100 \$[JOB-1] $speech_feat_scp |" ark,t:- \| \ + steps/segmentation/quantize_vector.pl \| \ + ali-to-post ark,t:- ark:- \| \ + weight-post ark:- scp:$deriv_weights_scp ark:- \| \ + post-to-feats --post-dim=2 ark:- ark:- \| \ + matrix-sum-rows ark:- ark:- \| \ + vector-sum ark:- $dir/post_output-speech.vec.JOB + eval vector-sum $dir/post_output-speech.vec.{`seq -s, 100`} $dir/post_output-speech.vec + + $train_cmd JOB=1:100 $dir/log/compute_post_output-overlapped_speech\ + ali-to-post "scp:utils/split_scp.pl -j 100 \$[JOB-1] $overlapped_speech_scp |" ark:- \| \ + post-to-feats --post-dim=2 ark:- ark:- \| \ + matrix-sum-rows ark:- ark:- \| \ + vector-sum ark:- $dir/post_output-overlapped_speech.vec.JOB + eval vector-sum $dir/post_output-overlapped_speech.vec.{`seq -s, 100`} $dir/post_output-overlapped_speech.vec +fi diff --git a/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1b.sh b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1b.sh new file mode 100755 index 00000000000..b562a83f6c3 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1b.sh @@ -0,0 +1,240 @@ +#!/bin/bash + +# This is a script to train a time-delay neural network for overlapped speech activity detection +# using statistic pooling component for long-context information. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +# TDNN options +relu_dim=256 +chunk_width=40 # We use chunk training for training TDNN +num_chunk_per_minibatch=64 + +extra_left_context=100 # Maximum left context in egs apart from TDNN's left context +extra_right_context=20 # Maximum right context in egs apart from TDNN's right context + +# We randomly select an extra {left,right} context for each job between +# min_extra_*_context and extra_*_context so that the network can get used +# to different contexts used to compute statistics. +min_extra_left_context=20 +min_extra_right_context=0 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +sad_data_dir=data/train_aztec_unsad_whole_corrupted_sp_hires_bp_2400 +ovlp_data_dir=data/train_aztec_unsad_seg_ovlp_corrupted_hires_bp + +#extra_left_context=79 +#extra_right_context=11 + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=b + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +num_utts=`cat $sad_data_dir/utt2spk $ovlp_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_stats_sad_ovlp_snr/nnet_tdnn +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$sad_data_dir/feats.scp -` name=input + output name=output-temp input=Append(-3,-2,-1,0,1,2,3) + + relu-renorm-layer name=tdnn1 input=Append(input@-3, input@-2, input@-1, input, input@1, input@2, input@3) dim=256 + stats-layer name=tdnn2_stats config=mean+count(-99:3:9:99) + relu-renorm-layer name=tdnn2 input=Append(tdnn1@-6, tdnn1, tdnn2_stats) dim=256 + relu-renorm-layer name=tdnn3 input=Append(-9,0,3) dim=256 + relu-renorm-layer name=tdnn4 dim=256 + + output-layer name=output-speech include-log-softmax=true dim=2 objective-scale=`perl -e "print ($num_frames_ovlp / $num_frames_sad) ** 0.25"` input=tdnn4 + + output-layer name=output-snr include-log-softmax=false dim=$num_snr_bins objective-type=quadratic objective-scale=`perl -e "print (($num_frames_ovlp / $num_frames_sad) ** 0.25) / $num_snr_bins"` input=tdnn4 + + output-layer name=output-overlapped_speech include-log-softmax=true dim=2 input=tdnn4 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 2 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$sad_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-snr --target-type=dense --targets-scp=$sad_data_dir/irm_targets.scp --deriv-weights-scp=$sad_data_dir/deriv_weights_manual_seg.scp" \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$sad_data_dir/speech_feat.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + fi + + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_ovlp/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_music/storage $dir/egs_music/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$ovlp_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$ovlp_data_dir/speech_feat.scp --deriv-weights-scp=$ovlp_data_dir/deriv_weights.scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\"" \ + --targets-parameters="--output-name=output-overlapped_speech --target-type=sparse --dim=2 --targets-scp=$ovlp_data_dir/overlapped_speech_labels.scp --deriv-weights-scp=$ovlp_data_dir/deriv_weights_for_overlapped_speech.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\"" \ + --generate-egs-scp=true \ + --dir=$dir/egs_ovlp + fi + + if [ $stage -le 4 ]; then + # num_chunk_per_minibatch is multiplied by 4 to allow a buffer to use + # the same egs with a different num_chunk_per_minibatch + steps/nnet3/multilingual/get_egs.sh \ + --cmd "$train_cmd" \ + --minibatch-size $[num_chunk_per_minibatch * 4] \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_ovlp $dir/egs_multi + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.min-chunk-left-context=$min_extra_left_context \ + --trainer.min-chunk-right-context=$min_extra_right_context \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$sad_data_dir \ + --targets-scp="$sad_data_dir/speech_feat.scp" \ + --dir=$dir || exit 1 +fi + diff --git a/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1c.sh b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1c.sh new file mode 100755 index 00000000000..7041b0b3e9b --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1c.sh @@ -0,0 +1,239 @@ +#!/bin/bash + +# This is a script to train a time-delay neural network for overlapped speech activity detection +# using statistic pooling component for long-context information. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +# TDNN options +relu_dim=256 +chunk_width=40 # We use chunk training for training TDNN +num_chunk_per_minibatch=64 + +extra_left_context=100 # Maximum left context in egs apart from TDNN's left context +extra_right_context=20 # Maximum right context in egs apart from TDNN's right context + +# We randomly select an extra {left,right} context for each job between +# min_extra_*_context and extra_*_context so that the network can get used +# to different contexts used to compute statistics. +min_extra_left_context=20 +min_extra_right_context=0 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +sad_data_dir=data/train_aztec_unsad_whole_corrupted_sp_hires_bp_2400 +ovlp_data_dir=data/train_aztec_unsad_seg_ovlp_corrupted_hires_bp + +#extra_left_context=79 +#extra_right_context=11 + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=b + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +num_utts=`cat $sad_data_dir/utt2spk $ovlp_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_stats_sad_ovlp_snr/nnet_tdnn +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$sad_data_dir/feats.scp -` name=input + output name=output-temp input=Append(-3,-2,-1,0,1,2,3) + + relu-renorm-layer name=tdnn1 input=Append(input@-3, input@-2, input@-1, input, input@1, input@2, input@3) dim=256 + stats-layer name=tdnn2_stats config=mean+count(-99:3:9:99) + relu-renorm-layer name=tdnn2 input=Append(tdnn1@-6, tdnn1, tdnn2_stats) dim=256 + relu-renorm-layer name=tdnn3 input=Append(-9,0,3) dim=256 + relu-renorm-layer name=tdnn4 dim=256 + + output-layer name=output-speech include-log-softmax=true dim=2 objective-scale=`perl -e "print ($num_frames_ovlp / $num_frames_sad) ** 0.25"` input=tdnn4 + + output-layer name=output-snr include-log-softmax=false dim=$num_snr_bins objective-type=quadratic objective-scale=`perl -e "print (($num_frames_ovlp / $num_frames_sad) ** 0.25) / $num_snr_bins"` input=tdnn4 + + output-layer name=output-overlapped_speech include-log-softmax=true dim=2 input=tdnn4 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 2 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$sad_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-snr --target-type=dense --targets-scp=$sad_data_dir/irm_targets.scp --deriv-weights-scp=$sad_data_dir/deriv_weights_manual_seg.scp" \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$sad_data_dir/speech_feat.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + fi + + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_ovlp/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_ovlp/storage $dir/egs_ovlp/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$ovlp_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$ovlp_data_dir/speech_feat.scp --deriv-weights-scp=$ovlp_data_dir/deriv_weights.scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\"" \ + --targets-parameters="--output-name=output-overlapped_speech --target-type=sparse --dim=2 --targets-scp=$ovlp_data_dir/overlapped_speech_labels_fixed.scp --deriv-weights-scp=$ovlp_data_dir/deriv_weights_for_overlapped_speech.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\"" \ + --generate-egs-scp=true \ + --dir=$dir/egs_ovlp + fi + + if [ $stage -le 4 ]; then + # num_chunk_per_minibatch is multiplied by 4 to allow a buffer to use + # the same egs with a different num_chunk_per_minibatch + steps/nnet3/multilingual/get_egs.sh \ + --cmd "$train_cmd" \ + --minibatch-size $[num_chunk_per_minibatch * 4] \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_ovlp $dir/egs_multi + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.min-chunk-left-context=$min_extra_left_context \ + --trainer.min-chunk-right-context=$min_extra_right_context \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$sad_data_dir \ + --targets-scp="$sad_data_dir/speech_feat.scp" \ + --dir=$dir || exit 1 +fi diff --git a/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1d.sh b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1d.sh new file mode 100755 index 00000000000..a361435baa1 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1d.sh @@ -0,0 +1,262 @@ +#!/bin/bash + +# This is a script to train a time-delay neural network for overlapped speech activity detection +# using statistic pooling component for long-context information. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +# TDNN options +relu_dim=256 +chunk_width=40 # We use chunk training for training TDNN +num_chunk_per_minibatch=64 + +extra_left_context=100 # Maximum left context in egs apart from TDNN's left context +extra_right_context=20 # Maximum right context in egs apart from TDNN's right context + +# We randomly select an extra {left,right} context for each job between +# min_extra_*_context and extra_*_context so that the network can get used +# to different contexts used to compute statistics. +min_extra_left_context=20 +min_extra_right_context=0 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +sad_data_dir=data/train_aztec_unsad_whole_corrupted_sp_hires_bp_2400 +ovlp_data_dir=data/train_aztec_unsad_seg_ovlp_corrupted_hires_bp + +#extra_left_context=79 +#extra_right_context=11 + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=d + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +num_utts=`cat $sad_data_dir/utt2spk $ovlp_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_stats_sad_ovlp_snr/nnet_tdnn +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$sad_data_dir/feats.scp -` name=input + output name=output-temp input=Append(-3,-2,-1,0,1,2,3) + + relu-renorm-layer name=tdnn1 input=Append(input@-3, input@-2, input@-1, input, input@1, input@2, input@3) dim=256 + stats-layer name=tdnn2_stats config=mean+count(-99:3:9:99) + relu-renorm-layer name=tdnn2 input=Append(tdnn1@-6, tdnn1, tdnn2_stats) dim=256 + relu-renorm-layer name=tdnn3 input=Append(-9,0,3) dim=256 + relu-renorm-layer name=tdnn4 dim=256 + + output-layer name=output-speech include-log-softmax=true dim=2 objective-scale=`perl -e "print ($num_frames_ovlp / $num_frames_sad) ** 0.25"` input=tdnn4 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech.txt + + output-layer name=output-snr include-log-softmax=false dim=$num_snr_bins objective-type=quadratic objective-scale=`perl -e "print (($num_frames_ovlp / $num_frames_sad) ** 0.25) / $num_snr_bins"` input=tdnn4 + + output-layer name=output-overlapped_speech include-log-softmax=true dim=2 input=tdnn4 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-overlapped_speech.txt + +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 2 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$sad_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-snr --target-type=dense --targets-scp=$sad_data_dir/irm_targets.scp --deriv-weights-scp=$sad_data_dir/deriv_weights_manual_seg.scp" \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$sad_data_dir/speech_feat.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + fi + + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_ovlp/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_ovlp/storage $dir/egs_ovlp/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$ovlp_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$ovlp_data_dir/speech_feat.scp --deriv-weights-scp=$ovlp_data_dir/deriv_weights.scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\"" \ + --targets-parameters="--output-name=output-overlapped_speech --target-type=sparse --dim=2 --targets-scp=$ovlp_data_dir/overlapped_speech_labels_fixed.scp --deriv-weights-scp=$ovlp_data_dir/deriv_weights_for_overlapped_speech.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\"" \ + --generate-egs-scp=true \ + --dir=$dir/egs_ovlp + fi + + if [ $stage -le 4 ]; then + # num_chunk_per_minibatch is multiplied by 4 to allow a buffer to use + # the same egs with a different num_chunk_per_minibatch + steps/nnet3/multilingual/get_egs.sh \ + --cmd "$train_cmd" \ + --minibatch-size $[num_chunk_per_minibatch * 4] \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_ovlp $dir/egs_multi + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.min-chunk-left-context=$min_extra_left_context \ + --trainer.min-chunk-right-context=$min_extra_right_context \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$sad_data_dir \ + --targets-scp="$sad_data_dir/speech_feat.scp" \ + --dir=$dir || exit 1 +fi diff --git a/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1f.sh b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1f.sh new file mode 100755 index 00000000000..7048c40f62b --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1f.sh @@ -0,0 +1,272 @@ +#!/bin/bash + +# This is a script to train a time-delay neural network for overlapped speech activity detection +# using statistic pooling component for long-context information. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +# TDNN options +relu_dim=256 +chunk_width=40 # We use chunk training for training TDNN +num_chunk_per_minibatch=64 + +extra_left_context=100 # Maximum left context in egs apart from TDNN's left context +extra_right_context=20 # Maximum right context in egs apart from TDNN's right context + +# We randomly select an extra {left,right} context for each job between +# min_extra_*_context and extra_*_context so that the network can get used +# to different contexts used to compute statistics. +min_extra_left_context=20 +min_extra_right_context=0 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +sad_data_dir=data/train_aztec_unsad_whole_corrupted_sp_hires_bp_2400 +ovlp_data_dir=data/train_aztec_unsad_seg_ovlp_corrupted_hires_bp + +#extra_left_context=79 +#extra_right_context=11 + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=d + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +num_utts=`cat $sad_data_dir/utt2spk $ovlp_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_stats_sad_ovlp_snr/nnet_tdnn +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$sad_data_dir/feats.scp -` name=input + output name=output-temp input=Append(-3,-2,-1,0,1,2,3) + + relu-renorm-layer name=tdnn1 input=Append(input@-3, input@-2, input@-1, input, input@1, input@2, input@3) dim=256 + stats-layer name=tdnn2_stats config=mean+count(-99:3:9:99) + relu-renorm-layer name=tdnn2 input=Append(tdnn1@-6, tdnn1, tdnn2_stats) dim=256 + relu-renorm-layer name=tdnn3 input=Append(-9,0,3) dim=256 + relu-renorm-layer name=tdnn4 dim=256 + + output-layer name=output-speech include-log-softmax=true dim=2 objective-scale=$speech_scale input=tdnn4 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech.txt + + output-layer name=output-snr include-log-softmax=false dim=$num_snr_bins objective-type=quadratic objective-scale=`perl -e "print $speech_scale / $num_snr_bins"` input=tdnn4 max-change=0.75 + + output-layer name=output-overlapped_speech include-log-softmax=true dim=2 objective-scale=$ovlp_scale input=tdnn4 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-overlapped_speech.txt max-change=0.75 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 2 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$sad_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-snr --target-type=dense --targets-scp=$sad_data_dir/irm_targets.scp --deriv-weights-scp=$sad_data_dir/deriv_weights_manual_seg.scp" \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$sad_data_dir/speech_feat.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + fi + + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_ovlp/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_ovlp/storage $dir/egs_ovlp/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$ovlp_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$ovlp_data_dir/speech_feat.scp --deriv-weights-scp=$ovlp_data_dir/deriv_weights.scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\"" \ + --targets-parameters="--output-name=output-overlapped_speech --target-type=sparse --dim=2 --targets-scp=$ovlp_data_dir/overlapped_speech_labels_fixed.scp --deriv-weights-scp=$ovlp_data_dir/deriv_weights_for_overlapped_speech.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\"" \ + --generate-egs-scp=true \ + --dir=$dir/egs_ovlp + fi + + if [ $stage -le 4 ]; then + # num_chunk_per_minibatch is multiplied by 4 to allow a buffer to use + # the same egs with a different num_chunk_per_minibatch + steps/nnet3/multilingual/get_egs.sh \ + --cmd "$train_cmd" \ + --minibatch-size $[num_chunk_per_minibatch * 4] \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_ovlp $dir/egs_multi + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.min-chunk-left-context=$min_extra_left_context \ + --trainer.min-chunk-right-context=$min_extra_right_context \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$sad_data_dir \ + --targets-scp="$sad_data_dir/speech_feat.scp" \ + --dir=$dir || exit 1 +fi diff --git a/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1g.sh b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1g.sh new file mode 100755 index 00000000000..72e26b5347b --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1g.sh @@ -0,0 +1,275 @@ +#!/bin/bash + +# This is a script to train a time-delay neural network for overlapped speech activity detection +# using statistic pooling component for long-context information. + +# This script is same as 1e but adds max-change=0.75 for snr and overlapped_speech outputs +# and learning rate factor 0.1 for the final affine components. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +# TDNN options +relu_dim=256 +chunk_width=40 # We use chunk training for training TDNN +num_chunk_per_minibatch=64 + +extra_left_context=100 # Maximum left context in egs apart from TDNN's left context +extra_right_context=20 # Maximum right context in egs apart from TDNN's right context + +# We randomly select an extra {left,right} context for each job between +# min_extra_*_context and extra_*_context so that the network can get used +# to different contexts used to compute statistics. +min_extra_left_context=20 +min_extra_right_context=0 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +sad_data_dir=data/train_aztec_unsad_whole_corrupted_sp_hires_bp_2400 +ovlp_data_dir=data/train_aztec_unsad_seg_ovlp_corrupted_hires_bp + +#extra_left_context=79 +#extra_right_context=11 + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=g + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +num_utts=`cat $sad_data_dir/utt2spk $ovlp_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_stats_sad_ovlp_snr/nnet_tdnn +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$sad_data_dir/feats.scp -` name=input + output name=output-temp input=Append(-3,-2,-1,0,1,2,3) + + relu-renorm-layer name=tdnn1 input=Append(input@-3, input@-2, input@-1, input, input@1, input@2, input@3) dim=256 + stats-layer name=tdnn2_stats config=mean+count(-99:3:9:99) + relu-renorm-layer name=tdnn2 input=Append(tdnn1@-6, tdnn1, tdnn2_stats) dim=256 + relu-renorm-layer name=tdnn3 input=Append(-9,0,3) dim=256 + relu-renorm-layer name=tdnn4 dim=256 + + output-layer name=output-speech include-log-softmax=true dim=2 objective-scale=$speech_scale input=tdnn4 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech.txt learning-rate-factor=0.1 + + output-layer name=output-snr include-log-softmax=false dim=$num_snr_bins objective-type=quadratic objective-scale=`perl -e "print $speech_scale / $num_snr_bins"` input=tdnn4 max-change=0.75 + + output-layer name=output-overlapped_speech include-log-softmax=true dim=2 objective-scale=$ovlp_scale input=tdnn4 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-overlapped_speech.txt max-change=0.75 learning-rate-factor=0.1 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 2 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$sad_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-snr --target-type=dense --targets-scp=$sad_data_dir/irm_targets.scp --deriv-weights-scp=$sad_data_dir/deriv_weights_manual_seg.scp" \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$sad_data_dir/speech_feat.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + fi + + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_ovlp/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_ovlp/storage $dir/egs_ovlp/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$ovlp_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$ovlp_data_dir/speech_feat.scp --deriv-weights-scp=$ovlp_data_dir/deriv_weights.scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\"" \ + --targets-parameters="--output-name=output-overlapped_speech --target-type=sparse --dim=2 --targets-scp=$ovlp_data_dir/overlapped_speech_labels_fixed.scp --deriv-weights-scp=$ovlp_data_dir/deriv_weights_for_overlapped_speech.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\"" \ + --generate-egs-scp=true \ + --dir=$dir/egs_ovlp + fi + + if [ $stage -le 4 ]; then + # num_chunk_per_minibatch is multiplied by 4 to allow a buffer to use + # the same egs with a different num_chunk_per_minibatch + steps/nnet3/multilingual/get_egs.sh \ + --cmd "$train_cmd" \ + --minibatch-size $[num_chunk_per_minibatch * 4] \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_ovlp $dir/egs_multi + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.min-chunk-left-context=$min_extra_left_context \ + --trainer.min-chunk-right-context=$min_extra_right_context \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$sad_data_dir \ + --targets-scp="$sad_data_dir/speech_feat.scp" \ + --dir=$dir || exit 1 +fi diff --git a/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1h.sh b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1h.sh new file mode 100755 index 00000000000..fb1616b9ac7 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_overlap_1h.sh @@ -0,0 +1,276 @@ +#!/bin/bash + +# This is a script to train a time-delay neural network for overlapped speech activity detection +# using statistic pooling component for long-context information. + +# This script is same as 1e but adds max-change=0.75 for snr and overlapped_speech outputs +# and learning rate factor 0.01 for the final affine components. +# Decreased learning rate factor of overlapped speech to 0.025 and 0.05 for speech. +# Changed relu-dim to 512 + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +# TDNN options +chunk_width=40 # We use chunk training for training TDNN +num_chunk_per_minibatch=64 + +extra_left_context=100 # Maximum left context in egs apart from TDNN's left context +extra_right_context=20 # Maximum right context in egs apart from TDNN's right context + +# We randomly select an extra {left,right} context for each job between +# min_extra_*_context and extra_*_context so that the network can get used +# to different contexts used to compute statistics. +min_extra_left_context=20 +min_extra_right_context=0 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +sad_data_dir=data/train_aztec_unsad_whole_corrupted_sp_hires_bp_2400 +ovlp_data_dir=data/train_aztec_unsad_seg_ovlp_corrupted_hires_bp + +#extra_left_context=79 +#extra_right_context=11 + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=g + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +num_utts=`cat $sad_data_dir/utt2spk $ovlp_data_dir/utt2spk | wc -l` +num_utts_subset_valid=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` +num_utts_subset_train=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 4000 ? 4000 : $n)' $num_utts` + +if [ -z "$dir" ]; then + dir=exp/nnet3_stats_sad_ovlp_snr/nnet_tdnn +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$sad_data_dir/feats.scp -` name=input + output name=output-temp input=Append(-3,-2,-1,0,1,2,3) + + relu-renorm-layer name=tdnn1 input=Append(input@-3, input@-2, input@-1, input, input@1, input@2, input@3) dim=512 + stats-layer name=tdnn2_stats config=mean+count(-99:3:9:99) + relu-renorm-layer name=tdnn2 input=Append(tdnn1@-6, tdnn1, tdnn2_stats) dim=512 + relu-renorm-layer name=tdnn3 input=Append(-9,0,3) dim=512 + relu-renorm-layer name=tdnn4 dim=512 + + output-layer name=output-speech include-log-softmax=true dim=2 objective-scale=$speech_scale input=tdnn4 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-speech.txt learning-rate-factor=0.05 + + output-layer name=output-snr include-log-softmax=false dim=$num_snr_bins objective-type=quadratic objective-scale=`perl -e "print $speech_scale / $num_snr_bins"` input=tdnn4 max-change=0.75 learning-rate-factor=0.5 + + output-layer name=output-overlapped_speech include-log-softmax=true dim=2 objective-scale=$ovlp_scale input=tdnn4 presoftmax-scale-file=$dir/presoftmax_prior_scale_output-overlapped_speech.txt max-change=0.75 learning-rate-factor=0.025 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +samples_per_iter=`perl -e "print int(400000 / $chunk_width)"` + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs_multi + if [ $stage -le 2 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_speech/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_speech/storage $dir/egs_speech/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$sad_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-snr --target-type=dense --targets-scp=$sad_data_dir/irm_targets.scp --deriv-weights-scp=$sad_data_dir/deriv_weights_manual_seg.scp" \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$sad_data_dir/speech_feat.scp --deriv-weights-scp=$sad_data_dir/deriv_weights.scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\" --compress=true" \ + --generate-egs-scp=true \ + --dir=$dir/egs_speech + fi + + if [ $stage -le 3 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs_ovlp/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs_ovlp/storage $dir/egs_ovlp/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$ovlp_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=$samples_per_iter \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$ovlp_data_dir/speech_feat.scp --deriv-weights-scp=$ovlp_data_dir/deriv_weights.scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\"" \ + --targets-parameters="--output-name=output-overlapped_speech --target-type=sparse --dim=2 --targets-scp=$ovlp_data_dir/overlapped_speech_labels_fixed.scp --deriv-weights-scp=$ovlp_data_dir/deriv_weights_for_overlapped_speech.scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\"" \ + --generate-egs-scp=true \ + --dir=$dir/egs_ovlp + fi + + if [ $stage -le 4 ]; then + # num_chunk_per_minibatch is multiplied by 4 to allow a buffer to use + # the same egs with a different num_chunk_per_minibatch + steps/nnet3/multilingual/get_egs.sh \ + --cmd "$train_cmd" \ + --minibatch-size $[num_chunk_per_minibatch * 4] \ + --samples-per-iter $samples_per_iter \ + 2 $dir/egs_speech $dir/egs_ovlp $dir/egs_multi + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.use-multitask-egs=true --egs.rename-multitask-outputs=false \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.min-chunk-left-context=$min_extra_left_context \ + --trainer.min-chunk-right-context=$min_extra_right_context \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=$num_chunk_per_minibatch \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$sad_data_dir \ + --targets-scp="$sad_data_dir/speech_feat.scp" \ + --dir=$dir || exit 1 +fi diff --git a/egs/aspire/s5/path.sh b/egs/aspire/s5/path.sh index 1a6fb5f891b..7fb6d91c543 100755 --- a/egs/aspire/s5/path.sh +++ b/egs/aspire/s5/path.sh @@ -2,4 +2,5 @@ export KALDI_ROOT=`pwd`/../../.. export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$PWD:$PATH [ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present -> Exit!" && exit 1 . $KALDI_ROOT/tools/config/common_path.sh +export PATH=$KALDI_ROOT/tools/sctk/bin:$PATH export LC_ALL=C diff --git a/egs/babel/s5c/conf/lang/101-cantonese-limitedLP.official.conf b/egs/babel/s5c/conf/lang/101-cantonese-limitedLP.official.conf index e5d60c12367..9efcdc6a164 100644 --- a/egs/babel/s5c/conf/lang/101-cantonese-limitedLP.official.conf +++ b/egs/babel/s5c/conf/lang/101-cantonese-limitedLP.official.conf @@ -92,7 +92,7 @@ oovSymbol="" lexiconFlags="--romanized --oov " # Scoring protocols (dummy GLM file to appease the scoring script) -glmFile=/export/babel/data/splits/Cantonese_Babel101/cantonese.glm +glmFile=dummy.glm lexicon_file=/export/babel/data/101-cantonese/release-babel101b-v0.4c_sub-train1/conversational/reference_materials/lexicon.sub-train1.txt cer=1 diff --git a/egs/babel/s5c/conf/lang/105-turkish-limitedLP.official.conf b/egs/babel/s5c/conf/lang/105-turkish-limitedLP.official.conf index ae4cb55f4d5..014b519f3b7 100644 --- a/egs/babel/s5c/conf/lang/105-turkish-limitedLP.official.conf +++ b/egs/babel/s5c/conf/lang/105-turkish-limitedLP.official.conf @@ -3,7 +3,7 @@ #speech corpora files location train_data_dir=/export/babel/data/105-turkish/release-current-b/conversational/training -train_data_list=/export/babel/data/splits/Turkish_Babel105/train.LimitedLP.official.list +train_data_list=/export/babel/data/splits/Turkish_Babel105/train.LimitedLP.list train_nj=16 #RADICAL DEV data files diff --git a/egs/babel/s5c/run-1-main.sh b/egs/babel/s5c/run-1-main.sh index 99d74069087..dc5ed134a04 100755 --- a/egs/babel/s5c/run-1-main.sh +++ b/egs/babel/s5c/run-1-main.sh @@ -249,7 +249,6 @@ if [ ! -f exp/tri5/.done ]; then touch exp/tri5/.done fi - ################################################################################ # Ready to start SGMM training ################################################################################ diff --git a/egs/babel/s5c/run-4-anydecode.sh b/egs/babel/s5c/run-4-anydecode.sh index 312d26911df..472acbfe80e 100755 --- a/egs/babel/s5c/run-4-anydecode.sh +++ b/egs/babel/s5c/run-4-anydecode.sh @@ -10,12 +10,12 @@ dir=dev10h.pem kind= data_only=false fast_path=true -skip_kws=false +skip_kws=true skip_stt=false skip_scoring=false extra_kws=true vocab_kws=false -tri5_only=false +tri5_only=true wip=0.5 echo "run-4-test.sh $@" @@ -195,7 +195,6 @@ if [ ! -f $dataset_dir/.done ] ; then else echo "Unknown type of the dataset: \"$dataset_segments\"!"; echo "Valid dataset types are: seg, uem, pem"; - exit 1 fi elif [ "$dataset_kind" == "unsupervised" ] ; then if [ "$dataset_segments" == "seg" ] ; then @@ -214,12 +213,10 @@ if [ ! -f $dataset_dir/.done ] ; then else echo "Unknown type of the dataset: \"$dataset_segments\"!"; echo "Valid dataset types are: seg, uem, pem"; - exit 1 fi else echo "Unknown kind of the dataset: \"$dataset_kind\"!"; echo "Valid dataset kinds are: supervised, unsupervised, shadow"; - exit 1 fi if [ ! -f ${dataset_dir}/.plp.done ]; then @@ -286,11 +283,11 @@ if ! $fast_path ; then "${lmwt_plp_extra_opts[@]}" \ ${dataset_dir} data/lang ${decode} - local/run_kws_stt_task.sh --cer $cer --max-states $max_states \ - --skip-scoring $skip_scoring --extra-kws $extra_kws --wip $wip \ - --cmd "$decode_cmd" --skip-kws $skip_kws --skip-stt $skip_stt \ - "${lmwt_plp_extra_opts[@]}" \ - ${dataset_dir} data/lang ${decode}.si + #local/run_kws_stt_task.sh --cer $cer --max-states $max_states \ + # --skip-scoring $skip_scoring --extra-kws $extra_kws --wip $wip \ + # --cmd "$decode_cmd" --skip-kws $skip_kws --skip-stt $skip_stt \ + # "${lmwt_plp_extra_opts[@]}" \ + # ${dataset_dir} data/lang ${decode}.si fi if $tri5_only; then diff --git a/egs/bn_music_speech/v1/local/run_dnn_music_id.sh b/egs/bn_music_speech/v1/local/run_dnn_music_id.sh new file mode 100755 index 00000000000..bd30387ae2f --- /dev/null +++ b/egs/bn_music_speech/v1/local/run_dnn_music_id.sh @@ -0,0 +1,130 @@ +#! /bin/bash + +set -e +set -o pipefail +set -u + +stage=-1 +segmentation_config=conf/segmentation.conf +cmd=run.pl +nj=40 + +# Viterbi options +min_silence_duration=3 # minimum number of frames for silence +min_speech_duration=3 # minimum number of frames for speech +min_music_duration=3 # minimum number of frames for music +frame_subsampling_factor=1 +music_transition_probability=0.1 +sil_transition_probability=0.1 +speech_transition_probability=0.1 +sil_prior=0.3 +speech_prior=0.4 +music_prior=0.3 + +# Decoding options +acwt=1 +beam=10 +max_active=7000 + +. utils/parse_options.sh + +if [ $# -ne 4 ]; then + echo "Usage: $0 " + echo " e.g.: $0 data/bn exp/nnet3_sad_snr/tdnn_b_n4/sad_bn_whole exp/nnet3_sad_snr/tdnn_b_n4/music_bn_whole exp/nnet3_sad_snr/tdnn_b_n4/segmentation_bn_whole exp/nnet3_sad_snr/tdnn_b_n4/segmentation_music_bn_whole exp/dnn_music_id" + exit 1 +fi + +data=$1 +sad_likes_dir=$2 +music_likes_dir=$3 +dir=$4 + +min_silence_duration=`perl -e "print (int($min_silence_duration / $frame_subsampling_factor))"` +min_speech_duration=`perl -e "print (int($min_speech_duration / $frame_subsampling_factor))"` +min_music_duration=`perl -e "print (int($min_music_duration / $frame_subsampling_factor))"` + +lang=$dir/lang + +if [ $stage -le 1 ]; then + mkdir -p $lang + + # Create a lang directory with phones.txt and topo with + # silence, music and speech phones. + steps/segmentation/internal/prepare_sad_lang.py \ + --phone-transition-parameters="--phone-list=1 --min-duration=$min_silence_duration --end-transition-probability=$sil_transition_probability" \ + --phone-transition-parameters="--phone-list=2 --min-duration=$min_speech_duration --end-transition-probability=$speech_transition_probability" \ + --phone-transition-parameters="--phone-list=3 --min-duration=$min_music_duration --end-transition-probability=$music_transition_probability" \ + $lang + + cp $lang/phones.txt $lang/words.txt +fi + +feat_dim=2 # dummy. We don't need this. +if [ $stage -le 2 ]; then + $cmd $dir/log/create_transition_model.log gmm-init-mono \ + $lang/topo $feat_dim - $dir/tree \| \ + copy-transition-model --binary=false - $dir/trans.mdl || exit 1 +fi + +# Make unigram G.fst +if [ $stage -le 3 ]; then + cat > $lang/word2prior < $lang/G.fst +fi + +graph_dir=$dir/graph_test + +if [ $stage -le 4 ]; then + $cmd $dir/log/make_vad_graph.log \ + steps/segmentation/internal/make_sad_graph.sh --iter trans \ + $lang $dir $dir/graph_test || exit 1 +fi + +if [ $stage -le 5 ]; then + utils/split_data.sh $data $nj + sdata=$data/split$nj + + nj_sad=`cat $sad_likes_dir/num_jobs` + sad_likes= + for n in `seq $nj_sad`; do + sad_likes="$sad_likes $sad_likes_dir/log_likes.$n.gz" + done + + nj_music=`cat $music_likes_dir/num_jobs` + music_likes= + for n in `seq $nj_music`; do + music_likes="$music_likes $music_likes_dir/log_likes.$n.gz" + done + + decoder_opts+=(--acoustic-scale=$acwt --beam=$beam --max-active=$max_active) + $cmd JOB=1:$nj $dir/log/decode.JOB.log \ + paste-feats "ark:gunzip -c $sad_likes | extract-feature-segments ark,s,cs:- $sdata/JOB/segments ark:- |" \ + "ark,s,cs:gunzip -c $music_likes | extract-feature-segments ark,s,cs:- $sdata/JOB/segments ark:- | select-feats 1 ark:- ark:- |" \ + ark:- \| decode-faster-mapped ${decoder_opts[@]} \ + $dir/trans.mdl $graph_dir/HCLG.fst ark:- \ + ark:/dev/null ark:- \| \ + ali-to-phones --per-frame $dir/trans.mdl ark:- \ + "ark:|gzip -c > $dir/ali.JOB.gz" +fi + +include_silence=true +if [ $stage -le 6 ]; then + $cmd JOB=1:$nj $dir/log/get_class_id.JOB.log \ + ali-to-post "ark:gunzip -c $dir/ali.JOB.gz |" ark:- \| \ + post-to-feats --post-dim=4 ark:- ark:- \| \ + matrix-sum-rows --do-average ark:- ark,t:- \| \ + sid/vector_to_music_labels.pl ${include_silence:+--include-silence-in-music} '>' $dir/ratio.JOB +fi + +for n in `seq $nj`; do + cat $dir/ratio.$n +done > $dir/ratio + +cat $dir/ratio | local/print_scores.py /dev/stdin | compute-eer - diff --git a/egs/bn_music_speech/v1/local/run_nnet3_music_id.sh b/egs/bn_music_speech/v1/local/run_nnet3_music_id.sh new file mode 100644 index 00000000000..d96acdabaaa --- /dev/null +++ b/egs/bn_music_speech/v1/local/run_nnet3_music_id.sh @@ -0,0 +1,217 @@ +#!/bin/bash + +set -e +set -o pipefail +set -u + +. path.sh +. cmd.sh + +feat_affix=bp_vh +affix= +reco_nj=32 + +stage=-1 + +# SAD network config +iter=final +extra_left_context=100 # Set to some large value +extra_right_context=20 + + +# Configs +frame_subsampling_factor=1 + +min_silence_duration=3 # minimum number of frames for silence +min_speech_duration=3 # minimum number of frames for speech +min_music_duration=3 # minimum number of frames for music +music_transition_probability=0.1 +sil_transition_probability=0.1 +speech_transition_probability=0.1 +sil_prior=0.3 +speech_prior=0.4 +music_prior=0.3 + +# Decoding options +acwt=1 +beam=10 +max_active=7000 + +mfcc_config=conf/mfcc_hires_bp.conf + +echo $* + +. utils/parse_options.sh + +if [ $# -ne 3 ]; then + echo "Usage: $0 " + echo " e.g.: $0 data/bn exp/nnet3_sad_snr/tdnn_j_n4 exp/dnn_music_id" + exit 1 +fi + +# Set to true if the test data has > 8kHz sampling frequency. +do_downsampling=true + +data_dir=$1 +sad_nnet_dir=$2 +dir=$3 + +data_id=`basename $data_dir` + +export PATH="$KALDI_ROOT/tools/sph2pipe_v2.5/:$PATH" +[ ! -z `which sph2pipe` ] + +for f in $sad_nnet_dir/$iter.raw $sad_nnet_dir/post_output-speech.vec $sad_nnet_dir/post_output-music.vec; do + if [ ! -f $f ]; then + echo "$0: Could not find $f. See the local/segmentation/run_train_sad.sh" + exit 1 + fi +done + +mkdir -p $dir + +new_data_dir=$dir/${data_id} +if [ $stage -le 0 ]; then + utils/data/convert_data_dir_to_whole.sh $data_dir ${new_data_dir}_whole + + freq=`cat $mfcc_config | perl -pe 's/\s*#.*//g' | grep "sample-frequency=" | awk -F'=' '{if (NF == 0) print 16000; else print $2}'` + sox=`which sox` + + cat $data_dir/wav.scp | python -c "import sys +for line in sys.stdin.readlines(): + splits = line.strip().split() + if splits[-1] == '|': + out_line = line.strip() + ' $sox -t wav - -r $freq -c 1 -b 16 -t wav - downsample |' + else: + out_line = 'cat {0} {1} | $sox -t wav - -r $freq -c 1 -b 16 -t wav - downsample |'.format(splits[0], ' '.join(splits[1:])) + print (out_line)" > ${new_data_dir}_whole/wav.scp + + utils/copy_data_dir.sh ${new_data_dir}_whole ${new_data_dir}_whole_bp_hires +fi + +test_data_dir=${new_data_dir}_whole_bp_hires + +if [ $stage -le 1 ]; then + steps/make_mfcc.sh --mfcc-config $mfcc_config --nj $reco_nj --cmd "$train_cmd" \ + ${new_data_dir}_whole_bp_hires exp/make_hires/${data_id}_whole_bp mfcc_hires + steps/compute_cmvn_stats.sh ${new_data_dir}_whole_bp_hires exp/make_hires/${data_id}_whole_bp mfcc_hires +fi + +if [ $stage -le 2 ]; then + output_name=output-speech + post_vec=$sad_nnet_dir/post_${output_name}.vec + steps/nnet3/compute_output.sh --nj $reco_nj --cmd "$train_cmd" \ + --post-vec "$post_vec" \ + --iter $iter \ + --extra-left-context $extra_left_context \ + --extra-right-context $extra_right_context \ + --frames-per-chunk 150 \ + --output-name $output_name \ + --frame-subsampling-factor $frame_subsampling_factor \ + --get-raw-nnet-from-am false ${test_data_dir} $sad_nnet_dir $dir/sad_${data_id}_whole_bp +fi + +if [ $stage -le 3 ]; then + output_name=output-music + post_vec=$sad_nnet_dir/post_${output_name}.vec + steps/nnet3/compute_output.sh --nj $reco_nj --cmd "$train_cmd" \ + --post-vec "$post_vec" \ + --iter $iter \ + --extra-left-context $extra_left_context \ + --extra-right-context $extra_right_context \ + --frames-per-chunk 150 \ + --output-name $output_name \ + --frame-subsampling-factor $frame_subsampling_factor \ + --get-raw-nnet-from-am false ${test_data_dir} $sad_nnet_dir $dir/music_${data_id}_whole_bp +fi + +if [ $stage -le 4 ]; then + $train_cmd JOB=1:$reco_nj $dir/get_average_likes.JOB.log \ + paste-feats \ + "ark:gunzip -c $dir/sad_${data_id}_whole_bp/log_likes.JOB.gz | extract-feature-segments ark:- 'utils/filter_scp.pl -f 2 ${test_data_dir}/split$reco_nj/JOB/utt2spk $data_dir/segments |' ark:- |" \ + "ark:gunzip -c $dir/music_${data_id}_whole_bp/log_likes.JOB.gz | select-feats 1 ark:- ark:- | extract-feature-segments ark:- 'utils/filter_scp.pl -f 2 ${test_data_dir}/split$reco_nj/JOB/utt2spk $data_dir/segments |' ark:- |" \ + ark:- \| \ + matrix-sum-rows --do-average ark:- ark,t:$dir/average_likes.JOB.ark + + for n in `seq $reco_nj`; do + cat $dir/average_likes.$n.ark + done | awk '{print $1" "( exp($3) + exp($5) + 0.01) / (exp($4) + 0.01)}' | \ + local/print_scores.py /dev/stdin | compute-eer - +fi + +lang=$dir/lang + +if [ $stage -le 5 ]; then + mkdir -p $lang + + # Create a lang directory with phones.txt and topo with + # silence, music and speech phones. + steps/segmentation/internal/prepare_sad_lang.py \ + --phone-transition-parameters="--phone-list=1 --min-duration=$min_silence_duration --end-transition-probability=$sil_transition_probability" \ + --phone-transition-parameters="--phone-list=2 --min-duration=$min_speech_duration --end-transition-probability=$speech_transition_probability" \ + --phone-transition-parameters="--phone-list=3 --min-duration=$min_music_duration --end-transition-probability=$music_transition_probability" \ + $lang + + cp $lang/phones.txt $lang/words.txt +fi + +feat_dim=2 # dummy. We don't need this. +if [ $stage -le 6 ]; then + $train_cmd $dir/log/create_transition_model.log gmm-init-mono \ + $lang/topo $feat_dim - $dir/tree \| \ + copy-transition-model --binary=false - $dir/trans.mdl || exit 1 +fi + +# Make unigram G.fst +if [ $stage -le 7 ]; then + cat > $lang/word2prior < $lang/G.fst +fi + +graph_dir=$dir/graph_test + +if [ $stage -le 8 ]; then + $train_cmd $dir/log/make_vad_graph.log \ + steps/segmentation/internal/make_sad_graph.sh --iter trans \ + $lang $dir $dir/graph_test || exit 1 +fi + +seg_dir=$dir/segmentation_${data_id}_whole_bp +mkdir -p $seg_dir + +if [ $stage -le 9 ]; then + decoder_opts+=(--acoustic-scale=$acwt --beam=$beam --max-active=$max_active) + $train_cmd JOB=1:$reco_nj $dir/decode.JOB.log \ + paste-feats \ + "ark:gunzip -c $dir/sad_${data_id}_whole_bp/log_likes.JOB.gz | extract-feature-segments ark:- 'utils/filter_scp.pl -f 2 ${test_data_dir}/split$reco_nj/JOB/utt2spk $data_dir/segments |' ark:- |" \ + "ark:gunzip -c $dir/music_${data_id}_whole_bp/log_likes.JOB.gz | select-feats 1 ark:- ark:- | extract-feature-segments ark:- 'utils/filter_scp.pl -f 2 ${test_data_dir}/split$reco_nj/JOB/utt2spk $data_dir/segments |' ark:- |" \ + ark:- \| decode-faster-mapped ${decoder_opts[@]} \ + $dir/trans.mdl $graph_dir/HCLG.fst ark:- \ + ark:/dev/null ark:- \| \ + ali-to-phones --per-frame $dir/trans.mdl ark:- \ + "ark:|gzip -c > $seg_dir/ali.JOB.gz" +fi + +include_silence=true +if [ $stage -le 10 ]; then + $train_cmd JOB=1:$reco_nj $dir/log/get_class_id.JOB.log \ + ali-to-post "ark:gunzip -c $seg_dir/ali.JOB.gz |" ark:- \| \ + post-to-feats --post-dim=4 ark:- ark:- \| \ + matrix-sum-rows --do-average ark:- ark,t:- \| \ + sid/vector_to_music_labels.pl ${include_silence:+--include-silence-in-music} '>' $dir/ratio.JOB + + for n in `seq $reco_nj`; do + cat $dir/ratio.$n + done > $dir/ratio + + cat $dir/ratio | local/print_scores.py /dev/stdin | compute-eer - +fi + +# LOG (compute-eer:main():compute-eer.cc:136) Equal error rate is 0.860585%, at threshold 1.99361 diff --git a/egs/rt/s5/cmd.sh b/egs/rt/s5/cmd.sh new file mode 120000 index 00000000000..19f7e836644 --- /dev/null +++ b/egs/rt/s5/cmd.sh @@ -0,0 +1 @@ +../../wsj/s5/cmd.sh \ No newline at end of file diff --git a/egs/rt/s5/conf/fbank.conf b/egs/rt/s5/conf/fbank.conf new file mode 100644 index 00000000000..07e1639e6ee --- /dev/null +++ b/egs/rt/s5/conf/fbank.conf @@ -0,0 +1,3 @@ +# No non-default options for now. +--num-mel-bins=40 # similar to Google's setup. + diff --git a/egs/rt/s5/conf/librispeech_mfcc.conf b/egs/rt/s5/conf/librispeech_mfcc.conf new file mode 100644 index 00000000000..45d284ad05c --- /dev/null +++ b/egs/rt/s5/conf/librispeech_mfcc.conf @@ -0,0 +1 @@ +--use-energy=false diff --git a/egs/rt/s5/conf/mfcc_hires.conf b/egs/rt/s5/conf/mfcc_hires.conf new file mode 100644 index 00000000000..434834a6725 --- /dev/null +++ b/egs/rt/s5/conf/mfcc_hires.conf @@ -0,0 +1,10 @@ +# config for high-resolution MFCC features, intended for neural network training +# Note: we keep all cepstra, so it has the same info as filterbank features, +# but MFCC is more easily compressible (because less correlated) which is why +# we prefer this method. +--use-energy=false # use average of log energy, not energy. +--num-mel-bins=40 # similar to Google's setup. +--num-ceps=40 # there is no dimensionality reduction. +--low-freq=20 # low cutoff frequency for mel bins... this is high-bandwidth data, so + # there might be some information at the low end. +--high-freq=-400 # high cutoff frequently, relative to Nyquist of 8000 (=7600) diff --git a/egs/rt/s5/conf/mfcc_vad.conf b/egs/rt/s5/conf/mfcc_vad.conf new file mode 100644 index 00000000000..22765c6280e --- /dev/null +++ b/egs/rt/s5/conf/mfcc_vad.conf @@ -0,0 +1,5 @@ +--sample-frequency=16000 +--frame-length=25 # the default is 25. +--low-freq=20 # the default. +--high-freq=-600 # the default is zero meaning use the Nyquist (4k in this case). +--num-ceps=13 # higher than the default which is 12. diff --git a/egs/rt/s5/conf/pitch.conf b/egs/rt/s5/conf/pitch.conf new file mode 100644 index 00000000000..e959a19d5b8 --- /dev/null +++ b/egs/rt/s5/conf/pitch.conf @@ -0,0 +1 @@ +--sample-frequency=16000 diff --git a/egs/rt/s5/conf/vad_decode_icsi.conf b/egs/rt/s5/conf/vad_decode_icsi.conf new file mode 100644 index 00000000000..15ba288e3af --- /dev/null +++ b/egs/rt/s5/conf/vad_decode_icsi.conf @@ -0,0 +1,40 @@ +## Features paramters +window_size=100 # 1s +frames_per_gaussian=2000 + +## Phase 1 parameters +num_frames_init_silence=2000 +num_frames_init_sound=10000 +num_frames_init_sound_next=2000 +sil_num_gauss_init=2 +sound_num_gauss_init=2 +sil_max_gauss=2 +sound_max_gauss=6 +sil_gauss_incr=0 +sound_gauss_incr=2 +num_iters=5 +min_sil_variance=0.1 +min_sound_variance=0.01 +min_speech_variance=0.001 + +## Phase 2 parameters +num_frames_init_speech=10000 +speech_num_gauss_init=6 +sil_max_gauss_phase2=7 +sound_max_gauss_phase2=18 +speech_max_gauss_phase2=16 +sil_gauss_incr_phase2=1 +sound_gauss_incr_phase2=2 +speech_gauss_incr_phase2=2 +num_iters_phase2=5 + +## Phase 3 parameters +sil_num_gauss_init_phase3=2 +speech_num_gauss_init_phase3=2 +sil_max_gauss_phase3=5 +speech_max_gauss_phase3=12 +sil_gauss_incr_phase3=1 +speech_gauss_incr_phase3=2 +num_iters_phase3=7 + + diff --git a/egs/rt/s5/conf/vad_decode_pitch.conf b/egs/rt/s5/conf/vad_decode_pitch.conf new file mode 100644 index 00000000000..d7ba1d40093 --- /dev/null +++ b/egs/rt/s5/conf/vad_decode_pitch.conf @@ -0,0 +1,55 @@ +## Features paramters +window_size=10 # 1s +smooth_weights=false +smoothing_window=2 +smooth_mask=true + +## Phase 1 parameters +num_frames_init_silence=200 +num_frames_init_sound=200 +num_frames_init_sound_next=200 +sil_num_gauss_init=2 +sound_num_gauss_init=2 +sil_max_gauss=2 +sound_max_gauss=2 +sil_gauss_incr=1 +sound_gauss_incr=1 +sil_frames_incr=200 +sound_frames_incr=200 +sound_frames_next_incr=200 +num_iters=5 +min_sil_variance=0.1 +min_sound_variance=0.01 +min_speech_variance=0.001 + +## Phase 2 parameters +num_frames_init_speech=5000 +speech_num_gauss_init=6 +sil_max_gauss_phase2=7 +sound_max_gauss_phase2=7 +speech_max_gauss_phase2=16 +sil_gauss_incr_phase2=1 +sound_gauss_incr_phase2=1 +speech_gauss_incr_phase2=2 +num_iters_phase2=20 +window_size_phase2_init=10 +window_size_phase2_next=10 +window_size_incr_iter=5 + +num_frames_init_speech_phase2=100000 +num_frames_init_silence_phase2=200000 +num_frames_init_sound_phase2=200000 +speech_frames_incr_phase2=200000 +sil_frames_incr_phase2=200000 +sound_frames_incr_phase2=200000 + +## Phase 3 parameters +sil_num_gauss_init_phase3=2 +speech_num_gauss_init_phase3=2 +sil_max_gauss_phase3=5 +speech_max_gauss_phase3=12 +sil_gauss_incr_phase3=1 +speech_gauss_incr_phase3=2 +num_iters_phase3=7 + + diff --git a/egs/rt/s5/conf/vad_icsi_babel.conf b/egs/rt/s5/conf/vad_icsi_babel.conf new file mode 100644 index 00000000000..70f651403f5 --- /dev/null +++ b/egs/rt/s5/conf/vad_icsi_babel.conf @@ -0,0 +1,39 @@ +## Features paramters +window_size=10 # 100 ms +frames_per_gaussian=200 + +## Phase 1 parameters +num_frames_init_silence=2000 # 20s - Lowest energy frames selected to initialize Silence GMM +num_frames_init_sound=10000 # 100s - Highest energy frames selected to initialize Sound GMM +num_frames_init_sound_next=2000 # 20s - Highest zero crossing frames selected to initialize Sound GMM +sil_num_gauss_init=2 +sound_num_gauss_init=2 +sil_max_gauss=2 +sound_max_gauss=6 +sil_gauss_incr=0 +sound_gauss_incr=2 +num_iters=5 +min_sil_variance=0.1 +min_sound_variance=0.01 +min_speech_variance=0.001 + +## Phase 2 parameters +speech_num_gauss_init=6 +sil_max_gauss_phase2=7 +sound_max_gauss_phase2=18 +speech_max_gauss_phase2=16 +sil_gauss_incr_phase2=1 +sound_gauss_incr_phase2=2 +speech_gauss_incr_phase2=2 +num_iters_phase2=5 + +## Phase 3 parameters +sil_num_gauss_init_phase3=2 +speech_num_gauss_init_phase3=2 +sil_max_gauss_phase3=5 +speech_max_gauss_phase3=12 +sil_gauss_incr_phase3=1 +speech_gauss_incr_phase3=2 +num_iters_phase3=7 + + diff --git a/egs/rt/s5/conf/vad_icsi_babel_3models.conf b/egs/rt/s5/conf/vad_icsi_babel_3models.conf new file mode 100644 index 00000000000..1196f0d2aff --- /dev/null +++ b/egs/rt/s5/conf/vad_icsi_babel_3models.conf @@ -0,0 +1,54 @@ +## Features paramters +window_size=10 # 100 ms +frames_per_gaussian=200 + +## Phase 1 parameters +num_frames_init_silence=2000 # 20s - Lowest energy frames selected to initialize Silence GMM +num_frames_init_sound=10000 # 100s - Highest energy frames selected to initialize Sound GMM +num_frames_init_sound_next=2000 # 20s - Highest zero crossing frames selected to initialize Sound GMM +sil_num_gauss_init=2 +sound_num_gauss_init=2 +sil_max_gauss=2 +sound_max_gauss=6 +sil_gauss_incr=0 +sound_gauss_incr=2 +num_iters=5 +min_sil_variance=0.1 +min_sound_variance=0.01 +min_speech_variance=0.001 + +## Phase 2 parameters +speech_num_gauss_init=6 +sil_max_gauss_phase2=7 +sound_max_gauss_phase2=18 +speech_max_gauss_phase2=16 +sil_gauss_incr_phase2=1 +sound_gauss_incr_phase2=2 +speech_gauss_incr_phase2=2 +num_iters_phase2=5 + +## Phase 3 parameters +num_frames_silence_phase3_init=2000 +num_frames_speech_phase3_init=2000 +sil_num_gauss_init_phase3=2 +speech_num_gauss_init_phase3=2 +sil_max_gauss_phase3=5 +sil_max_gauss_phase4=8 +speech_max_gauss_phase4=16 +sil_gauss_incr_phase3=1 +sil_gauss_incr_phase4=1 +speech_gauss_incr_phase4=2 +num_iters_phase3=5 +num_iters_phase4=5 + +## Phase 5 parameters +sil_num_gauss_init_phase5=2 +speech_num_gauss_init_phase5=2 +sil_max_gauss_phase5=5 +speech_max_gauss_phase5=12 +sil_gauss_incr_phase5=1 +speech_gauss_incr_phase5=2 +num_iters_phase5=7 + + + diff --git a/egs/rt/s5/conf/vad_icsi_rt.conf b/egs/rt/s5/conf/vad_icsi_rt.conf new file mode 100644 index 00000000000..d19038014db --- /dev/null +++ b/egs/rt/s5/conf/vad_icsi_rt.conf @@ -0,0 +1,40 @@ +## Features paramters +window_size=10 # 100 ms +frames_per_gaussian=200 + +## Phase 1 parameters +num_frames_init_silence=2000 +num_frames_init_sound=10000 +num_frames_init_sound_next=2000 +sil_num_gauss_init=2 +sound_num_gauss_init=2 +sil_max_gauss=2 +sound_max_gauss=6 +sil_gauss_incr=0 +sound_gauss_incr=2 +num_iters=5 +min_sil_variance=0.1 +min_sound_variance=0.01 +min_speech_variance=0.001 + +## Phase 2 parameters +#num_frames_init_speech=10000 +speech_num_gauss_init=6 +sil_max_gauss_phase2=7 +sound_max_gauss_phase2=18 +speech_max_gauss_phase2=16 +sil_gauss_incr_phase2=1 +sound_gauss_incr_phase2=2 +speech_gauss_incr_phase2=2 +num_iters_phase2=5 + +## Phase 3 parameters +sil_num_gauss_init_phase3=2 +speech_num_gauss_init_phase3=2 +sil_max_gauss_phase3=5 +speech_max_gauss_phase3=12 +sil_gauss_incr_phase3=1 +speech_gauss_incr_phase3=2 +num_iters_phase3=7 + + diff --git a/egs/rt/s5/conf/vad_snr_rt.conf b/egs/rt/s5/conf/vad_snr_rt.conf new file mode 100644 index 00000000000..a1029eb8fe6 --- /dev/null +++ b/egs/rt/s5/conf/vad_snr_rt.conf @@ -0,0 +1,35 @@ +## Features paramters +window_size=5 # 5 frame. Window over which initial selection of frames + +frames_per_silence_gaussian=200 # 2s per Gaussian +frames_per_sound_gaussian=200 # 2s per Gaussian +frames_per_speech_gaussian=2000 # 20s per Gaussian + +## Phase 1 parameters +num_frames_init_silence=1000 # 10s - Lowest energy frames selected to initialize Silence GMM +num_frames_init_silence_next=200 # 2s - Lowest energy frames selected to initialize Silence GMM +num_frames_init_sound=1000 # 10s - Highest energy frames selected to initialize Sound GMM +num_frames_init_sound_next=200 # 2s - Highest zero crossing frames selected to initialize Sound GMM +num_frames_init_speech=10000 # 100s - Highest energy frames selected to initialize Sound GMM +sil_num_gauss_init=2 +sound_num_gauss_init=2 +speech_num_gauss_init=6 +sil_max_gauss=7 +sound_max_gauss=12 +speech_max_gauss=16 +sil_gauss_incr=1 +sound_gauss_incr=2 +speech_gauss_incr=2 +num_iters=10 + +## Phase 3 parameters +num_frames_init_silence_phase3=1000 # 10s - Lowest energy frames selected to initialize Silence GMM +num_frames_init_silence_next_phase3=200 # 2s - Lowest energy frames selected to initialize Silence GMM +num_frames_init_speech_phase3=10000 # 100s - Highest energy frames selected to initialize Sound GMM +sil_num_gauss_init=2 +speech_num_gauss_init=6 +sil_max_gauss=7 +speech_max_gauss=16 +sil_gauss_incr=1 +speech_gauss_incr=2 +num_iters_phase3=10 diff --git a/egs/rt/s5/conf/zc_vad.conf b/egs/rt/s5/conf/zc_vad.conf new file mode 100644 index 00000000000..b5d94450709 --- /dev/null +++ b/egs/rt/s5/conf/zc_vad.conf @@ -0,0 +1,5 @@ +--sample-frequency=16000 +--frame-length=25 # the default is 25. +--dither=0.0 +--zero-crossing-threshold=1e-5 + diff --git a/egs/rt/s5/diarization b/egs/rt/s5/diarization new file mode 120000 index 00000000000..ba78a9126af --- /dev/null +++ b/egs/rt/s5/diarization @@ -0,0 +1 @@ +../../sre08/v1/diarization \ No newline at end of file diff --git a/egs/rt/s5/local/make_rt_2004_dev.pl b/egs/rt/s5/local/make_rt_2004_dev.pl new file mode 100755 index 00000000000..8a08dd268a7 --- /dev/null +++ b/egs/rt/s5/local/make_rt_2004_dev.pl @@ -0,0 +1,64 @@ +#!/usr/bin/perl -w +# Copyright 2015 Vimal Manohar +# Apache 2.0. + +use strict; +use File::Basename; + +if (@ARGV != 2) { + print STDERR "Usage: $0 \n" . + " e.g.: $0 /export/corpora5/LDC/LDC2007S11 data\n"; + exit(1); +} + +my ($db_base, $out_dir) = @ARGV; +$out_dir = "$out_dir/rt04_dev"; + +if (system("mkdir -p $out_dir")) { + die "Error making directory $out_dir"; +} + +open(SPKR, ">", "$out_dir/utt2spk") + or die "Could not open the output file $out_dir/utt2spk"; +open(WAV, ">", "$out_dir/wav.scp") + or die "Could not open the output file $out_dir/wav.scp"; +open(RECO2FILE_AND_CHANNEL, ">", "$out_dir/reco2file_and_channel") + or die "Could not open the output file $out_dir/reco2file_and_channel"; + +open(LIST, 'find ' . $db_base . '/data/audio/dev04s -name "*.sph" |'); + + +my $sox =`which sox` || die "Could not find sox in PATH"; +chomp($sox); + +while (my $line = ) { + chomp($line); + my ($file_id, $path, $suffix) = fileparse($line, qr/\.[^.]*/); + if ($suffix =~ /.sph/) { + #print WAV $file_id . " $sox $line -c 1 -b 16 -t wav - |\n"; + print WAV $file_id . " sph2pipe -f wav $line |\n"; + } elsif ($suffix =~ /.wav/) { + print WAV $file_id . " $line |\n"; + } else { + die "$0: Unknown suffix $suffix in $line\n" + } + + print SPKR "$file_id $file_id\n"; + print RECO2FILE_AND_CHANNEL "$file_id $file_id 1\n"; +} + +close(LIST) || die; +close(WAV) || die; +close(SPKR) || die; + +if (system( + "utils/utt2spk_to_spk2utt.pl $out_dir/utt2spk >$out_dir/spk2utt") != 0) { + die "Error creating spk2utt file in directory $out_dir"; +} + +system("utils/fix_data_dir.sh $out_dir"); + +if (system( + "utils/validate_data_dir.sh --no-text --no-feats $out_dir") != 0) { + die "Error validating directory $out_dir"; +} diff --git a/egs/rt/s5/local/make_rt_2004_eval.pl b/egs/rt/s5/local/make_rt_2004_eval.pl new file mode 100755 index 00000000000..4c1286ea1cc --- /dev/null +++ b/egs/rt/s5/local/make_rt_2004_eval.pl @@ -0,0 +1,64 @@ +#!/usr/bin/perl -w +# Copyright 2015 Vimal Manohar +# Apache 2.0. + +use strict; +use File::Basename; + +if (@ARGV != 2) { + print STDERR "Usage: $0 \n" . + " e.g.: $0 /export/corpora5/LDC/LDC2007S12/package/rt04_eval data\n"; + exit(1); +} + +my ($db_base, $out_dir) = @ARGV; +$out_dir = "$out_dir/rt04_eval"; + +if (system("mkdir -p $out_dir")) { + die "Error making directory $out_dir"; +} + +open(SPKR, ">", "$out_dir/utt2spk") + or die "Could not open the output file $out_dir/utt2spk"; +open(WAV, ">", "$out_dir/wav.scp") + or die "Could not open the output file $out_dir/wav.scp"; +open(RECO2FILE_AND_CHANNEL, ">", "$out_dir/reco2file_and_channel") + or die "Could not open the output file $out_dir/reco2file_and_channel"; + +open(LIST, 'find ' . $db_base . '/data/audio/eval04s -name "*.sph" |'); + +my $sox =`which sox` || die "Could not find sox in PATH"; +chomp($sox); + +while (my $line = ) { + chomp($line); + my ($file_id, $path, $suffix) = fileparse($line, qr/\.[^.]*/); + if ($suffix =~ /.sph/) { + #print WAV $file_id . " $sox $line -c 1 -b 16 -t wav - |\n"; + print WAV $file_id . " sph2pipe -f wav $line |\n"; + } elsif ($suffix =~ /.wav/) { + print WAV $file_id . " $line |\n"; + } else { + die "$0: Unknown suffix $suffix in $line\n" + } + + print SPKR "$file_id $file_id\n"; + print RECO2FILE_AND_CHANNEL "$file_id $file_id 1\n"; +} + +close(LIST) || die; +close(WAV) || die; +close(SPKR) || die; + +if (system( + "utils/utt2spk_to_spk2utt.pl $out_dir/utt2spk >$out_dir/spk2utt") != 0) { + die "Error creating spk2utt file in directory $out_dir"; +} + +system("utils/fix_data_dir.sh $out_dir"); + +if (system( + "utils/validate_data_dir.sh --no-text --no-feats $out_dir") != 0) { + die "Error validating directory $out_dir"; +} + diff --git a/egs/rt/s5/local/make_rt_2005_eval.pl b/egs/rt/s5/local/make_rt_2005_eval.pl new file mode 100755 index 00000000000..d48dcaae926 --- /dev/null +++ b/egs/rt/s5/local/make_rt_2005_eval.pl @@ -0,0 +1,64 @@ +#!/usr/bin/perl -w +# Copyright 2015 Vimal Manohar +# Apache 2.0. + +use strict; +use File::Basename; + +if (@ARGV != 2) { + print STDERR "Usage: $0 \n" . + " e.g.: $0 /export/corpora5/LDC/LDC2011S06 data\n"; + exit(1); +} + +my ($db_base, $out_dir) = @ARGV; +$out_dir = "$out_dir/rt05_eval"; + +if (system("mkdir -p $out_dir")) { + die "Error making directory $out_dir"; +} + +open(SPKR, ">", "$out_dir/utt2spk") + or die "Could not open the output file $out_dir/utt2spk"; +open(WAV, ">", "$out_dir/wav.scp") + or die "Could not open the output file $out_dir/wav.scp"; +open(RECO2FILE_AND_CHANNEL, ">", "$out_dir/reco2file_and_channel") + or die "Could not open the output file $out_dir/reco2file_and_channel"; + +open(LIST, 'find ' . $db_base . '/data/audio/eval05s -name "*.sph" |'); + +my $sox =`which sox` || die "Could not find sox in PATH"; +chomp($sox); + +while (my $line = ) { + chomp($line); + my ($file_id, $path, $suffix) = fileparse($line, qr/\.[^.]*/); + if ($suffix =~ /.sph/) { + print WAV $file_id . " $sox $line -c 1 -b 16 -t wav - |\n"; + } elsif ($suffix =~ /.wav/) { + print WAV $file_id . " $line |\n"; + } else { + die "$0: Unknown suffix $suffix in $line\n" + } + + print SPKR "$file_id $file_id\n"; + print RECO2FILE_AND_CHANNEL "$file_id $file_id 1\n"; +} + +close(LIST) || die; +close(WAV) || die; +close(SPKR) || die; + +if (system( + "utils/utt2spk_to_spk2utt.pl $out_dir/utt2spk >$out_dir/spk2utt") != 0) { + die "Error creating spk2utt file in directory $out_dir"; +} + +system("utils/fix_data_dir.sh $out_dir"); + +if (system( + "utils/validate_data_dir.sh --no-text --no-feats $out_dir") != 0) { + die "Error validating directory $out_dir"; +} + + diff --git a/egs/rt/s5/local/run_prepare_rt.sh b/egs/rt/s5/local/run_prepare_rt.sh new file mode 100755 index 00000000000..c431f760dab --- /dev/null +++ b/egs/rt/s5/local/run_prepare_rt.sh @@ -0,0 +1,87 @@ +#! /bin/bash + +# Copyright 2016 Vimal Manohar +# Apache 2.0 + +set -e +set -o pipefail +set -u + +. path.sh +. cmd.sh + +mic=sdm +task=sad + +. parse_options.sh + +RT04_DEV_ROOT=/export/corpora5/LDC/LDC2007S11 +RT04_EVAL_ROOT=/export/corpora5/LDC/LDC2007S12/package/rt04_eval +RT05_EVAL_ROOT=/export/corpora5/LDC/LDC2011S06 + +if [ ! -f data/rt04_dev/.done ]; then + local/make_rt_2004_dev.pl $RT04_DEV_ROOT data + touch data/rt04_dev/.done +fi + +if [ ! -f data/rt04_eval/.done ]; then + local/make_rt_2004_eval.pl $RT04_EVAL_ROOT data + touch data/rt04_eval/.done +fi + +if [ ! -f data/rt05_eval/.done ]; then + local/make_rt_2005_eval.pl $RT05_EVAL_ROOT data + touch data/rt05_eval/.done +fi + +mkdir -p data/local + +dir=data/local/rt05_eval/$mic/$task +mkdir -p $dir + +if [ $task == "stt" ]; then + cp $RT05_EVAL_ROOT/data/reference/concatenated/rt05s.confmtg.050614.${task}.${mic}.stm $dir/stm +else + cp $RT05_EVAL_ROOT/data/reference/concatenated/rt05s.confmtg.050614.${task}.${mic}.rttm $dir/rttm +fi + +cp $RT05_EVAL_ROOT/data/indicies/expt_05s_${task}ul_eval05s_eng_confmtg_${mic}_1.uem $dir/uem +cat $dir/uem | awk '!/;;/{if (NF > 0) print $1}' | perl -pe 's/(.*)\.sph/$1/g' | sort -u > $dir/list +utils/subset_data_dir.sh --utt-list $dir/list data/rt05_eval data/rt05_eval_${mic}_${task} +[ -f $dir/stm ] && cp $dir/stm data/rt05_eval_${mic}_${task} +[ -f $dir/uem ] && cp $dir/uem data/rt05_eval_${mic}_${task} +[ -f $dir/rttm ] && cp $dir/rttm data/rt05_eval_${mic}_${task} + +dir=data/local/rt04_dev/$mic/$task +mkdir -p $dir + +if [ $task == "stt" ]; then + cp $RT04_DEV_ROOT/data/reference/dev04s/concatenated/dev04s.040809.${mic}.stm $dir/stm +elif [ $task == "spkr" ]; then + cp $RT04_DEV_ROOT/data/reference/dev04s/concatenated/dev04s.040809.${mic}.rttm $dir/rttm +else + cat $RT04_DEV_ROOT/data/reference/dev04s/concatenated/dev04s.040809.${mic}.rttm | spkr2sad.pl | rttmSmooth.pl -s 0 > $dir/rttm +fi +cp $RT04_DEV_ROOT/data/indices/dev04s/dev04s.${mic}.uem $dir/uem +cat $dir/uem | awk '!/;;/{if (NF > 0) print $1}' | perl -pe 's/(.*)\.sph/$1/g' | sort -u > $dir/list +utils/subset_data_dir.sh --utt-list $dir/list data/rt04_dev data/rt04_dev_${mic}_${task} +[ -f $dir/stm ] && cp $dir/stm data/rt04_dev_${mic}_${task} +[ -f $dir/uem ] && cp $dir/uem data/rt04_dev_${mic}_${task} +[ -f $dir/rttm ] && cp $dir/rttm data/rt04_dev_${mic}_${task} + +dir=data/local/rt04_eval/$mic/$task +mkdir -p $dir + +if [ $task == "stt" ]; then + cp $RT04_EVAL_ROOT/data/reference/eval04s/concatenated/eval04s.040511.${mic}.stm $dir/stm +elif [ $task == "spkr" ]; then + cp $RT04_EVAL_ROOT/data/reference/eval04s/concatenated/eval04s.040511.${mic}.rttm $dir/rttm +else + cat $RT04_EVAL_ROOT/data/reference/eval04s/concatenated/eval04s.040511.${mic}.rttm | spkr2sad.pl | rttmSmooth.pl -s 0 > $dir/rttm +fi +cp $RT04_EVAL_ROOT/data/indices/eval04s/eval04s.${mic}.uem $dir/uem +cat $dir/uem | awk '!/;;/{if (NF > 0) print $1}' | perl -pe 's/(.*)\.sph/$1/g' | sort -u > $dir/list +utils/subset_data_dir.sh --utt-list $dir/list data/rt04_eval data/rt04_eval_${mic}_${task} +[ -f $dir/stm ] && cp $dir/stm data/rt04_eval_${mic}_${task} +[ -f $dir/uem ] && cp $dir/uem data/rt04_eval_${mic}_${task} +[ -f $dir/rttm ] && cp $dir/rttm data/rt04_eval_${mic}_${task} diff --git a/egs/rt/s5/local/score.sh b/egs/rt/s5/local/score.sh new file mode 100755 index 00000000000..1c3e2cbe8c4 --- /dev/null +++ b/egs/rt/s5/local/score.sh @@ -0,0 +1,53 @@ +#!/bin/bash + +# Copyright Johns Hopkins University (Author: Daniel Povey) 2012 +# Copyright University of Edinburgh (Author: Pawel Swietojanski) 2014 +# Apache 2.0 + +orig_args= +for x in "$@"; do orig_args="$orig_args '$x'"; done + +# begin configuration section. we include all the options that score_sclite.sh or +# score_basic.sh might need, or parse_options.sh will die. +cmd=run.pl +stage=0 +min_lmwt=9 # unused, +max_lmwt=15 # unused, +asclite=true +#end configuration section. + +[ -f ./path.sh ] && . ./path.sh +. parse_options.sh || exit 1; + +if [ $# -ne 3 ]; then + echo "Usage: local/score.sh [options] " && exit; + echo " Options:" + echo " --cmd (run.pl|queue.pl...) # specify how to run the sub-processes." + echo " --stage (0|1|2) # start scoring script from part-way through." + echo " --min_lmwt # minumum LM-weight for lattice rescoring " + echo " --max_lmwt # maximum LM-weight for lattice rescoring " + echo " --asclite (true/false) # score with ascltie instead of sclite (overlapped speech)" + exit 1; +fi + +data=$1 + +mic=$(echo $data | awk -F '/' '{print $2}') +case $mic in + ihm*) + echo "Using sclite for IHM (close talk)," + eval local/score_asclite.sh --asclite false $orig_args + ;; + sdm*) + echo "Using asclite for overlapped speech SDM (single distant mic)," + eval local/score_asclite.sh --asclite $asclite $orig_args + ;; + mdm*) + echo "Using asclite for overlapped speech MDM (multiple distant mics)," + eval local/score_asclite.sh --asclite $asclite $orig_args + ;; + *) + echo "local/score.sh: no ihm/sdm/mdm directories found. AMI recipe assumes data/{ihm,sdm,mdm}/..." + exit 1; + ;; +esac diff --git a/egs/rt/s5/local/score_asclite.sh b/egs/rt/s5/local/score_asclite.sh new file mode 100755 index 00000000000..86b801b975d --- /dev/null +++ b/egs/rt/s5/local/score_asclite.sh @@ -0,0 +1,120 @@ +#!/bin/bash +# Copyright Johns Hopkins University (Author: Daniel Povey) 2012. Apache 2.0. +# 2014, University of Edinburgh, (Author: Pawel Swietojanski) + +# begin configuration section. +cmd=run.pl +stage=0 +min_lmwt=9 +max_lmwt=15 +reverse=false +asclite=true +overlap_spk=4 +#end configuration section. + +[ -f ./path.sh ] && . ./path.sh +. parse_options.sh || exit 1; + +if [ $# -ne 3 ]; then + echo "Usage: local/score_asclite.sh [--cmd (run.pl|queue.pl...)] " + echo " Options:" + echo " --cmd (run.pl|queue.pl...) # specify how to run the sub-processes." + echo " --stage (0|1|2) # start scoring script from part-way through." + echo " --min_lmwt # minumum LM-weight for lattice rescoring " + echo " --max_lmwt # maximum LM-weight for lattice rescoring " + echo " --reverse (true/false) # score with time reversed features " + exit 1; +fi + +data=$1 +lang=$2 # Note: may be graph directory not lang directory, but has the necessary stuff copied. +dir=$3 + +model=$dir/../final.mdl # assume model one level up from decoding dir. + +hubscr=$KALDI_ROOT/tools/sctk/bin/hubscr.pl +[ ! -f $hubscr ] && echo "Cannot find scoring program at $hubscr" && exit 1; +hubdir=`dirname $hubscr` + +for f in $data/stm $data/glm $lang/words.txt $lang/phones/word_boundary.int \ + $model $data/segments $data/reco2file_and_channel $dir/lat.1.gz; do + [ ! -f $f ] && echo "$0: expecting file $f to exist" && exit 1; +done + +name=`basename $data`; # e.g. eval2000 + +mkdir -p $dir/ascoring/log + +if [ $stage -le 0 ]; then + if $reverse; then + $cmd LMWT=$min_lmwt:$max_lmwt $dir/ascoring/log/get_ctm.LMWT.log \ + mkdir -p $dir/ascore_LMWT/ '&&' \ + lattice-1best --lm-scale=LMWT "ark:gunzip -c $dir/lat.*.gz|" ark:- \| \ + lattice-reverse ark:- ark:- \| \ + lattice-align-words --reorder=false $lang/phones/word_boundary.int $model ark:- ark:- \| \ + nbest-to-ctm ark:- - \| \ + utils/int2sym.pl -f 5 $lang/words.txt \| \ + utils/convert_ctm.pl $data/segments $data/reco2file_and_channel \ + '>' $dir/ascore_LMWT/$name.ctm || exit 1; + else + $cmd LMWT=$min_lmwt:$max_lmwt $dir/ascoring/log/get_ctm.LMWT.log \ + mkdir -p $dir/ascore_LMWT/ '&&' \ + lattice-1best --lm-scale=LMWT "ark:gunzip -c $dir/lat.*.gz|" ark:- \| \ + lattice-align-words $lang/phones/word_boundary.int $model ark:- ark:- \| \ + nbest-to-ctm ark:- - \| \ + utils/int2sym.pl -f 5 $lang/words.txt \| \ + utils/convert_ctm.pl $data/segments $data/reco2file_and_channel \ + '>' $dir/ascore_LMWT/$name.ctm || exit 1; + fi +fi + +if [ $stage -le 1 ]; then +# Remove some stuff we don't want to score, from the ctm. + for x in $dir/ascore_*/$name.ctm; do + cp $x $dir/tmpf; + cat $dir/tmpf | grep -i -v -E '\[noise|laughter|vocalized-noise\]' | \ + grep -i -v -E '' > $x; +# grep -i -v -E '|%HESITATION' > $x; + done +fi + +if [ $stage -le 2 ]; then + if [ "$asclite" == "true" ]; then + oname=$name + [ ! -z $overlap_spk ] && oname=${name}_o$overlap_spk + echo "asclite is starting" + # Run scoring, meaning of hubscr.pl options: + # -G .. produce alignment graphs, + # -v .. verbose, + # -m .. max-memory in GBs, + # -o .. max N of overlapping speakers, + # -a .. use asclite, + # -C .. compression for asclite, + # -B .. blocksize for asclite (kBs?), + # -p .. path for other components, + # -V .. skip validation of input transcripts, + # -h rt-stt .. removes non-lexical items from CTM, + $cmd LMWT=$min_lmwt:$max_lmwt $dir/ascoring/log/score.LMWT.log \ + cp $data/stm $dir/ascore_LMWT/ '&&' \ + cp $dir/ascore_LMWT/${name}.ctm $dir/ascore_LMWT/${oname}.ctm '&&' \ + $hubscr -G -v -m 1:2 -o$overlap_spk -a -C -B 8192 -p $hubdir -V -l english \ + -h rt-stt -g $data/glm -r $dir/ascore_LMWT/stm $dir/ascore_LMWT/${oname}.ctm || exit 1 + # Compress some scoring outputs : alignment info and graphs, + echo -n "compressing asclite outputs " + for LMWT in $(seq $min_lmwt $max_lmwt); do + ascore=$dir/ascore_${LMWT} + gzip -f $ascore/${oname}.ctm.filt.aligninfo.csv + cp $ascore/${oname}.ctm.filt.alignments/index.html $ascore/${oname}.ctm.filt.overlap.html + tar -C $ascore -czf $ascore/${oname}.ctm.filt.alignments.tar.gz ${oname}.ctm.filt.alignments + rm -r $ascore/${oname}.ctm.filt.alignments + echo -n "LMWT:$LMWT " + done + echo done + else + $cmd LMWT=$min_lmwt:$max_lmwt $dir/ascoring/log/score.LMWT.log \ + cp $data/stm $dir/ascore_LMWT/ '&&' \ + $hubscr -p $hubdir -V -l english -h hub5 -g $data/glm -r $dir/ascore_LMWT/stm $dir/ascore_LMWT/${name}.ctm || exit 1 + fi +fi + +exit 0 diff --git a/egs/rt/s5/local/snr b/egs/rt/s5/local/snr new file mode 120000 index 00000000000..6d422e11960 --- /dev/null +++ b/egs/rt/s5/local/snr @@ -0,0 +1 @@ +../../../wsj_noisy/s5/local/snr \ No newline at end of file diff --git a/egs/rt/s5/path.sh b/egs/rt/s5/path.sh new file mode 100755 index 00000000000..8461d980758 --- /dev/null +++ b/egs/rt/s5/path.sh @@ -0,0 +1,5 @@ +export KALDI_ROOT=`pwd`/../../.. +[ -f $KALDI_ROOT/tools/env.sh ] && . $KALDI_ROOT/tools/env.sh +export PATH=$PWD/utils/:$KALDI_ROOT/src/bin:$KALDI_ROOT/tools/openfst/bin:$KALDI_ROOT/src/fstbin/:$KALDI_ROOT/src/gmmbin/:$KALDI_ROOT/src/featbin/:$KALDI_ROOT/src/lm/:$KALDI_ROOT/src/sgmmbin/:$KALDI_ROOT/src/sgmm2bin/:$KALDI_ROOT/src/fgmmbin/:$KALDI_ROOT/src/latbin/:$KALDI_ROOT/src/nnetbin:$KALDI_ROOT/src/nnet2bin/:$KALDI_ROOT/src/kwsbin:$KALDI_ROOT/src/online2bin/:$KALDI_ROOT/src/ivectorbin/:$KALDI_ROOT/src/lmbin/:$KALDI_ROOT/src/nnet3bin/:$KALDI_ROOT/src/segmenterbin/:$PWD:$PATH:$KALDI_ROOT/tools/sctk/bin +export PATH=$KALDI_ROOT/tools/sph2pipe_v2.5/:$PATH +export LC_ALL=C diff --git a/egs/rt/s5/sid b/egs/rt/s5/sid new file mode 120000 index 00000000000..5cb0274b7d6 --- /dev/null +++ b/egs/rt/s5/sid @@ -0,0 +1 @@ +../../sre08/v1/sid/ \ No newline at end of file diff --git a/egs/rt/s5/steps b/egs/rt/s5/steps new file mode 120000 index 00000000000..1b186770dd1 --- /dev/null +++ b/egs/rt/s5/steps @@ -0,0 +1 @@ +../../wsj/s5/steps/ \ No newline at end of file diff --git a/egs/rt/s5/utils b/egs/rt/s5/utils new file mode 120000 index 00000000000..a3279dc8679 --- /dev/null +++ b/egs/rt/s5/utils @@ -0,0 +1 @@ +../../wsj/s5/utils/ \ No newline at end of file diff --git a/egs/wsj/s5/steps/compute_cmvn_stats.sh b/egs/wsj/s5/steps/compute_cmvn_stats.sh index 9056d88691c..6e7531394a2 100755 --- a/egs/wsj/s5/steps/compute_cmvn_stats.sh +++ b/egs/wsj/s5/steps/compute_cmvn_stats.sh @@ -91,18 +91,18 @@ if $fake; then ! cat $data/spk2utt | awk -v dim=$dim '{print $1, "["; for (n=0; n < dim; n++) { printf("0 "); } print "1"; for (n=0; n < dim; n++) { printf("1 "); } print "0 ]";}' | \ copy-matrix ark:- ark,scp:$cmvndir/cmvn_$name.ark,$cmvndir/cmvn_$name.scp && \ - echo "Error creating fake CMVN stats" && exit 1; + echo "Error creating fake CMVN stats. See $logdir/cmvn_$name.log." && exit 1; elif $two_channel; then ! compute-cmvn-stats-two-channel $data/reco2file_and_channel scp:$data/feats.scp \ ark,scp:$cmvndir/cmvn_$name.ark,$cmvndir/cmvn_$name.scp \ - 2> $logdir/cmvn_$name.log && echo "Error computing CMVN stats (using two-channel method)" && exit 1; + 2> $logdir/cmvn_$name.log && echo "Error computing CMVN stats (using two-channel method). See $logdir/cmvn_$name.log." && exit 1; elif [ ! -z "$fake_dims" ]; then ! compute-cmvn-stats --spk2utt=ark:$data/spk2utt scp:$data/feats.scp ark:- | \ modify-cmvn-stats "$fake_dims" ark:- ark,scp:$cmvndir/cmvn_$name.ark,$cmvndir/cmvn_$name.scp && \ - echo "Error computing (partially fake) CMVN stats" && exit 1; + echo "Error computing (partially fake) CMVN stats. See $logdir/cmvn_$name.log" && exit 1; else ! compute-cmvn-stats --spk2utt=ark:$data/spk2utt scp:$data/feats.scp ark,scp:$cmvndir/cmvn_$name.ark,$cmvndir/cmvn_$name.scp \ - 2> $logdir/cmvn_$name.log && echo "Error computing CMVN stats" && exit 1; + 2> $logdir/cmvn_$name.log && echo "Error computing CMVN stats. See $logdir/cmvn_$name.log" && exit 1; fi cp $cmvndir/cmvn_$name.scp $data/cmvn.scp || exit 1; diff --git a/egs/wsj/s5/steps/data/data_dir_manipulation_lib.py b/egs/wsj/s5/steps/data/data_dir_manipulation_lib.py index 1f7253d4891..26fb17324dc 100644 --- a/egs/wsj/s5/steps/data/data_dir_manipulation_lib.py +++ b/egs/wsj/s5/steps/data/data_dir_manipulation_lib.py @@ -1,4 +1,10 @@ -import subprocess +#!/usr/bin/env python +# Copyright 2016 Tom Ko +# 2016 Vimal Manohar +# Apache 2.0 + +from __future__ import print_function +import subprocess, random, argparse, os, shlex, warnings def RunKaldiCommand(command, wait = True): """ Runs commands frequently seen in Kaldi scripts. These are usually a @@ -16,3 +22,415 @@ def RunKaldiCommand(command, wait = True): else: return p +class list_cyclic_iterator: + def __init__(self, list): + self.list_index = 0 + self.list = list + random.shuffle(self.list) + + def next(self): + item = self.list[self.list_index] + self.list_index = (self.list_index + 1) % len(self.list) + return item + +# This functions picks an item from the collection according to the associated probability distribution. +# The probability estimate of each item in the collection is stored in the "probability" field of +# the particular item. x : a collection (list or dictionary) where the values contain a field called probability +def PickItemWithProbability(x): + if isinstance(x, dict): + plist = list(set(x.values())) + else: + plist = x + total_p = sum(item.probability for item in plist) + p = random.uniform(0, total_p) + accumulate_p = 0 + for item in plist: + if accumulate_p + item.probability >= p: + return item + accumulate_p += item.probability + assert False, "Shouldn't get here as the accumulated probability should always equal to 1" + +# This function smooths the probability distribution in the list +def SmoothProbabilityDistribution(list, smoothing_weight=0.0, target_sum=1.0): + if len(list) > 0: + num_unspecified = 0 + accumulated_prob = 0 + for item in list: + if item.probability is None: + num_unspecified += 1 + else: + accumulated_prob += item.probability + + # Compute the probability for the items without specifying their probability + uniform_probability = 0 + if num_unspecified > 0 and accumulated_prob < 1: + uniform_probability = (1 - accumulated_prob) / float(num_unspecified) + elif num_unspecified > 0 and accumulate_prob >= 1: + warnings.warn("The sum of probabilities specified by user is larger than or equal to 1. " + "The items without probabilities specified will be given zero to their probabilities.") + + for item in list: + if item.probability is None: + item.probability = uniform_probability + else: + # smooth the probability + item.probability = (1 - smoothing_weight) * item.probability + smoothing_weight * uniform_probability + + # Normalize the probability + sum_p = sum(item.probability for item in list) + for item in list: + item.probability = item.probability / sum_p * target_sum + + return list + +# This function parses a file and pack the data into a dictionary +# It is useful for parsing file like wav.scp, utt2spk, text...etc +def ParseFileToDict(file, assert2fields = False, value_processor = None): + if value_processor is None: + value_processor = lambda x: x[0] + + dict = {} + for line in open(file, 'r'): + parts = line.split() + if assert2fields: + assert(len(parts) == 2) + + dict[parts[0]] = value_processor(parts[1:]) + return dict + +# This function creates a file and write the content of a dictionary into it +def WriteDictToFile(dict, file_name): + file = open(file_name, 'w') + keys = dict.keys() + keys.sort() + for key in keys: + value = dict[key] + if type(value) in [list, tuple] : + if type(value) is tuple: + value = list(value) + value.sort() + value = ' '.join([ str(x) for x in value ]) + file.write('{0} {1}\n'.format(key, value)) + file.close() + + +# This function creates the utt2uniq file from the utterance id in utt2spk file +def CreateCorruptedUtt2uniq(input_dir, output_dir, num_replicas, include_original, prefix): + corrupted_utt2uniq = {} + # Parse the utt2spk to get the utterance id + utt2spk = ParseFileToDict(input_dir + "/utt2spk", value_processor = lambda x: " ".join(x)) + keys = utt2spk.keys() + keys.sort() + if include_original: + start_index = 0 + else: + start_index = 1 + + for i in range(start_index, num_replicas+1): + for utt_id in keys: + new_utt_id = GetNewId(utt_id, prefix, i) + corrupted_utt2uniq[new_utt_id] = utt_id + + WriteDictToFile(corrupted_utt2uniq, output_dir + "/utt2uniq") + +# This function generates a new id from the input id +# This is needed when we have to create multiple copies of the original data +# E.g. GetNewId("swb0035", prefix="rvb", copy=1) returns a string "rvb1_swb0035" +def GetNewId(id, prefix=None, copy=0): + if prefix is not None: + new_id = prefix + str(copy) + "_" + id + else: + new_id = id + + return new_id + +# This function replicate the entries in files like segments, utt2spk, text +def AddPrefixToFields(input_file, output_file, num_replicas, include_original, prefix, field = [0]): + list = map(lambda x: x.strip(), open(input_file)) + f = open(output_file, "w") + if include_original: + start_index = 0 + else: + start_index = 1 + + for i in range(start_index, num_replicas+1): + for line in list: + if len(line) > 0 and line[0] != ';': + split1 = line.split() + for j in field: + split1[j] = GetNewId(split1[j], prefix, i) + print(" ".join(split1), file=f) + else: + print(line, file=f) + f.close() + +def CopyDataDirFiles(input_dir, output_dir, num_replicas, include_original, prefix): + if not os.path.isfile(output_dir + "/wav.scp"): + raise Exception("CopyDataDirFiles function expects output_dir to contain wav.scp already") + + AddPrefixToFields(input_dir + "/utt2spk", output_dir + "/utt2spk", num_replicas, include_original=include_original, prefix=prefix, field = [0,1]) + RunKaldiCommand("utils/utt2spk_to_spk2utt.pl <{output_dir}/utt2spk >{output_dir}/spk2utt" + .format(output_dir = output_dir)) + + if os.path.isfile(input_dir + "/utt2uniq"): + AddPrefixToFields(input_dir + "/utt2uniq", output_dir + "/utt2uniq", num_replicas, include_original=include_original, prefix=prefix, field =[0]) + else: + # Create the utt2uniq file + CreateCorruptedUtt2uniq(input_dir, output_dir, num_replicas, include_original, prefix) + + if os.path.isfile(input_dir + "/text"): + AddPrefixToFields(input_dir + "/text", output_dir + "/text", num_replicas, include_original=include_original, prefix=prefix, field =[0]) + if os.path.isfile(input_dir + "/segments"): + AddPrefixToFields(input_dir + "/segments", output_dir + "/segments", num_replicas, prefix=prefix, include_original=include_original, field = [0,1]) + if os.path.isfile(input_dir + "/reco2file_and_channel"): + AddPrefixToFields(input_dir + "/reco2file_and_channel", output_dir + "/reco2file_and_channel", num_replicas, include_original=include_original, prefix=prefix, field = [0,1]) + + AddPrefixToFields(input_dir + "/reco2dur", output_dir + "/reco2dur", num_replicas, include_original=include_original, prefix=prefix, field = [0]) + + RunKaldiCommand("utils/validate_data_dir.sh --no-feats {output_dir}" + .format(output_dir = output_dir)) + + +# This function parse the array of rir set parameter strings. +# It will assign probabilities to those rir sets which don't have a probability +# It will also check the existence of the rir list files. +def ParseSetParameterStrings(set_para_array): + set_list = [] + for set_para in set_para_array: + set = lambda: None + setattr(set, "filename", None) + setattr(set, "probability", None) + parts = set_para.split(',') + if len(parts) == 2: + set.probability = float(parts[0]) + set.filename = parts[1].strip() + else: + set.filename = parts[0].strip() + if not os.path.isfile(set.filename): + raise Exception(set.filename + " not found") + set_list.append(set) + + return SmoothProbabilityDistribution(set_list) + + +# This function creates the RIR list +# Each rir object in the list contains the following attributes: +# rir_id, room_id, receiver_position_id, source_position_id, rt60, drr, probability +# Please refer to the help messages in the parser for the meaning of these attributes +def ParseRirList(rir_set_para_array, smoothing_weight, sampling_rate = None): + rir_parser = argparse.ArgumentParser() + rir_parser.add_argument('--rir-id', type=str, required=True, help='This id is unique for each RIR and the noise may associate with a particular RIR by refering to this id') + rir_parser.add_argument('--room-id', type=str, required=True, help='This is the room that where the RIR is generated') + rir_parser.add_argument('--receiver-position-id', type=str, default=None, help='receiver position id') + rir_parser.add_argument('--source-position-id', type=str, default=None, help='source position id') + rir_parser.add_argument('--rt60', type=float, default=None, help='RT60 is the time required for reflections of a direct sound to decay 60 dB.') + rir_parser.add_argument('--drr', type=float, default=None, help='Direct-to-reverberant-ratio of the impulse response.') + rir_parser.add_argument('--cte', type=float, default=None, help='Early-to-late index of the impulse response.') + rir_parser.add_argument('--probability', type=float, default=None, help='probability of the impulse response.') + rir_parser.add_argument('rir_rspecifier', type=str, help="""rir rspecifier, it can be either a filename or a piped command. + E.g. data/impulses/Room001-00001.wav or "sox data/impulses/Room001-00001.wav -t wav - |" """) + + set_list = ParseSetParameterStrings(rir_set_para_array) + + rir_list = [] + for rir_set in set_list: + current_rir_list = map(lambda x: rir_parser.parse_args(shlex.split(x.strip())),open(rir_set.filename)) + for rir in current_rir_list: + if sampling_rate is not None: + # check if the rspecifier is a pipe or not + if len(rir.rir_rspecifier.split()) == 1: + rir.rir_rspecifier = "sox {0} -r {1} -t wav - |".format(rir.rir_rspecifier, sampling_rate) + else: + rir.rir_rspecifier = "{0} sox -t wav - -r {1} -t wav - |".format(rir.rir_rspecifier, sampling_rate) + + rir_list += SmoothProbabilityDistribution(current_rir_list, smoothing_weight, rir_set.probability) + + return rir_list + + +# This dunction checks if the inputs are approximately equal assuming they are floats. +def almost_equal(value_1, value_2, accuracy = 10**-8): + return abs(value_1 - value_2) < accuracy + +# This function converts a list of RIRs into a dictionary of RIRs indexed by the room-id. +# Its values are objects with two attributes: a local RIR list +# and the probability of the corresponding room +# Please look at the comments at ParseRirList() for the attributes that a RIR object contains +def MakeRoomDict(rir_list): + room_dict = {} + for rir in rir_list: + if rir.room_id not in room_dict: + # add new room + room_dict[rir.room_id] = lambda: None + setattr(room_dict[rir.room_id], "rir_list", []) + setattr(room_dict[rir.room_id], "probability", 0) + room_dict[rir.room_id].rir_list.append(rir) + + # the probability of the room is the sum of probabilities of its RIR + for key in room_dict.keys(): + room_dict[key].probability = sum(rir.probability for rir in room_dict[key].rir_list) + + assert almost_equal(sum(room_dict[key].probability for key in room_dict.keys()), 1.0) + + return room_dict + + +# This function creates the point-source noise list +# and the isotropic noise dictionary from the noise information file +# The isotropic noise dictionary is indexed by the room +# and its value is the corrresponding isotropic noise list +# Each noise object in the list contains the following attributes: +# noise_id, noise_type, bg_fg_type, room_linkage, probability, noise_rspecifier +# Please refer to the help messages in the parser for the meaning of these attributes +def ParseNoiseList(noise_set_para_array, smoothing_weight, sampling_rate = None): + noise_parser = argparse.ArgumentParser() + noise_parser.add_argument('--noise-id', type=str, required=True, help='noise id') + noise_parser.add_argument('--noise-type', type=str, required=True, help='the type of noise; i.e. isotropic or point-source', choices = ["isotropic", "point-source"]) + noise_parser.add_argument('--bg-fg-type', type=str, default="background", help='background or foreground noise, for background noises, ' + 'they will be extended before addition to cover the whole speech; for foreground noise, they will be kept ' + 'to their original duration and added at a random point of the speech.', choices = ["background", "foreground"]) + noise_parser.add_argument('--room-linkage', type=str, default=None, help='required if isotropic, should not be specified if point-source.') + noise_parser.add_argument('--probability', type=float, default=None, help='probability of the noise.') + noise_parser.add_argument('noise_rspecifier', type=str, help="""noise rspecifier, it can be either a filename or a piped command. + E.g. type5_noise_cirline_ofc_ambient1.wav or "sox type5_noise_cirline_ofc_ambient1.wav -t wav - |" """) + + set_list = ParseSetParameterStrings(noise_set_para_array) + + pointsource_noise_list = [] + iso_noise_dict = {} + for noise_set in set_list: + current_noise_list = map(lambda x: noise_parser.parse_args(shlex.split(x.strip())),open(noise_set.filename)) + current_pointsource_noise_list = [] + for noise in current_noise_list: + if sampling_rate is not None: + # check if the rspecifier is a pipe or not + if len(noise.noise_rspecifier.split()) == 1: + noise.noise_rspecifier = "sox {0} -r {1} -t wav - |".format(noise.noise_rspecifier, sampling_rate) + else: + noise.noise_rspecifier = "{0} sox -t wav - -r {1} -t wav - |".format(noise.noise_rspecifier, sampling_rate) + + if noise.noise_type == "isotropic": + if noise.room_linkage is None: + raise Exception("--room-linkage must be specified if --noise-type is isotropic") + else: + if noise.room_linkage not in iso_noise_dict: + iso_noise_dict[noise.room_linkage] = [] + iso_noise_dict[noise.room_linkage].append(noise) + else: + current_pointsource_noise_list.append(noise) + + pointsource_noise_list += SmoothProbabilityDistribution(current_pointsource_noise_list, smoothing_weight, noise_set.probability) + + # ensure the point-source noise probabilities sum to 1 + pointsource_noise_list = SmoothProbabilityDistribution(pointsource_noise_list, smoothing_weight, 1.0) + if len(pointsource_noise_list) > 0: + assert almost_equal(sum(noise.probability for noise in pointsource_noise_list), 1.0) + + # ensure the isotropic noise source probabilities for a given room sum to 1 + for key in iso_noise_dict.keys(): + iso_noise_dict[key] = SmoothProbabilityDistribution(iso_noise_dict[key]) + assert almost_equal(sum(noise.probability for noise in iso_noise_dict[key]), 1.0) + + return (pointsource_noise_list, iso_noise_dict) + +def AddPointSourceNoise(room, # the room selected + pointsource_noise_list, # the point source noise list + pointsource_noise_addition_probability, # Probability of adding point-source noises + foreground_snrs, # the SNR for adding the foreground noises + background_snrs, # the SNR for adding the background noises + speech_dur, # duration of the recording + max_noises_recording, # Maximum number of point-source noises that can be added + noise_addition_descriptor # descriptor to store the information of the noise added + ): + num_noises_added = 0 + if len(pointsource_noise_list) > 0 and random.random() < pointsource_noise_addition_probability and max_noises_recording >= 1: + for k in range(random.randint(1, max_noises_recording)): + num_noises_added = num_noises_added + 1 + # pick the RIR to reverberate the point-source noise + noise = PickItemWithProbability(pointsource_noise_list) + noise_rir = PickItemWithProbability(room.rir_list) + # If it is a background noise, the noise will be extended and be added to the whole speech + # if it is a foreground noise, the noise will not extended and be added at a random time of the speech + if noise.bg_fg_type == "background": + noise_rvb_command = """wav-reverberate --impulse-response="{0}" --duration={1}""".format(noise_rir.rir_rspecifier, speech_dur) + noise_addition_descriptor['start_times'].append(0) + noise_addition_descriptor['snrs'].append(background_snrs.next()) + noise_addition_descriptor['durations'].append(speech_dur) + noise_addition_descriptor['noise_ids'].append(noise.noise_id) + else: + noise_rvb_command = """wav-reverberate --impulse-response="{0}" """.format(noise_rir.rir_rspecifier) + noise_addition_descriptor['start_times'].append(round(random.random() * speech_dur, 2)) + noise_addition_descriptor['snrs'].append(foreground_snrs.next()) + noise_addition_descriptor['durations'].append(-1) + noise_addition_descriptor['noise_ids'].append(noise.noise_id) + + # check if the rspecifier is a pipe or not + if len(noise.noise_rspecifier.split()) == 1: + noise_addition_descriptor['noise_io'].append("{1} {0} - |".format(noise.noise_rspecifier, noise_rvb_command)) + else: + noise_addition_descriptor['noise_io'].append("{0} {1} - - |".format(noise.noise_rspecifier, noise_rvb_command)) + +# This function randomly decides whether to reverberate, and sample a RIR if it does +# It also decides whether to add the appropriate noises +# This function return the string of options to the binary wav-reverberate +def GenerateReverberationOpts(room_dict, # the room dictionary, please refer to MakeRoomDict() for the format + pointsource_noise_list, # the point source noise list + iso_noise_dict, # the isotropic noise dictionary + foreground_snrs, # the SNR for adding the foreground noises + background_snrs, # the SNR for adding the background noises + speech_rvb_probability, # Probability of reverberating a speech signal + isotropic_noise_addition_probability, # Probability of adding isotropic noises + pointsource_noise_addition_probability, # Probability of adding point-source noises + speech_dur, # duration of the recording + max_noises_recording # Maximum number of point-source noises that can be added + ): + impulse_response_opts = "" + additive_noise_opts = "" + + noise_addition_descriptor = {'noise_io': [], + 'start_times': [], + 'snrs': [], + 'noise_ids': [], + 'durations': [] + } + # Randomly select the room + # Here the room probability is a sum of the probabilities of the RIRs recorded in the room. + room = PickItemWithProbability(room_dict) + # Randomly select the RIR in the room + speech_rir = PickItemWithProbability(room.rir_list) + if random.random() < speech_rvb_probability: + # pick the RIR to reverberate the speech + impulse_response_opts = """--impulse-response="{0}" """.format(speech_rir.rir_rspecifier) + + rir_iso_noise_list = [] + if speech_rir.room_id in iso_noise_dict: + rir_iso_noise_list = iso_noise_dict[speech_rir.room_id] + # Add the corresponding isotropic noise associated with the selected RIR + if len(rir_iso_noise_list) > 0 and random.random() < isotropic_noise_addition_probability: + isotropic_noise = PickItemWithProbability(rir_iso_noise_list) + # extend the isotropic noise to the length of the speech waveform + # check if the rspecifier is really a pipe + if len(isotropic_noise.noise_rspecifier.split()) == 1: + noise_addition_descriptor['noise_io'].append("wav-reverberate --duration={1} {0} - |".format(isotropic_noise.noise_rspecifier, speech_dur)) + else: + noise_addition_descriptor['noise_io'].append("{0} wav-reverberate --duration={1} - - |".format(isotropic_noise.noise_rspecifier, speech_dur)) + noise_addition_descriptor['start_times'].append(0) + noise_addition_descriptor['snrs'].append(background_snrs.next()) + noise_addition_descriptor['noise_ids'].append(isotropic_noise.noise_id) + noise_addition_descriptor['durations'].append(speech_dur) + + AddPointSourceNoise(room, # the room selected + pointsource_noise_list, # the point source noise list + pointsource_noise_addition_probability, # Probability of adding point-source noises + foreground_snrs, # the SNR for adding the foreground noises + background_snrs, # the SNR for adding the background noises + speech_dur, # duration of the recording + max_noises_recording, # Maximum number of point-source noises that can be added + noise_addition_descriptor # descriptor to store the information of the noise added + ) + + assert len(noise_addition_descriptor['noise_io']) == len(noise_addition_descriptor['start_times']) + assert len(noise_addition_descriptor['noise_io']) == len(noise_addition_descriptor['snrs']) + + return [impulse_response_opts, noise_addition_descriptor] + diff --git a/egs/wsj/s5/steps/data/make_corrupted_data_dir.py b/egs/wsj/s5/steps/data/make_corrupted_data_dir.py new file mode 100644 index 00000000000..c0fa94c2a42 --- /dev/null +++ b/egs/wsj/s5/steps/data/make_corrupted_data_dir.py @@ -0,0 +1,613 @@ +#!/usr/bin/env python +# Copyright 2016 Tom Ko +# Apache 2.0 +# script to generate reverberated data + +# we're using python 3.x style print but want it to work in python 2.x, +from __future__ import print_function +import argparse, shlex, glob, math, os, random, sys, warnings, copy, imp, ast + +import data_dir_manipulation_lib as data_lib + +sys.path.insert(0, 'steps') +import libs.common as common_lib + +def GetArgs(): + # we add required arguments as named arguments for readability + parser = argparse.ArgumentParser(description="Reverberate the data directory with an option " + "to add isotropic and point source noises. " + "Usage: reverberate_data_dir.py [options...] " + "E.g. reverberate_data_dir.py --rir-set-parameters rir_list " + "--foreground-snrs 20:10:15:5:0 --background-snrs 20:10:15:5:0 " + "--noise-list-file noise_list --speech-rvb-probability 1 --num-replications 2 " + "--random-seed 1 data/train data/train_rvb", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument("--rir-set-parameters", type=str, action='append', required = True, dest = "rir_set_para_array", + help="Specifies the parameters of an RIR set. " + "Supports the specification of mixture_weight and rir_list_file_name. The mixture weight is optional. " + "The default mixture weight is the probability mass remaining after adding the mixture weights " + "of all the RIR lists, uniformly divided among the RIR lists without mixture weights. " + "E.g. --rir-set-parameters '0.3, rir_list' or 'rir_list' " + "the format of the RIR list file is " + "--rir-id --room-id " + "--receiver-position-id --source-position-id " + "--rt-60 --drr location " + "E.g. --rir-id 00001 --room-id 001 --receiver-position-id 001 --source-position-id 00001 " + "--rt60 0.58 --drr -4.885 data/impulses/Room001-00001.wav") + parser.add_argument("--noise-set-parameters", type=str, action='append', + default = None, dest = "noise_set_para_array", + help="Specifies the parameters of an noise set. " + "Supports the specification of mixture_weight and noise_list_file_name. The mixture weight is optional. " + "The default mixture weight is the probability mass remaining after adding the mixture weights " + "of all the noise lists, uniformly divided among the noise lists without mixture weights. " + "E.g. --noise-set-parameters '0.3, noise_list' or 'noise_list' " + "the format of the noise list file is " + "--noise-id --noise-type " + "--bg-fg-type " + "--room-linkage " + "location " + "E.g. --noise-id 001 --noise-type isotropic --rir-id 00019 iso_noise.wav") + parser.add_argument("--speech-segments-set-parameters", type=str, action='append', + default = None, dest = "speech_segments_set_para_array", + help="Specifies the speech segments for overlapped speech generation.\n" + "Format: [], wav_scp, segments_list\n"); + parser.add_argument("--num-replications", type=int, dest = "num_replicas", default = 1, + help="Number of replicate to generated for the data") + parser.add_argument('--foreground-snrs', type=str, dest = "foreground_snr_string", + default = '20:10:0', + help='When foreground noises are being added the script will iterate through these SNRs.') + parser.add_argument('--background-snrs', type=str, dest = "background_snr_string", + default = '20:10:0', + help='When background noises are being added the script will iterate through these SNRs.') + parser.add_argument('--overlap-snrs', type=str, dest = "overlap_snr_string", + default = "20:10:0", + help='When overlapping speech segments are being added the script will iterate through these SNRs.') + parser.add_argument('--prefix', type=str, default = None, + help='This prefix will modified for each reverberated copy, by adding additional affixes.') + parser.add_argument("--speech-rvb-probability", type=float, default = 1.0, + help="Probability of reverberating a speech signal, e.g. 0 <= p <= 1") + parser.add_argument("--pointsource-noise-addition-probability", type=float, default = 1.0, + help="Probability of adding point-source noises, e.g. 0 <= p <= 1") + parser.add_argument("--isotropic-noise-addition-probability", type=float, default = 1.0, + help="Probability of adding isotropic noises, e.g. 0 <= p <= 1") + parser.add_argument("--overlapping-speech-addition-probability", type=float, default = 1.0, + help="Probability of adding overlapping speech, e.g. 0 <= p <= 1") + parser.add_argument("--rir-smoothing-weight", type=float, default = 0.3, + help="Smoothing weight for the RIR probabilties, e.g. 0 <= p <= 1. If p = 0, no smoothing will be done. " + "The RIR distribution will be mixed with a uniform distribution according to the smoothing weight") + parser.add_argument("--noise-smoothing-weight", type=float, default = 0.3, + help="Smoothing weight for the noise probabilties, e.g. 0 <= p <= 1. If p = 0, no smoothing will be done. " + "The noise distribution will be mixed with a uniform distribution according to the smoothing weight") + parser.add_argument("--overlapping-speech-smoothing-weight", type=float, default = 0.3, + help="The overlapping speech distribution will be mixed with a uniform distribution according to the smoothing weight") + parser.add_argument("--max-noises-per-minute", type=int, default = 2, + help="This controls the maximum number of point-source noises that could be added to a recording according to its duration") + parser.add_argument("--min-overlapping-segments-per-minute", type=int, default = 1, + help="This controls the minimum number of overlapping segments of speech that could be added to a recording per minute") + parser.add_argument("--max-overlapping-segments-per-minute", type=int, default = 5, + help="This controls the maximum number of overlapping segments of speech that could be added to a recording per minute") + parser.add_argument('--random-seed', type=int, default=0, + help='seed to be used in the randomization of impulses and noises') + parser.add_argument("--shift-output", type=str, + help="If true, the reverberated waveform will be shifted by the amount of the peak position of the RIR", + choices=['true', 'false'], default = "true") + parser.add_argument('--source-sampling-rate', type=int, default=None, + help="Sampling rate of the source data. If a positive integer is specified with this option, " + "the RIRs/noises will be resampled to the rate of the source data.") + parser.add_argument("--include-original-data", type=str, help="If true, the output data includes one copy of the original data", + choices=['true', 'false'], default = "false") + parser.add_argument("--output-additive-noise-dir", type=str, + action = common_lib.NullstrToNoneAction, default = None, + help="Output directory corresponding to the additive noise part of the data corruption") + parser.add_argument("--output-reverb-dir", type=str, + action = common_lib.NullstrToNoneAction, default = None, + help="Output directory corresponding to the reverberated signal part of the data corruption") + + parser.add_argument("input_dir", + help="Input data directory") + parser.add_argument("output_dir", + help="Output data directory") + + print(' '.join(sys.argv)) + + args = parser.parse_args() + args = CheckArgs(args) + + return args + +def CheckArgs(args): + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + + ## Check arguments. + + if args.prefix is None: + if args.num_replicas > 1 or args.include_original_data == "true": + args.prefix = "rvb" + warnings.warn("--prefix is set to 'rvb' as more than one copy of data is generated") + + if args.output_reverb_dir is not None: + if not os.path.exists(args.output_reverb_dir): + os.makedirs(args.output_reverb_dir) + + if args.output_additive_noise_dir is not None: + if not os.path.exists(args.output_additive_noise_dir): + os.makedirs(args.output_additive_noise_dir) + + ## Check arguments. + + if args.num_replicas > 1 and args.prefix is None: + args.prefix = "rvb" + warnings.warn("--prefix is set to 'rvb' as --num-replications is larger than 1.") + + if not args.num_replicas > 0: + raise Exception("--num-replications cannot be non-positive") + + if args.speech_rvb_probability < 0 or args.speech_rvb_probability > 1: + raise Exception("--speech-rvb-probability must be between 0 and 1") + + if args.pointsource_noise_addition_probability < 0 or args.pointsource_noise_addition_probability > 1: + raise Exception("--pointsource-noise-addition-probability must be between 0 and 1") + + if args.isotropic_noise_addition_probability < 0 or args.isotropic_noise_addition_probability > 1: + raise Exception("--isotropic-noise-addition-probability must be between 0 and 1") + + if args.overlapping_speech_addition_probability < 0 or args.overlapping_speech_addition_probability > 1: + raise Exception("--overlapping-speech-addition-probability must be between 0 and 1") + + if args.rir_smoothing_weight < 0 or args.rir_smoothing_weight > 1: + raise Exception("--rir-smoothing-weight must be between 0 and 1") + + if args.noise_smoothing_weight < 0 or args.noise_smoothing_weight > 1: + raise Exception("--noise-smoothing-weight must be between 0 and 1") + + if args.overlapping_speech_smoothing_weight < 0 or args.overlapping_speech_smoothing_weight > 1: + raise Exception("--overlapping-speech-smoothing-weight must be between 0 and 1") + + if args.max_noises_per_minute < 0: + raise Exception("--max-noises-per-minute cannot be negative") + + if args.min_overlapping_segments_per_minute < 0: + raise Exception("--min-overlapping-segments-per-minute cannot be negative") + + if args.max_overlapping_segments_per_minute < 0: + raise Exception("--max-overlapping-segments-per-minute cannot be negative") + + return args + +def ParseSpeechSegmentsList(speech_segments_set_para_array, smoothing_weight): + set_list = [] + for set_para in speech_segments_set_para_array: + set = lambda: None + setattr(set, "wav_scp", None) + setattr(set, "segments", None) + setattr(set, "probability", None) + parts = set_para.split(',') + if len(parts) == 3: + set.probability = float(parts[0]) + set.wav_scp = parts[1].strip() + set.segments = parts[2].strip() + else: + set.wav_scp = parts[0].strip() + set.segments = parts[1].strip() + if not os.path.isfile(set.wav_scp): + raise Exception(set.wav_scp + " not found") + if not os.path.isfile(set.segments): + raise Exception(set.segments + " not found") + set_list.append(set) + + data_lib.SmoothProbabilityDistribution(set_list) + + segments_list = [] + for segments_set in set_list: + current_segments_list = [] + + wav_dict = {} + for s in open(segments_set.wav_scp): + parts = s.strip().split() + wav_dict[parts[0]] = ' '.join(parts[1:]) + + for s in open(segments_set.segments): + parts = s.strip().split() + current_segment = argparse.Namespace() + current_segment.utt_id = parts[0] + current_segment.probability = None + + start_time = float(parts[2]) + end_time = float(parts[3]) + + current_segment.duration = (end_time - start_time) + + wav_rxfilename = wav_dict[parts[1]] + if wav_rxfilename.split()[-1] == '|': + current_segment.wav_rxfilename = "{0} sox -t wav - -t wav - trim {1} {2} |".format(wav_rxfilename, start_time, end_time - start_time) + else: + current_segment.wav_rxfilename = "sox {0} -t wav - trim {1} {2} |".format(wav_rxfilename, start_time, end_time - start_time) + + current_segments_list.append(current_segment) + + segments_list += data_lib.SmoothProbabilityDistribution(current_segments_list, smoothing_weight, segments_set.probability) + + return segments_list + +def AddOverlappingSpeech(room, # the room selected + speech_segments_list, # the speech list + overlapping_speech_addition_probability, # Probability of another speech waveform + snrs, # the SNR for adding the foreground speech + speech_dur, # duration of the recording + min_overlapping_speech_segments, # Minimum number of speech signals that can be added + max_overlapping_speech_segments, # Maximum number of speech signals that can be added + overlapping_speech_descriptor # descriptor to store the information of the overlapping speech + ): + if (len(speech_segments_list) > 0 and random.random() < overlapping_speech_addition_probability + and max_overlapping_speech_segments >= 1): + for k in range(1, random.randint(min_overlapping_speech_segments, max_overlapping_speech_segments) + 1): + # pick the overlapping_speech speech signal and the RIR to + # reverberate the overlapping_speech speech signal + speech_segment = data_lib.PickItemWithProbability(speech_segments_list) + rir = data_lib.PickItemWithProbability(room.rir_list) + + speech_rvb_command = """wav-reverberate --impulse-response="{0}" --shift-output=true """.format(rir.rir_rspecifier) + overlapping_speech_descriptor['start_times'].append( + round(random.random() + * max(speech_dur - speech_segment.duration, 0), 2)) + overlapping_speech_descriptor['snrs'].append(snrs.next()) + overlapping_speech_descriptor['utt_ids'].append(speech_segment.utt_id) + overlapping_speech_descriptor['durations'].append(speech_segment.duration) + + if len(speech_segment.wav_rxfilename.split()) == 1: + overlapping_speech_descriptor['speech_segments'].append("{1} {0} - |".format(speech_segment.wav_rxfilename, speech_rvb_command)) + else: + overlapping_speech_descriptor['speech_segments'].append("{0} {1} - - |".format(speech_segment.wav_rxfilename, speech_rvb_command)) + +# This function randomly decides whether to reverberate, and sample a RIR if it does +# It also decides whether to add the appropriate noises +# This function return the string of options to the binary wav-reverberate +def GenerateReverberationAndOverlappedSpeechOpts( + room_dict, # the room dictionary, please refer to MakeRoomDict() for the format + pointsource_noise_list, # the point source noise list + iso_noise_dict, # the isotropic noise dictionary + foreground_snrs, # the SNR for adding the foreground noises + background_snrs, # the SNR for adding the background noises + speech_segments_list, + overlap_snrs, + speech_rvb_probability, # Probability of reverberating a speech signal + isotropic_noise_addition_probability, # Probability of adding isotropic noises + pointsource_noise_addition_probability, # Probability of adding point-source noises + overlapping_speech_addition_probability, # Probability of adding overlapping speech segments + speech_dur, # duration of the recording + max_noises_recording, # Maximum number of point-source noises that can be added + min_overlapping_segments_recording, # Minimum number of overlapping segments that can be added + max_overlapping_segments_recording # Maximum number of overlapping segments that can be added + ): + impulse_response_opts = "" + + noise_addition_descriptor = {'noise_io': [], + 'start_times': [], + 'snrs': [], + 'noise_ids': [], + 'durations': [] + } + + # Randomly select the room + # Here the room probability is a sum of the probabilities of the RIRs recorded in the room. + room = data_lib.PickItemWithProbability(room_dict) + # Randomly select the RIR in the room + speech_rir = data_lib.PickItemWithProbability(room.rir_list) + if random.random() < speech_rvb_probability: + # pick the RIR to reverberate the speech + impulse_response_opts = """--impulse-response="{0}" """.format(speech_rir.rir_rspecifier) + + rir_iso_noise_list = [] + if speech_rir.room_id in iso_noise_dict: + rir_iso_noise_list = iso_noise_dict[speech_rir.room_id] + # Add the corresponding isotropic noise associated with the selected RIR + if len(rir_iso_noise_list) > 0 and random.random() < isotropic_noise_addition_probability: + isotropic_noise = data_lib.PickItemWithProbability(rir_iso_noise_list) + # extend the isotropic noise to the length of the speech waveform + # check if it is really a pipe + if len(isotropic_noise.noise_rspecifier.split()) == 1: + noise_addition_descriptor['noise_io'].append("wav-reverberate --duration={1} {0} - |".format(isotropic_noise.noise_rspecifier, speech_dur)) + else: + noise_addition_descriptor['noise_io'].append("{0} wav-reverberate --duration={1} - - |".format(isotropic_noise.noise_rspecifier, speech_dur)) + noise_addition_descriptor['start_times'].append(0) + noise_addition_descriptor['snrs'].append(background_snrs.next()) + noise_addition_descriptor['noise_ids'].append(isotropic_noise.noise_id) + noise_addition_descriptor['durations'].append(speech_dur) + + data_lib.AddPointSourceNoise(room, # the room selected + pointsource_noise_list, # the point source noise list + pointsource_noise_addition_probability, # Probability of adding point-source noises + foreground_snrs, # the SNR for adding the foreground noises + background_snrs, # the SNR for adding the background noises + speech_dur, # duration of the recording + max_noises_recording, # Maximum number of point-source noises that can be added + noise_addition_descriptor # descriptor to store the information of the noise added + ) + + assert len(noise_addition_descriptor['noise_io']) == len(noise_addition_descriptor['start_times']) + assert len(noise_addition_descriptor['noise_io']) == len(noise_addition_descriptor['snrs']) + assert len(noise_addition_descriptor['noise_io']) == len(noise_addition_descriptor['noise_ids']) + assert len(noise_addition_descriptor['noise_io']) == len(noise_addition_descriptor['durations']) + + overlapping_speech_descriptor = {'speech_segments': [], + 'start_times': [], + 'snrs': [], + 'utt_ids': [], + 'durations': [] + } + + AddOverlappingSpeech(room, + speech_segments_list, # speech segments list + overlapping_speech_addition_probability, + overlap_snrs, + speech_dur, + min_overlapping_segments_recording, + max_overlapping_segments_recording, + overlapping_speech_descriptor + ) + + return [impulse_response_opts, noise_addition_descriptor, + overlapping_speech_descriptor] + +# This is the main function to generate pipeline command for the corruption +# The generic command of wav-reverberate will be like: +# wav-reverberate --duration=t --impulse-response=rir.wav +# --additive-signals='noise1.wav,noise2.wav' --snrs='snr1,snr2' --start-times='s1,s2' input.wav output.wav +def GenerateReverberatedWavScpWithOverlappedSpeech( + wav_scp, # a dictionary whose values are the Kaldi-IO strings of the speech recordings + durations, # a dictionary whose values are the duration (in sec) of the speech recordings + output_dir, # output directory to write the corrupted wav.scp + room_dict, # the room dictionary, please refer to MakeRoomDict() for the format + pointsource_noise_list, # the point source noise list + iso_noise_dict, # the isotropic noise dictionary + foreground_snr_array, # the SNR for adding the foreground noises + background_snr_array, # the SNR for adding the background noises + speech_segments_list, # list of speech segments to create overlapped speech + overlap_snr_array, # the SNR for adding overlapping speech + num_replicas, # Number of replicate to generated for the data + include_original, # include a copy of the original data + prefix, # prefix for the id of the corrupted utterances + speech_rvb_probability, # Probability of reverberating a speech signal + shift_output, # option whether to shift the output waveform + isotropic_noise_addition_probability, # Probability of adding isotropic noises + pointsource_noise_addition_probability, # Probability of adding point-source noises + max_noises_per_minute, # maximum number of point-source noises that can be added to a recording according to its duration + overlapping_speech_addition_probability, + min_overlapping_segments_per_minute, + max_overlapping_segments_per_minute, + output_reverb_dir = None, + output_additive_noise_dir = None + ): + foreground_snrs = data_lib.list_cyclic_iterator(foreground_snr_array) + background_snrs = data_lib.list_cyclic_iterator(background_snr_array) + overlap_snrs = data_lib.list_cyclic_iterator(overlap_snr_array) + + corrupted_wav_scp = {} + reverb_wav_scp = {} + additive_noise_wav_scp = {} + overlapping_segments_info = {} + + keys = wav_scp.keys() + keys.sort() + + if include_original: + start_index = 0 + else: + start_index = 1 + + for i in range(start_index, num_replicas+1): + for recording_id in keys: + wav_original_pipe = wav_scp[recording_id] + # check if it is really a pipe + if len(wav_original_pipe.split()) == 1: + wav_original_pipe = "cat {0} |".format(wav_original_pipe) + speech_dur = durations[recording_id] + max_noises_recording = math.floor(max_noises_per_minute * speech_dur / 60) + min_overlapping_segments_recording = max(math.floor(min_overlapping_segments_per_minute * speech_dur / 60), 1) + max_overlapping_segments_recording = math.ceil(max_overlapping_segments_per_minute * speech_dur / 60) + + [impulse_response_opts, noise_addition_descriptor, + overlapping_speech_descriptor] = GenerateReverberationAndOverlappedSpeechOpts( + room_dict = room_dict, # the room dictionary, please refer to MakeRoomDict() for the format + pointsource_noise_list = pointsource_noise_list, # the point source noise list + iso_noise_dict = iso_noise_dict, # the isotropic noise dictionary + foreground_snrs = foreground_snrs, # the SNR for adding the foreground noises + background_snrs = background_snrs, # the SNR for adding the background noises + speech_segments_list = speech_segments_list, # Speech segments for creating overlapped speech + overlap_snrs = overlap_snrs, # the SNR for adding overlapping speech + speech_rvb_probability = speech_rvb_probability, # Probability of reverberating a speech signal + isotropic_noise_addition_probability = isotropic_noise_addition_probability, # Probability of adding isotropic noises + pointsource_noise_addition_probability = pointsource_noise_addition_probability, # Probability of adding point-source noises + overlapping_speech_addition_probability = overlapping_speech_addition_probability, + speech_dur = speech_dur, # duration of the recording + max_noises_recording = max_noises_recording, # Maximum number of point-source noises that can be added + min_overlapping_segments_recording = min_overlapping_segments_recording, + max_overlapping_segments_recording = max_overlapping_segments_recording + ) + + additive_noise_opts = "" + + if (len(noise_addition_descriptor['noise_io']) > 0 or + len(overlapping_speech_descriptor['speech_segments']) > 0): + additive_noise_opts += ("--additive-signals='{0}' " + .format(',' + .join(noise_addition_descriptor['noise_io'] + + overlapping_speech_descriptor['speech_segments'])) + ) + additive_noise_opts += ("--start-times='{0}' " + .format(',' + .join(map(lambda x:str(x), noise_addition_descriptor['start_times'] + + overlapping_speech_descriptor['start_times']))) + ) + additive_noise_opts += ("--snrs='{0}' " + .format(',' + .join(map(lambda x:str(x), noise_addition_descriptor['snrs'] + + overlapping_speech_descriptor['snrs']))) + ) + + reverberate_opts = impulse_response_opts + additive_noise_opts + + new_recording_id = data_lib.GetNewId(recording_id, prefix, i) + + # prefix using index 0 is reserved for original data e.g. rvb0_swb0035 corresponds to the swb0035 recording in original data + if reverberate_opts == "" or i == 0: + wav_corrupted_pipe = "{0}".format(wav_original_pipe) + else: + wav_corrupted_pipe = "{0} wav-reverberate --shift-output={1} {2} - - |".format(wav_original_pipe, shift_output, reverberate_opts) + + corrupted_wav_scp[new_recording_id] = wav_corrupted_pipe + + if output_reverb_dir is not None: + if impulse_response_opts == "": + wav_reverb_pipe = "{0}".format(wav_original_pipe) + else: + wav_reverb_pipe = "{0} wav-reverberate --shift-output={1} --reverb-out-wxfilename=- {2} - /dev/null |".format(wav_original_pipe, shift_output, reverberate_opts) + reverb_wav_scp[new_recording_id] = wav_reverb_pipe + + if output_additive_noise_dir is not None: + if additive_noise_opts != "": + wav_additive_noise_pipe = "{0} wav-reverberate --shift-output={1} --additive-noise-out-wxfilename=- {2} - /dev/null |".format(wav_original_pipe, shift_output, reverberate_opts) + additive_noise_wav_scp[new_recording_id] = wav_additive_noise_pipe + else: + assert False + + if len(overlapping_speech_descriptor['speech_segments']) > 0: + overlapping_segments_info[new_recording_id] = [ + ':'.join(x) + for x in zip(overlapping_speech_descriptor['utt_ids'], + [ str(x) for x in overlapping_speech_descriptor['start_times'] ], + [ str(x) for x in overlapping_speech_descriptor['durations'] ]) + ] + + data_lib.WriteDictToFile(corrupted_wav_scp, output_dir + "/wav.scp") + + # Write for each new recording, the id, start time and durations + # of the overlapping segments + data_lib.WriteDictToFile(overlapping_segments_info, output_dir + "/overlapped_segments_info.txt") + + if output_reverb_dir is not None: + data_lib.WriteDictToFile(reverb_wav_scp, output_reverb_dir + "/wav.scp") + + if output_additive_noise_dir is not None: + data_lib.WriteDictToFile(additive_noise_wav_scp, output_additive_noise_dir + "/wav.scp") + + +# This function creates multiple copies of the necessary files, e.g. utt2spk, wav.scp ... +def CreateReverberatedCopy(input_dir, + output_dir, + room_dict, # the room dictionary, please refer to MakeRoomDict() for the format + pointsource_noise_list, # the point source noise list + iso_noise_dict, # the isotropic noise dictionary + speech_segments_list, + foreground_snr_string, # the SNR for adding the foreground noises + background_snr_string, # the SNR for adding the background noises + overlap_snr_string, # the SNR for overlapping speech + num_replicas, # Number of replicate to generated for the data + include_original, # include a copy of the original data + prefix, # prefix for the id of the corrupted utterances + speech_rvb_probability, # Probability of reverberating a speech signal + shift_output, # option whether to shift the output waveform + isotropic_noise_addition_probability, # Probability of adding isotropic noises + pointsource_noise_addition_probability, # Probability of adding point-source noises + max_noises_per_minute, # maximum number of point-source noises that can be added to a recording according to its duration + overlapping_speech_addition_probability, + min_overlapping_segments_per_minute, + max_overlapping_segments_per_minute, + output_reverb_dir = None, + output_additive_noise_dir = None + ): + + wav_scp = data_lib.ParseFileToDict(input_dir + "/wav.scp", value_processor = lambda x: " ".join(x)) + if not os.path.isfile(input_dir + "/reco2dur"): + print("Getting the duration of the recordings..."); + read_entire_file="false" + for value in wav_scp.values(): + # we will add more checks for sox commands which modify the header as we come across these cases in our data + if "sox" in value and "speed" in value: + read_entire_file="true" + break + data_lib.RunKaldiCommand("wav-to-duration --read-entire-file={1} scp:{0}/wav.scp ark,t:{0}/reco2dur".format(input_dir, read_entire_file)) + durations = data_lib.ParseFileToDict(input_dir + "/reco2dur", value_processor = lambda x: float(x[0])) + foreground_snr_array = map(lambda x: float(x), foreground_snr_string.split(':')) + background_snr_array = map(lambda x: float(x), background_snr_string.split(':')) + overlap_snr_array = map(lambda x: float(x), overlap_snr_string.split(':')) + + GenerateReverberatedWavScpWithOverlappedSpeech( + wav_scp = wav_scp, + durations = durations, + output_dir = output_dir, + room_dict = room_dict, + pointsource_noise_list = pointsource_noise_list, + iso_noise_dict = iso_noise_dict, + foreground_snr_array = foreground_snr_array, + background_snr_array = background_snr_array, + speech_segments_list = speech_segments_list, + overlap_snr_array = overlap_snr_array, + num_replicas = num_replicas, include_original=include_original, prefix = prefix, + speech_rvb_probability = speech_rvb_probability, + shift_output = shift_output, + isotropic_noise_addition_probability = isotropic_noise_addition_probability, + pointsource_noise_addition_probability = pointsource_noise_addition_probability, + max_noises_per_minute = max_noises_per_minute, + overlapping_speech_addition_probability = overlapping_speech_addition_probability, + min_overlapping_segments_per_minute = min_overlapping_segments_per_minute, + max_overlapping_segments_per_minute = max_overlapping_segments_per_minute, + output_reverb_dir = output_reverb_dir, + output_additive_noise_dir = output_additive_noise_dir) + + data_lib.CopyDataDirFiles(input_dir, output_dir, num_replicas, include_original=include_original, prefix=prefix) + + if output_reverb_dir is not None: + data_lib.CopyDataDirFiles(input_dir, output_reverb_dir, num_replicas, include_original=include_original, prefix=prefix) + + if output_additive_noise_dir is not None: + data_lib.CopyDataDirFiles(input_dir, output_additive_noise_dir, num_replicas, include_original=include_original, prefix=prefix) + + +def Main(): + args = GetArgs() + random.seed(args.random_seed) + rir_list = data_lib.ParseRirList(args.rir_set_para_array, args.rir_smoothing_weight, args.source_sampling_rate) + print("Number of RIRs is {0}".format(len(rir_list))) + pointsource_noise_list = [] + iso_noise_dict = {} + if args.noise_set_para_array is not None: + pointsource_noise_list, iso_noise_dict = data_lib.ParseNoiseList(args.noise_set_para_array, args.noise_smoothing_weight, args.source_sampling_rate) + print("Number of point-source noises is {0}".format(len(pointsource_noise_list))) + print("Number of isotropic noises is {0}".format(sum(len(iso_noise_dict[key]) for key in iso_noise_dict.keys()))) + room_dict = data_lib.MakeRoomDict(rir_list) + + if args.include_original_data == "true": + include_original = True + else: + include_original = False + + speech_segments_list = ParseSpeechSegmentsList(args.speech_segments_set_para_array, args.overlapping_speech_smoothing_weight) + + CreateReverberatedCopy(input_dir = args.input_dir, + output_dir = args.output_dir, + room_dict = room_dict, + pointsource_noise_list = pointsource_noise_list, + iso_noise_dict = iso_noise_dict, + speech_segments_list = speech_segments_list, + foreground_snr_string = args.foreground_snr_string, + background_snr_string = args.background_snr_string, + overlap_snr_string = args.overlap_snr_string, + num_replicas = args.num_replicas, + include_original = include_original, + prefix = args.prefix, + speech_rvb_probability = args.speech_rvb_probability, + shift_output = args.shift_output, + isotropic_noise_addition_probability = args.isotropic_noise_addition_probability, + pointsource_noise_addition_probability = args.pointsource_noise_addition_probability, + max_noises_per_minute = args.max_noises_per_minute, + overlapping_speech_addition_probability = args.overlapping_speech_addition_probability, + min_overlapping_segments_per_minute = args.min_overlapping_segments_per_minute, + max_overlapping_segments_per_minute = args.max_overlapping_segments_per_minute, + output_reverb_dir = args.output_reverb_dir, + output_additive_noise_dir = args.output_additive_noise_dir) + +if __name__ == "__main__": + Main() diff --git a/egs/wsj/s5/steps/data/make_overlapped_data_dir.py b/egs/wsj/s5/steps/data/make_overlapped_data_dir.py new file mode 100644 index 00000000000..e4bf85f9af7 --- /dev/null +++ b/egs/wsj/s5/steps/data/make_overlapped_data_dir.py @@ -0,0 +1,595 @@ +#!/usr/bin/env python +# Copyright 2016 Tom Ko +# Apache 2.0 +# script to generate reverberated data + +# we're using python 3.x style print but want it to work in python 2.x, +from __future__ import print_function +import argparse, shlex, glob, math, os, random, sys, warnings, copy, imp, ast + +data_lib = imp.load_source('dml', 'steps/data/data_dir_manipulation_lib.py') + +sys.path.insert(0, 'steps') +import libs.common as common_lib + +def GetArgs(): + # we add required arguments as named arguments for readability + parser = argparse.ArgumentParser(description="Reverberate the data directory with an option " + "to add isotropic and point source noises. " + "Usage: reverberate_data_dir.py [options...] " + "E.g. reverberate_data_dir.py --rir-set-parameters rir_list " + "--foreground-snrs 20:10:15:5:0 --background-snrs 20:10:15:5:0 " + "--noise-list-file noise_list --speech-rvb-probability 1 --num-replications 2 " + "--random-seed 1 data/train data/train_rvb", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument("--rir-set-parameters", type=str, action='append', required = True, dest = "rir_set_para_array", + help="Specifies the parameters of an RIR set. " + "Supports the specification of mixture_weight and rir_list_file_name. The mixture weight is optional. " + "The default mixture weight is the probability mass remaining after adding the mixture weights " + "of all the RIR lists, uniformly divided among the RIR lists without mixture weights. " + "E.g. --rir-set-parameters '0.3, rir_list' or 'rir_list' " + "the format of the RIR list file is " + "--rir-id --room-id " + "--receiver-position-id --source-position-id " + "--rt-60 --drr location " + "E.g. --rir-id 00001 --room-id 001 --receiver-position-id 001 --source-position-id 00001 " + "--rt60 0.58 --drr -4.885 data/impulses/Room001-00001.wav") + parser.add_argument("--noise-set-parameters", type=str, action='append', + default = None, dest = "noise_set_para_array", + help="Specifies the parameters of an noise set. " + "Supports the specification of mixture_weight and noise_list_file_name. The mixture weight is optional. " + "The default mixture weight is the probability mass remaining after adding the mixture weights " + "of all the noise lists, uniformly divided among the noise lists without mixture weights. " + "E.g. --noise-set-parameters '0.3, noise_list' or 'noise_list' " + "the format of the noise list file is " + "--noise-id --noise-type " + "--bg-fg-type " + "--room-linkage " + "location " + "E.g. --noise-id 001 --noise-type isotropic --rir-id 00019 iso_noise.wav") + parser.add_argument("--speech-segments-set-parameters", type=str, action='append', + default = None, dest = "speech_segments_set_para_array", + help="Specifies the speech segments for overlapped speech generation.\n" + "Format: [], wav_scp, segments_list\n"); + parser.add_argument("--num-replications", type=int, dest = "num_replicas", default = 1, + help="Number of replicate to generated for the data") + parser.add_argument('--foreground-snrs', type=str, dest = "foreground_snr_string", + default = '20:10:0', + help='When foreground noises are being added the script will iterate through these SNRs.') + parser.add_argument('--background-snrs', type=str, dest = "background_snr_string", + default = '20:10:0', + help='When background noises are being added the script will iterate through these SNRs.') + parser.add_argument('--overlap-snrs', type=str, dest = "overlap_snr_string", + default = "20:10:0", + help='When overlapping speech segments are being added the script will iterate through these SNRs.') + parser.add_argument('--prefix', type=str, default = None, + help='This prefix will modified for each reverberated copy, by adding additional affixes.') + parser.add_argument("--speech-rvb-probability", type=float, default = 1.0, + help="Probability of reverberating a speech signal, e.g. 0 <= p <= 1") + parser.add_argument("--pointsource-noise-addition-probability", type=float, default = 1.0, + help="Probability of adding point-source noises, e.g. 0 <= p <= 1") + parser.add_argument("--isotropic-noise-addition-probability", type=float, default = 1.0, + help="Probability of adding isotropic noises, e.g. 0 <= p <= 1") + parser.add_argument("--overlapping-speech-addition-probability", type=float, default = 1.0, + help="Probability of adding overlapping speech, e.g. 0 <= p <= 1") + parser.add_argument("--rir-smoothing-weight", type=float, default = 0.3, + help="Smoothing weight for the RIR probabilties, e.g. 0 <= p <= 1. If p = 0, no smoothing will be done. " + "The RIR distribution will be mixed with a uniform distribution according to the smoothing weight") + parser.add_argument("--noise-smoothing-weight", type=float, default = 0.3, + help="Smoothing weight for the noise probabilties, e.g. 0 <= p <= 1. If p = 0, no smoothing will be done. " + "The noise distribution will be mixed with a uniform distribution according to the smoothing weight") + parser.add_argument("--overlapping-speech-smoothing-weight", type=float, default = 0.3, + help="The overlapping speech distribution will be mixed with a uniform distribution according to the smoothing weight") + parser.add_argument("--max-noises-per-minute", type=int, default = 2, + help="This controls the maximum number of point-source noises that could be added to a recording according to its duration") + parser.add_argument("--min-overlapping-segments-per-minute", type=int, default = 1, + help="This controls the minimum number of overlapping segments of speech that could be added to a recording per minute") + parser.add_argument("--max-overlapping-segments-per-minute", type=int, default = 5, + help="This controls the maximum number of overlapping segments of speech that could be added to a recording per minute") + parser.add_argument('--random-seed', type=int, default=0, + help='seed to be used in the randomization of impulses and noises') + parser.add_argument("--shift-output", type=str, + help="If true, the reverberated waveform will be shifted by the amount of the peak position of the RIR", + choices=['true', 'false'], default = "true") + parser.add_argument("--output-additive-noise-dir", type=str, + action = common_train_lib.NullstrToNoneAction, default = None, + help="Output directory corresponding to the additive noise part of the data corruption") + parser.add_argument("--output-reverb-dir", type=str, + action = common_train_lib.NullstrToNoneAction, default = None, + help="Output directory corresponding to the reverberated signal part of the data corruption") + + parser.add_argument("input_dir", + help="Input data directory") + parser.add_argument("output_dir", + help="Output data directory") + + print(' '.join(sys.argv)) + + args = parser.parse_args() + args = CheckArgs(args) + + return args + +def CheckArgs(args): + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + + if args.output_reverb_dir is not None: + if args.output_reverb_dir == "": + args.output_reverb_dir = None + + if args.output_reverb_dir is not None: + if not os.path.exists(args.output_reverb_dir): + os.makedirs(args.output_reverb_dir) + + if args.output_additive_noise_dir is not None: + if args.output_additive_noise_dir == "": + args.output_additive_noise_dir = None + + if args.output_additive_noise_dir is not None: + if not os.path.exists(args.output_additive_noise_dir): + os.makedirs(args.output_additive_noise_dir) + + ## Check arguments. + + if args.num_replicas > 1 and args.prefix is None: + args.prefix = "rvb" + warnings.warn("--prefix is set to 'rvb' as --num-replications is larger than 1.") + + if not args.num_replicas > 0: + raise Exception("--num-replications cannot be non-positive") + + if args.speech_rvb_probability < 0 or args.speech_rvb_probability > 1: + raise Exception("--speech-rvb-probability must be between 0 and 1") + + if args.pointsource_noise_addition_probability < 0 or args.pointsource_noise_addition_probability > 1: + raise Exception("--pointsource-noise-addition-probability must be between 0 and 1") + + if args.isotropic_noise_addition_probability < 0 or args.isotropic_noise_addition_probability > 1: + raise Exception("--isotropic-noise-addition-probability must be between 0 and 1") + + if args.overlapping_speech_addition_probability < 0 or args.overlapping_speech_addition_probability > 1: + raise Exception("--overlapping-speech-addition-probability must be between 0 and 1") + + if args.rir_smoothing_weight < 0 or args.rir_smoothing_weight > 1: + raise Exception("--rir-smoothing-weight must be between 0 and 1") + + if args.noise_smoothing_weight < 0 or args.noise_smoothing_weight > 1: + raise Exception("--noise-smoothing-weight must be between 0 and 1") + + if args.overlapping_speech_smoothing_weight < 0 or args.overlapping_speech_smoothing_weight > 1: + raise Exception("--overlapping-speech-smoothing-weight must be between 0 and 1") + + if args.max_noises_per_minute < 0: + raise Exception("--max-noises-per-minute cannot be negative") + + if args.min_overlapping_segments_per_minute < 0: + raise Exception("--min-overlapping-segments-per-minute cannot be negative") + + if args.max_overlapping_segments_per_minute < 0: + raise Exception("--max-overlapping-segments-per-minute cannot be negative") + + return args + +def ParseSpeechSegmentsList(speech_segments_set_para_array, smoothing_weight): + set_list = [] + for set_para in speech_segments_set_para_array: + set = lambda: None + setattr(set, "wav_scp", None) + setattr(set, "segments", None) + setattr(set, "probability", None) + parts = set_para.split(',') + if len(parts) == 3: + set.probability = float(parts[0]) + set.wav_scp = parts[1].strip() + set.segments = parts[2].strip() + else: + set.wav_scp = parts[0].strip() + set.segments = parts[1].strip() + if not os.path.isfile(set.wav_scp): + raise Exception(set.wav_scp + " not found") + if not os.path.isfile(set.segments): + raise Exception(set.segments + " not found") + set_list.append(set) + + data_lib.SmoothProbabilityDistribution(set_list) + + segments_list = [] + for segments_set in set_list: + current_segments_list = [] + + wav_dict = {} + for s in open(segments_set.wav_scp): + parts = s.strip().split() + wav_dict[parts[0]] = ' '.join(parts[1:]) + + for s in open(segments_set.segments): + parts = s.strip().split() + current_segment = argparse.Namespace() + current_segment.utt_id = parts[0] + current_segment.probability = None + + start_time = float(parts[2]) + end_time = float(parts[3]) + + current_segment.duration = (end_time - start_time) + + wav_rxfilename = wav_dict[parts[1]] + if wav_rxfilename.split()[-1] == '|': + current_segment.wav_rxfilename = "{0} sox -t wav - -t wav - trim {1} {2} |".format(wav_rxfilename, start_time, end_time - start_time) + else: + current_segment.wav_rxfilename = "sox {0} -t wav - trim {1} {2} |".format(wav_rxfilename, start_time, end_time - start_time) + + current_segments_list.append(current_segment) + + segments_list += data_lib.SmoothProbabilityDistribution(current_segments_list, smoothing_weight, segments_set.probability) + + return segments_list + +def AddOverlappingSpeech(room, # the room selected + speech_segments_list, # the speech list + overlapping_speech_addition_probability, # Probability of another speech waveform + snrs, # the SNR for adding the foreground speech + speech_dur, # duration of the recording + min_overlapping_speech_segments, # Minimum number of speech signals that can be added + max_overlapping_speech_segments, # Maximum number of speech signals that can be added + overlapping_speech_descriptor # descriptor to store the information of the overlapping speech + ): + if (len(speech_segments_list) > 0 and random.random() < overlapping_speech_addition_probability + and max_overlapping_speech_segments >= 1): + for k in range(random.randint(min_overlapping_speech_segments, max_overlapping_speech_segments)): + # pick the overlapping_speech speech signal and the RIR to + # reverberate the overlapping_speech speech signal + speech_segment = data_lib.PickItemWithProbability(speech_segments_list) + rir = data_lib.PickItemWithProbability(room.rir_list) + + speech_rvb_command = """wav-reverberate --impulse-response="{0}" --shift-output=true """.format(rir.rir_rspecifier) + overlapping_speech_descriptor['start_times'].append(round(random.random() * speech_dur, 2)) + overlapping_speech_descriptor['snrs'].append(snrs.next()) + overlapping_speech_descriptor['utt_ids'].append(speech_segment.utt_id) + overlapping_speech_descriptor['durations'].append(speech_segment.duration) + + if len(speech_segment.wav_rxfilename.split()) == 1: + overlapping_speech_descriptor['speech_segments'].append("{1} {0} - |".format(speech_segment.wav_rxfilename, speech_rvb_command)) + else: + overlapping_speech_descriptor['speech_segments'].append("{0} {1} - - |".format(speech_segment.wav_rxfilename, speech_rvb_command)) + +# This function randomly decides whether to reverberate, and sample a RIR if it does +# It also decides whether to add the appropriate noises +# This function return the string of options to the binary wav-reverberate +def GenerateReverberationAndOverlappedSpeechOpts( + room_dict, # the room dictionary, please refer to MakeRoomDict() for the format + pointsource_noise_list, # the point source noise list + iso_noise_dict, # the isotropic noise dictionary + foreground_snrs, # the SNR for adding the foreground noises + background_snrs, # the SNR for adding the background noises + speech_segments_list, + overlap_snrs, + speech_rvb_probability, # Probability of reverberating a speech signal + isotropic_noise_addition_probability, # Probability of adding isotropic noises + pointsource_noise_addition_probability, # Probability of adding point-source noises + overlapping_speech_addition_probability, # Probability of adding overlapping speech segments + speech_dur, # duration of the recording + max_noises_recording, # Maximum number of point-source noises that can be added + min_overlapping_segments_recording, # Minimum number of overlapping segments that can be added + max_overlapping_segments_recording # Maximum number of overlapping segments that can be added + ): + impulse_response_opts = "" + + noise_addition_descriptor = {'noise_io': [], + 'start_times': [], + 'snrs': [], + 'noise_ids': [], + 'durations': [] + } + + # Randomly select the room + # Here the room probability is a sum of the probabilities of the RIRs recorded in the room. + room = data_lib.PickItemWithProbability(room_dict) + # Randomly select the RIR in the room + speech_rir = data_lib.PickItemWithProbability(room.rir_list) + if random.random() < speech_rvb_probability: + # pick the RIR to reverberate the speech + impulse_response_opts = """--impulse-response="{0}" """.format(speech_rir.rir_rspecifier) + + rir_iso_noise_list = [] + if speech_rir.room_id in iso_noise_dict: + rir_iso_noise_list = iso_noise_dict[speech_rir.room_id] + # Add the corresponding isotropic noise associated with the selected RIR + if len(rir_iso_noise_list) > 0 and random.random() < isotropic_noise_addition_probability: + isotropic_noise = data_lib.PickItemWithProbability(rir_iso_noise_list) + # extend the isotropic noise to the length of the speech waveform + # check if it is really a pipe + if len(isotropic_noise.noise_rspecifier.split()) == 1: + noise_addition_descriptor['noise_io'].append("wav-reverberate --duration={1} {0} - |".format(isotropic_noise.noise_rspecifier, speech_dur)) + else: + noise_addition_descriptor['noise_io'].append("{0} wav-reverberate --duration={1} - - |".format(isotropic_noise.noise_rspecifier, speech_dur)) + noise_addition_descriptor['start_times'].append(0) + noise_addition_descriptor['snrs'].append(background_snrs.next()) + noise_addition_descriptor['noise_ids'].append(isotropic_noise.noise_id) + noise_addition_descriptor['durations'].append(speech_dur) + + data_lib.AddPointSourceNoise(room, # the room selected + pointsource_noise_list, # the point source noise list + pointsource_noise_addition_probability, # Probability of adding point-source noises + foreground_snrs, # the SNR for adding the foreground noises + background_snrs, # the SNR for adding the background noises + speech_dur, # duration of the recording + max_noises_recording, # Maximum number of point-source noises that can be added + noise_addition_descriptor # descriptor to store the information of the noise added + ) + + assert len(noise_addition_descriptor['noise_io']) == len(noise_addition_descriptor['start_times']) + assert len(noise_addition_descriptor['noise_io']) == len(noise_addition_descriptor['snrs']) + assert len(noise_addition_descriptor['noise_io']) == len(noise_addition_descriptor['utt_ids']) + assert len(noise_addition_descriptor['noise_io']) == len(noise_addition_descriptor['durations']) + + overlapping_speech_descriptor = {'speech_segments': [], + 'start_times': [], + 'snrs': [], + 'utt_ids': [], + 'durations': [] + } + + print ("Adding overlapping speech...") + AddOverlappingSpeech(room, + speech_segments_list, # speech segments list + overlapping_speech_addition_probability, + overlap_snrs, + speech_dur, + min_overlapping_segments_recording, + max_overlapping_segments_recording, + overlapping_speech_descriptor + ) + + return [impulse_response_opts, noise_addition_descriptor, + overlapping_speech_descriptor] + +# This is the main function to generate pipeline command for the corruption +# The generic command of wav-reverberate will be like: +# wav-reverberate --duration=t --impulse-response=rir.wav +# --additive-signals='noise1.wav,noise2.wav' --snrs='snr1,snr2' --start-times='s1,s2' input.wav output.wav +def GenerateReverberatedWavScpWithOverlappedSpeech( + wav_scp, # a dictionary whose values are the Kaldi-IO strings of the speech recordings + durations, # a dictionary whose values are the duration (in sec) of the speech recordings + output_dir, # output directory to write the corrupted wav.scp + room_dict, # the room dictionary, please refer to MakeRoomDict() for the format + pointsource_noise_list, # the point source noise list + iso_noise_dict, # the isotropic noise dictionary + foreground_snr_array, # the SNR for adding the foreground noises + background_snr_array, # the SNR for adding the background noises + speech_segments_list, # list of speech segments to create overlapped speech + overlap_snr_array, # the SNR for adding overlapping speech + num_replicas, # Number of replicate to generated for the data + prefix, # prefix for the id of the corrupted utterances + speech_rvb_probability, # Probability of reverberating a speech signal + shift_output, # option whether to shift the output waveform + isotropic_noise_addition_probability, # Probability of adding isotropic noises + pointsource_noise_addition_probability, # Probability of adding point-source noises + max_noises_per_minute, # maximum number of point-source noises that can be added to a recording according to its duration + overlapping_speech_addition_probability, + min_overlapping_segments_per_minute, + max_overlapping_segments_per_minute, + output_reverb_dir = None, + output_additive_noise_dir = None, + ): + foreground_snrs = data_lib.list_cyclic_iterator(foreground_snr_array) + background_snrs = data_lib.list_cyclic_iterator(background_snr_array) + overlap_snrs = data_lib.list_cyclic_iterator(overlap_snr_array) + + corrupted_wav_scp = {} + reverb_wav_scp = {} + additive_noise_wav_scp = {} + overlapping_segments_info = {} + + keys = wav_scp.keys() + keys.sort() + for i in range(1, num_replicas+1): + for recording_id in keys: + wav_original_pipe = wav_scp[recording_id] + # check if it is really a pipe + if len(wav_original_pipe.split()) == 1: + wav_original_pipe = "cat {0} |".format(wav_original_pipe) + speech_dur = durations[recording_id] + max_noises_recording = math.floor(max_noises_per_minute * speech_dur / 60) + min_overlapping_segments_recording = max(math.floor(min_overlapping_segments_per_minute * speech_dur / 60), 1) + max_overlapping_segments_recording = math.floor(max_overlapping_segments_per_minute * speech_dur / 60) + + [impulse_response_opts, noise_addition_descriptor, + overlapping_speech_descriptor] = GenerateReverberationAndOverlappedSpeechOpts( + room_dict = room_dict, # the room dictionary, please refer to MakeRoomDict() for the format + pointsource_noise_list = pointsource_noise_list, # the point source noise list + iso_noise_dict = iso_noise_dict, # the isotropic noise dictionary + foreground_snrs = foreground_snrs, # the SNR for adding the foreground noises + background_snrs = background_snrs, # the SNR for adding the background noises + speech_segments_list = speech_segments_list, # Speech segments for creating overlapped speech + overlap_snrs = overlap_snrs, # the SNR for adding overlapping speech + speech_rvb_probability = speech_rvb_probability, # Probability of reverberating a speech signal + isotropic_noise_addition_probability = isotropic_noise_addition_probability, # Probability of adding isotropic noises + pointsource_noise_addition_probability = pointsource_noise_addition_probability, # Probability of adding point-source noises + overlapping_speech_addition_probability = overlapping_speech_addition_probability, + speech_dur = speech_dur, # duration of the recording + max_noises_recording = max_noises_recording, # Maximum number of point-source noises that can be added + min_overlapping_segments_recording = min_overlapping_segments_recording, + max_overlapping_segments_recording = max_overlapping_segments_recording + ) + + additive_noise_opts = "" + + if (len(noise_addition_descriptor['noise_io']) > 0 or + len(overlapping_speech_descriptor['speech_segments']) > 0): + additive_noise_opts += ("--additive-signals='{0}' " + .format(',' + .join(noise_addition_descriptor['noise_io'] + + overlapping_speech_descriptor['speech_segments'])) + ) + additive_noise_opts += ("--start-times='{0}' " + .format(',' + .join(map(lambda x:str(x), noise_addition_descriptor['start_times'] + + overlapping_speech_descriptor['start_times']))) + ) + additive_noise_opts += ("--snrs='{0}' " + .format(',' + .join(map(lambda x:str(x), noise_addition_descriptor['snrs'] + + overlapping_speech_descriptor['snrs']))) + ) + + reverberate_opts = impulse_response_opts + additive_noise_opts + + new_recording_id = data_lib.GetNewId(recording_id, prefix, i) + + if reverberate_opts == "": + wav_corrupted_pipe = "{0}".format(wav_original_pipe) + else: + wav_corrupted_pipe = "{0} wav-reverberate --shift-output={1} {2} - - |".format(wav_original_pipe, shift_output, reverberate_opts) + + corrupted_wav_scp[new_recording_id] = wav_corrupted_pipe + + if output_reverb_dir is not None: + if impulse_response_opts == "": + wav_reverb_pipe = "{0}".format(wav_original_pipe) + else: + wav_reverb_pipe = "{0} wav-reverberate --shift-output={1} --reverb-out-wxfilename=- {2} - /dev/null |".format(wav_original_pipe, shift_output, reverberate_opts) + reverb_wav_scp[new_recording_id] = wav_reverb_pipe + + if output_additive_noise_dir is not None: + if additive_noise_opts != "": + wav_additive_noise_pipe = "{0} wav-reverberate --shift-output={1} --additive-noise-out-wxfilename=- {2} - /dev/null |".format(wav_original_pipe, shift_output, reverberate_opts) + additive_noise_wav_scp[new_recording_id] = wav_additive_noise_pipe + + if len(overlapping_speech_descriptor['speech_segments']) > 0: + overlapping_segments_info[new_recording_id] = [ + ':'.join(x) + for x in zip(overlapping_speech_descriptor['utt_ids'], + [ str(x) for x in overlapping_speech_descriptor['start_times'] ], + [ str(x) for x in overlapping_speech_descriptor['durations'] ]) + ] + + data_lib.WriteDictToFile(corrupted_wav_scp, output_dir + "/wav.scp") + + # Write for each new recording, the id, start time and durations + # of the overlapping segments + data_lib.WriteDictToFile(overlapping_segments_info, output_dir + "/overlapped_segments_info.txt") + + if output_reverb_dir is not None: + data_lib.WriteDictToFile(reverb_wav_scp, output_reverb_dir + "/wav.scp") + + if output_additive_noise_dir is not None: + data_lib.WriteDictToFile(additive_noise_wav_scp, output_additive_noise_dir + "/wav.scp") + +# This function creates multiple copies of the necessary files, e.g. utt2spk, wav.scp ... +def CreateReverberatedCopy(input_dir, + output_dir, + room_dict, # the room dictionary, please refer to MakeRoomDict() for the format + pointsource_noise_list, # the point source noise list + iso_noise_dict, # the isotropic noise dictionary + speech_segments_list, + foreground_snr_string, # the SNR for adding the foreground noises + background_snr_string, # the SNR for adding the background noises + overlap_snr_string, # the SNR for overlapping speech + num_replicas, # Number of replicate to generated for the data + prefix, # prefix for the id of the corrupted utterances + speech_rvb_probability, # Probability of reverberating a speech signal + shift_output, # option whether to shift the output waveform + isotropic_noise_addition_probability, # Probability of adding isotropic noises + pointsource_noise_addition_probability, # Probability of adding point-source noises + max_noises_per_minute, # maximum number of point-source noises that can be added to a recording according to its duration + overlapping_speech_addition_probability, + min_overlapping_segments_per_minute, + max_overlapping_segments_per_minute, + output_reverb_dir = None, + output_additive_noise_dir = None + ): + + wav_scp = data_lib.ParseFileToDict(input_dir + "/wav.scp", value_processor = lambda x: " ".join(x)) + if not os.path.isfile(input_dir + "/reco2dur"): + print("Getting the duration of the recordings..."); + read_entire_file="false" + for value in wav_scp.values(): + # we will add more checks for sox commands which modify the header as we come across these cases in our data + if "sox" in value and "speed" in value: + read_entire_file="true" + break + data_lib.RunKaldiCommand("wav-to-duration --read-entire-file={1} scp:{0}/wav.scp ark,t:{0}/reco2dur".format(input_dir, read_entire_file)) + durations = data_lib.ParseFileToDict(input_dir + "/reco2dur", value_processor = lambda x: float(x[0])) + foreground_snr_array = map(lambda x: float(x), foreground_snr_string.split(':')) + background_snr_array = map(lambda x: float(x), background_snr_string.split(':')) + overlap_snr_array = map(lambda x: float(x), overlap_snr_string.split(':')) + + GenerateReverberatedWavScpWithOverlappedSpeech( + wav_scp = wav_scp, + durations = durations, + output_dir = output_dir, + room_dict = room_dict, + pointsource_noise_list = pointsource_noise_list, + iso_noise_dict = iso_noise_dict, + foreground_snr_array = foreground_snr_array, + background_snr_array = background_snr_array, + speech_segments_list = speech_segments_list, + overlap_snr_array = overlap_snr_array, + num_replicas = num_replicas, prefix = prefix, + speech_rvb_probability = speech_rvb_probability, + shift_output = shift_output, + isotropic_noise_addition_probability = isotropic_noise_addition_probability, + pointsource_noise_addition_probability = pointsource_noise_addition_probability, + max_noises_per_minute = max_noises_per_minute, + overlapping_speech_addition_probability = overlapping_speech_addition_probability, + min_overlapping_segments_per_minute = min_overlapping_segments_per_minute, + max_overlapping_segments_per_minute = max_overlapping_segments_per_minute, + output_reverb_dir = output_reverb_dir, + output_additive_noise_dir = output_additive_noise_dir) + + data_lib.CopyDataDirFiles(input_dir, output_dir, num_replicas, prefix) + data_lib.AddPrefixToFields(input_dir + "/reco2dur", output_dir + "/reco2dur", num_replicas, prefix, field = [0]) + + if output_reverb_dir is not None: + data_lib.CopyDataDirFiles(input_dir, output_reverb_dir, num_replicas, prefix) + data_lib.AddPrefixToFields(input_dir + "/reco2dur", output_reverb_dir + "/reco2dur", num_replicas, prefix, field = [0]) + + if output_additive_noise_dir is not None: + data_lib.CopyDataDirFiles(input_dir, output_additive_noise_dir, num_replicas, prefix) + data_lib.AddPrefixToFields(input_dir + "/reco2dur", output_additive_noise_dir + "/reco2dur", num_replicas, prefix, field = [0]) + + +def Main(): + args = GetArgs() + random.seed(args.random_seed) + rir_list = data_lib.ParseRirList(args.rir_set_para_array, args.rir_smoothing_weight) + print("Number of RIRs is {0}".format(len(rir_list))) + pointsource_noise_list = [] + iso_noise_dict = {} + if args.noise_set_para_array is not None: + pointsource_noise_list, iso_noise_dict = data_lib.ParseNoiseList(args.noise_set_para_array, args.noise_smoothing_weight) + print("Number of point-source noises is {0}".format(len(pointsource_noise_list))) + print("Number of isotropic noises is {0}".format(sum(len(iso_noise_dict[key]) for key in iso_noise_dict.keys()))) + room_dict = data_lib.MakeRoomDict(rir_list) + + speech_segments_list = ParseSpeechSegmentsList(args.speech_segments_set_para_array, args.overlapping_speech_smoothing_weight) + + CreateReverberatedCopy(input_dir = args.input_dir, + output_dir = args.output_dir, + room_dict = room_dict, + pointsource_noise_list = pointsource_noise_list, + iso_noise_dict = iso_noise_dict, + speech_segments_list = speech_segments_list, + foreground_snr_string = args.foreground_snr_string, + background_snr_string = args.background_snr_string, + overlap_snr_string = args.overlap_snr_string, + num_replicas = args.num_replicas, + prefix = args.prefix, + speech_rvb_probability = args.speech_rvb_probability, + shift_output = args.shift_output, + isotropic_noise_addition_probability = args.isotropic_noise_addition_probability, + pointsource_noise_addition_probability = args.pointsource_noise_addition_probability, + max_noises_per_minute = args.max_noises_per_minute, + overlapping_speech_addition_probability = args.overlapping_speech_addition_probability, + min_overlapping_segments_per_minute = args.min_overlapping_segments_per_minute, + max_overlapping_segments_per_minute = args.max_overlapping_segments_per_minute, + output_reverb_dir = args.output_reverb_dir, + output_additive_noise_dir = args.output_additive_noise_dir) + +if __name__ == "__main__": + Main() + + diff --git a/egs/wsj/s5/steps/data/reverberate_data_dir.py b/egs/wsj/s5/steps/data/reverberate_data_dir.py index 0083efa4939..c9a4d918c91 100755 --- a/egs/wsj/s5/steps/data/reverberate_data_dir.py +++ b/egs/wsj/s5/steps/data/reverberate_data_dir.py @@ -5,9 +5,11 @@ # we're using python 3.x style print but want it to work in python 2.x, from __future__ import print_function -import argparse, shlex, glob, math, os, random, sys, warnings, copy, imp, ast +import argparse, glob, math, os, random, sys, warnings, copy, imp, ast -data_lib = imp.load_source('dml', 'steps/data/data_dir_manipulation_lib.py') +import data_dir_manipulation_lib as data_lib +sys.path.insert(0, 'steps') +import libs.common as common_lib def GetArgs(): # we add required arguments as named arguments for readability @@ -20,7 +22,7 @@ def GetArgs(): "--random-seed 1 data/train data/train_rvb", formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument("--rir-set-parameters", type=str, action='append', required = True, dest = "rir_set_para_array", + parser.add_argument("--rir-set-parameters", type=str, action='append', required = True, dest = "rir_set_para_array", help="Specifies the parameters of an RIR set. " "Supports the specification of mixture_weight and rir_list_file_name. The mixture weight is optional. " "The default mixture weight is the probability mass remaining after adding the mixture weights " @@ -71,6 +73,13 @@ def GetArgs(): "the RIRs/noises will be resampled to the rate of the source data.") parser.add_argument("--include-original-data", type=str, help="If true, the output data includes one copy of the original data", choices=['true', 'false'], default = "false") + parser.add_argument("--output-additive-noise-dir", type=str, + action = common_lib.NullstrToNoneAction, default = None, + help="Output directory corresponding to the additive noise part of the data corruption") + parser.add_argument("--output-reverb-dir", type=str, + action = common_lib.NullstrToNoneAction, default = None, + help="Output directory corresponding to the reverberated signal part of the data corruption") + parser.add_argument("input_dir", help="Input data directory") parser.add_argument("output_dir", @@ -87,12 +96,27 @@ def CheckArgs(args): if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) - ## Check arguments + ## Check arguments. + if args.prefix is None: if args.num_replicas > 1 or args.include_original_data == "true": args.prefix = "rvb" warnings.warn("--prefix is set to 'rvb' as more than one copy of data is generated") + if args.output_reverb_dir is not None: + if not os.path.exists(args.output_reverb_dir): + os.makedirs(args.output_reverb_dir) + + if args.output_additive_noise_dir is not None: + if not os.path.exists(args.output_additive_noise_dir): + os.makedirs(args.output_additive_noise_dir) + + ## Check arguments. + + if args.num_replicas > 1 and args.prefix is None: + args.prefix = "rvb" + warnings.warn("--prefix is set to 'rvb' as --num-replications is larger than 1.") + if not args.num_replicas > 0: raise Exception("--num-replications cannot be non-positive") @@ -104,7 +128,7 @@ def CheckArgs(args): if args.isotropic_noise_addition_probability < 0 or args.isotropic_noise_addition_probability > 1: raise Exception("--isotropic-noise-addition-probability must be between 0 and 1") - + if args.rir_smoothing_weight < 0 or args.rir_smoothing_weight > 1: raise Exception("--rir-smoothing-weight must be between 0 and 1") @@ -113,208 +137,20 @@ def CheckArgs(args): if args.max_noises_per_minute < 0: raise Exception("--max-noises-per-minute cannot be negative") - + if args.source_sampling_rate is not None and args.source_sampling_rate <= 0: raise Exception("--source-sampling-rate cannot be non-positive") return args -class list_cyclic_iterator: - def __init__(self, list): - self.list_index = 0 - self.list = list - random.shuffle(self.list) - - def next(self): - item = self.list[self.list_index] - self.list_index = (self.list_index + 1) % len(self.list) - return item - - -# This functions picks an item from the collection according to the associated probability distribution. -# The probability estimate of each item in the collection is stored in the "probability" field of -# the particular item. x : a collection (list or dictionary) where the values contain a field called probability -def PickItemWithProbability(x): - if isinstance(x, dict): - plist = list(set(x.values())) - else: - plist = x - total_p = sum(item.probability for item in plist) - p = random.uniform(0, total_p) - accumulate_p = 0 - for item in plist: - if accumulate_p + item.probability >= p: - return item - accumulate_p += item.probability - assert False, "Shouldn't get here as the accumulated probability should always equal to 1" - - -# This function parses a file and pack the data into a dictionary -# It is useful for parsing file like wav.scp, utt2spk, text...etc -def ParseFileToDict(file, assert2fields = False, value_processor = None): - if value_processor is None: - value_processor = lambda x: x[0] - - dict = {} - for line in open(file, 'r'): - parts = line.split() - if assert2fields: - assert(len(parts) == 2) - - dict[parts[0]] = value_processor(parts[1:]) - return dict - -# This function creates a file and write the content of a dictionary into it -def WriteDictToFile(dict, file_name): - file = open(file_name, 'w') - keys = dict.keys() - keys.sort() - for key in keys: - value = dict[key] - if type(value) in [list, tuple] : - if type(value) is tuple: - value = list(value) - value.sort() - value = ' '.join(str(value)) - file.write('{0} {1}\n'.format(key, value)) - file.close() - - -# This function creates the utt2uniq file from the utterance id in utt2spk file -def CreateCorruptedUtt2uniq(input_dir, output_dir, num_replicas, include_original, prefix): - corrupted_utt2uniq = {} - # Parse the utt2spk to get the utterance id - utt2spk = ParseFileToDict(input_dir + "/utt2spk", value_processor = lambda x: " ".join(x)) - keys = utt2spk.keys() - keys.sort() - if include_original: - start_index = 0 - else: - start_index = 1 - - for i in range(start_index, num_replicas+1): - for utt_id in keys: - new_utt_id = GetNewId(utt_id, prefix, i) - corrupted_utt2uniq[new_utt_id] = utt_id - - WriteDictToFile(corrupted_utt2uniq, output_dir + "/utt2uniq") - - -def AddPointSourceNoise(noise_addition_descriptor, # descriptor to store the information of the noise added - room, # the room selected - pointsource_noise_list, # the point source noise list - pointsource_noise_addition_probability, # Probability of adding point-source noises - foreground_snrs, # the SNR for adding the foreground noises - background_snrs, # the SNR for adding the background noises - speech_dur, # duration of the recording - max_noises_recording # Maximum number of point-source noises that can be added - ): - if len(pointsource_noise_list) > 0 and random.random() < pointsource_noise_addition_probability and max_noises_recording >= 1: - for k in range(random.randint(1, max_noises_recording)): - # pick the RIR to reverberate the point-source noise - noise = PickItemWithProbability(pointsource_noise_list) - noise_rir = PickItemWithProbability(room.rir_list) - # If it is a background noise, the noise will be extended and be added to the whole speech - # if it is a foreground noise, the noise will not extended and be added at a random time of the speech - if noise.bg_fg_type == "background": - noise_rvb_command = """wav-reverberate --impulse-response="{0}" --duration={1}""".format(noise_rir.rir_rspecifier, speech_dur) - noise_addition_descriptor['start_times'].append(0) - noise_addition_descriptor['snrs'].append(background_snrs.next()) - else: - noise_rvb_command = """wav-reverberate --impulse-response="{0}" """.format(noise_rir.rir_rspecifier) - noise_addition_descriptor['start_times'].append(round(random.random() * speech_dur, 2)) - noise_addition_descriptor['snrs'].append(foreground_snrs.next()) - - # check if the rspecifier is a pipe or not - if len(noise.noise_rspecifier.split()) == 1: - noise_addition_descriptor['noise_io'].append("{1} {0} - |".format(noise.noise_rspecifier, noise_rvb_command)) - else: - noise_addition_descriptor['noise_io'].append("{0} {1} - - |".format(noise.noise_rspecifier, noise_rvb_command)) - - return noise_addition_descriptor - - -# This function randomly decides whether to reverberate, and sample a RIR if it does -# It also decides whether to add the appropriate noises -# This function return the string of options to the binary wav-reverberate -def GenerateReverberationOpts(room_dict, # the room dictionary, please refer to MakeRoomDict() for the format - pointsource_noise_list, # the point source noise list - iso_noise_dict, # the isotropic noise dictionary - foreground_snrs, # the SNR for adding the foreground noises - background_snrs, # the SNR for adding the background noises - speech_rvb_probability, # Probability of reverberating a speech signal - isotropic_noise_addition_probability, # Probability of adding isotropic noises - pointsource_noise_addition_probability, # Probability of adding point-source noises - speech_dur, # duration of the recording - max_noises_recording # Maximum number of point-source noises that can be added - ): - reverberate_opts = "" - noise_addition_descriptor = {'noise_io': [], - 'start_times': [], - 'snrs': []} - # Randomly select the room - # Here the room probability is a sum of the probabilities of the RIRs recorded in the room. - room = PickItemWithProbability(room_dict) - # Randomly select the RIR in the room - speech_rir = PickItemWithProbability(room.rir_list) - if random.random() < speech_rvb_probability: - # pick the RIR to reverberate the speech - reverberate_opts += """--impulse-response="{0}" """.format(speech_rir.rir_rspecifier) - - rir_iso_noise_list = [] - if speech_rir.room_id in iso_noise_dict: - rir_iso_noise_list = iso_noise_dict[speech_rir.room_id] - # Add the corresponding isotropic noise associated with the selected RIR - if len(rir_iso_noise_list) > 0 and random.random() < isotropic_noise_addition_probability: - isotropic_noise = PickItemWithProbability(rir_iso_noise_list) - # extend the isotropic noise to the length of the speech waveform - # check if the rspecifier is a pipe or not - if len(isotropic_noise.noise_rspecifier.split()) == 1: - noise_addition_descriptor['noise_io'].append("wav-reverberate --duration={1} {0} - |".format(isotropic_noise.noise_rspecifier, speech_dur)) - else: - noise_addition_descriptor['noise_io'].append("{0} wav-reverberate --duration={1} - - |".format(isotropic_noise.noise_rspecifier, speech_dur)) - noise_addition_descriptor['start_times'].append(0) - noise_addition_descriptor['snrs'].append(background_snrs.next()) - - noise_addition_descriptor = AddPointSourceNoise(noise_addition_descriptor, # descriptor to store the information of the noise added - room, # the room selected - pointsource_noise_list, # the point source noise list - pointsource_noise_addition_probability, # Probability of adding point-source noises - foreground_snrs, # the SNR for adding the foreground noises - background_snrs, # the SNR for adding the background noises - speech_dur, # duration of the recording - max_noises_recording # Maximum number of point-source noises that can be added - ) - - assert len(noise_addition_descriptor['noise_io']) == len(noise_addition_descriptor['start_times']) - assert len(noise_addition_descriptor['noise_io']) == len(noise_addition_descriptor['snrs']) - if len(noise_addition_descriptor['noise_io']) > 0: - reverberate_opts += "--additive-signals='{0}' ".format(','.join(noise_addition_descriptor['noise_io'])) - reverberate_opts += "--start-times='{0}' ".format(','.join(map(lambda x:str(x), noise_addition_descriptor['start_times']))) - reverberate_opts += "--snrs='{0}' ".format(','.join(map(lambda x:str(x), noise_addition_descriptor['snrs']))) - - return reverberate_opts - -# This function generates a new id from the input id -# This is needed when we have to create multiple copies of the original data -# E.g. GetNewId("swb0035", prefix="rvb", copy=1) returns a string "rvb1_swb0035" -def GetNewId(id, prefix=None, copy=0): - if prefix is not None: - new_id = prefix + str(copy) + "_" + id - else: - new_id = id - - return new_id - - # This is the main function to generate pipeline command for the corruption # The generic command of wav-reverberate will be like: -# wav-reverberate --duration=t --impulse-response=rir.wav +# wav-reverberate --duration=t --impulse-response=rir.wav # --additive-signals='noise1.wav,noise2.wav' --snrs='snr1,snr2' --start-times='s1,s2' input.wav output.wav def GenerateReverberatedWavScp(wav_scp, # a dictionary whose values are the Kaldi-IO strings of the speech recordings durations, # a dictionary whose values are the duration (in sec) of the speech recordings - output_dir, # output directory to write the corrupted wav.scp + output_dir, # output directory to write the corrupted wav.scp room_dict, # the room dictionary, please refer to MakeRoomDict() for the format pointsource_noise_list, # the point source noise list iso_noise_dict, # the isotropic noise dictionary @@ -327,13 +163,20 @@ def GenerateReverberatedWavScp(wav_scp, # a dictionary whose values are the Kal shift_output, # option whether to shift the output waveform isotropic_noise_addition_probability, # Probability of adding isotropic noises pointsource_noise_addition_probability, # Probability of adding point-source noises - max_noises_per_minute # maximum number of point-source noises that can be added to a recording according to its duration + max_noises_per_minute, # maximum number of point-source noises that can be added to a recording according to its duration + output_reverb_dir = None, + output_additive_noise_dir = None ): - foreground_snrs = list_cyclic_iterator(foreground_snr_array) - background_snrs = list_cyclic_iterator(background_snr_array) + foreground_snrs = data_lib.list_cyclic_iterator(foreground_snr_array) + background_snrs = data_lib.list_cyclic_iterator(background_snr_array) corrupted_wav_scp = {} + reverb_wav_scp = {} + additive_noise_wav_scp = {} keys = wav_scp.keys() keys.sort() + + additive_signals_info = {} + if include_original: start_index = 0 else: @@ -346,51 +189,71 @@ def GenerateReverberatedWavScp(wav_scp, # a dictionary whose values are the Kal if len(wav_original_pipe.split()) == 1: wav_original_pipe = "cat {0} |".format(wav_original_pipe) speech_dur = durations[recording_id] - max_noises_recording = math.floor(max_noises_per_minute * speech_dur / 60) - - reverberate_opts = GenerateReverberationOpts(room_dict, # the room dictionary, please refer to MakeRoomDict() for the format - pointsource_noise_list, # the point source noise list - iso_noise_dict, # the isotropic noise dictionary - foreground_snrs, # the SNR for adding the foreground noises - background_snrs, # the SNR for adding the background noises - speech_rvb_probability, # Probability of reverberating a speech signal - isotropic_noise_addition_probability, # Probability of adding isotropic noises - pointsource_noise_addition_probability, # Probability of adding point-source noises - speech_dur, # duration of the recording - max_noises_recording # Maximum number of point-source noises that can be added - ) + max_noises_recording = math.ceil(max_noises_per_minute * speech_dur / 60) + + [impulse_response_opts, noise_addition_descriptor] = data_lib.GenerateReverberationOpts(room_dict, # the room dictionary, please refer to MakeRoomDict() for the format + pointsource_noise_list, # the point source noise list + iso_noise_dict, # the isotropic noise dictionary + foreground_snrs, # the SNR for adding the foreground noises + background_snrs, # the SNR for adding the background noises + speech_rvb_probability, # Probability of reverberating a speech signal + isotropic_noise_addition_probability, # Probability of adding isotropic noises + pointsource_noise_addition_probability, # Probability of adding point-source noises + speech_dur, # duration of the recording + max_noises_recording # Maximum number of point-source noises that can be added + ) + additive_noise_opts = "" + + if len(noise_addition_descriptor['noise_io']) > 0: + additive_noise_opts += "--additive-signals='{0}' ".format(','.join(noise_addition_descriptor['noise_io'])) + additive_noise_opts += "--start-times='{0}' ".format(','.join(map(lambda x:str(x), noise_addition_descriptor['start_times']))) + additive_noise_opts += "--snrs='{0}' ".format(','.join(map(lambda x:str(x), noise_addition_descriptor['snrs']))) + + reverberate_opts = impulse_response_opts + additive_noise_opts + + new_recording_id = data_lib.GetNewId(recording_id, prefix, i) # prefix using index 0 is reserved for original data e.g. rvb0_swb0035 corresponds to the swb0035 recording in original data if reverberate_opts == "" or i == 0: - wav_corrupted_pipe = "{0}".format(wav_original_pipe) + wav_corrupted_pipe = "{0}".format(wav_original_pipe) else: wav_corrupted_pipe = "{0} wav-reverberate --shift-output={1} {2} - - |".format(wav_original_pipe, shift_output, reverberate_opts) - new_recording_id = GetNewId(recording_id, prefix, i) corrupted_wav_scp[new_recording_id] = wav_corrupted_pipe - WriteDictToFile(corrupted_wav_scp, output_dir + "/wav.scp") + if output_reverb_dir is not None: + if impulse_response_opts == "": + wav_reverb_pipe = "{0}".format(wav_original_pipe) + else: + wav_reverb_pipe = "{0} wav-reverberate --shift-output={1} --reverb-out-wxfilename=- {2} - /dev/null |".format(wav_original_pipe, shift_output, reverberate_opts) + reverb_wav_scp[new_recording_id] = wav_reverb_pipe + if output_additive_noise_dir is not None: + if additive_noise_opts != "": + wav_additive_noise_pipe = "{0} wav-reverberate --shift-output={1} --additive-noise-out-wxfilename=- {2} - /dev/null |".format(wav_original_pipe, shift_output, reverberate_opts) + additive_noise_wav_scp[new_recording_id] = wav_additive_noise_pipe -# This function replicate the entries in files like segments, utt2spk, text -def AddPrefixToFields(input_file, output_file, num_replicas, include_original, prefix, field = [0]): - list = map(lambda x: x.strip(), open(input_file)) - f = open(output_file, "w") - if include_original: - start_index = 0 - else: - start_index = 1 - - for i in range(start_index, num_replicas+1): - for line in list: - if len(line) > 0 and line[0] != ';': - split1 = line.split() - for j in field: - split1[j] = GetNewId(split1[j], prefix, i) - print(" ".join(split1), file=f) - else: - print(line, file=f) - f.close() + if additive_noise_opts != "": + additive_signals_info[new_recording_id] = [ + ':'.join(x) + for x in zip(noise_addition_descriptor['noise_ids'], + [ str(x) for x in noise_addition_descriptor['start_times'] ], + [ str(x) for x in noise_addition_descriptor['durations'] ]) + ] + + # Write for each new recording, the id, start time and durations + # of the signals. Duration is -1 for the foreground noise and needs to + # be extracted separately if required by determining the durations + # using the wav file + data_lib.WriteDictToFile(additive_signals_info, output_dir + "/additive_signals_info.txt") + + data_lib.WriteDictToFile(corrupted_wav_scp, output_dir + "/wav.scp") + + if output_reverb_dir is not None: + data_lib.WriteDictToFile(reverb_wav_scp, output_reverb_dir + "/wav.scp") + + if output_additive_noise_dir is not None: + data_lib.WriteDictToFile(additive_noise_wav_scp, output_additive_noise_dir + "/wav.scp") # This function creates multiple copies of the necessary files, e.g. utt2spk, wav.scp ... @@ -408,10 +271,12 @@ def CreateReverberatedCopy(input_dir, shift_output, # option whether to shift the output waveform isotropic_noise_addition_probability, # Probability of adding isotropic noises pointsource_noise_addition_probability, # Probability of adding point-source noises - max_noises_per_minute # maximum number of point-source noises that can be added to a recording according to its duration + max_noises_per_minute, # maximum number of point-source noises that can be added to a recording according to its duration + output_reverb_dir = None, + output_additive_noise_dir = None ): - - wav_scp = ParseFileToDict(input_dir + "/wav.scp", value_processor = lambda x: " ".join(x)) + + wav_scp = data_lib.ParseFileToDict(input_dir + "/wav.scp", value_processor = lambda x: " ".join(x)) if not os.path.isfile(input_dir + "/reco2dur"): print("Getting the duration of the recordings..."); read_entire_file="false" @@ -421,225 +286,38 @@ def CreateReverberatedCopy(input_dir, read_entire_file="true" break data_lib.RunKaldiCommand("wav-to-duration --read-entire-file={1} scp:{0}/wav.scp ark,t:{0}/reco2dur".format(input_dir, read_entire_file)) - durations = ParseFileToDict(input_dir + "/reco2dur", value_processor = lambda x: float(x[0])) + durations = data_lib.ParseFileToDict(input_dir + "/reco2dur", value_processor = lambda x: float(x[0])) foreground_snr_array = map(lambda x: float(x), foreground_snr_string.split(':')) background_snr_array = map(lambda x: float(x), background_snr_string.split(':')) GenerateReverberatedWavScp(wav_scp, durations, output_dir, room_dict, pointsource_noise_list, iso_noise_dict, - foreground_snr_array, background_snr_array, num_replicas, include_original, prefix, - speech_rvb_probability, shift_output, isotropic_noise_addition_probability, - pointsource_noise_addition_probability, max_noises_per_minute) + foreground_snr_array, background_snr_array, num_replicas, include_original, prefix, + speech_rvb_probability, shift_output, isotropic_noise_addition_probability, + pointsource_noise_addition_probability, max_noises_per_minute, + output_reverb_dir = output_reverb_dir, + output_additive_noise_dir = output_additive_noise_dir) - AddPrefixToFields(input_dir + "/utt2spk", output_dir + "/utt2spk", num_replicas, include_original, prefix, field = [0,1]) - data_lib.RunKaldiCommand("utils/utt2spk_to_spk2utt.pl <{output_dir}/utt2spk >{output_dir}/spk2utt" - .format(output_dir = output_dir)) + data_lib.CopyDataDirFiles(input_dir, output_dir, num_replicas, include_original, prefix) - if os.path.isfile(input_dir + "/utt2uniq"): - AddPrefixToFields(input_dir + "/utt2uniq", output_dir + "/utt2uniq", num_replicas, include_original, prefix, field =[0]) - else: - # Create the utt2uniq file - CreateCorruptedUtt2uniq(input_dir, output_dir, num_replicas, include_original, prefix) - - if os.path.isfile(input_dir + "/text"): - AddPrefixToFields(input_dir + "/text", output_dir + "/text", num_replicas, include_original, prefix, field =[0]) - if os.path.isfile(input_dir + "/segments"): - AddPrefixToFields(input_dir + "/segments", output_dir + "/segments", num_replicas, include_original, prefix, field = [0,1]) - if os.path.isfile(input_dir + "/reco2file_and_channel"): - AddPrefixToFields(input_dir + "/reco2file_and_channel", output_dir + "/reco2file_and_channel", num_replicas, include_original, prefix, field = [0,1]) - - data_lib.RunKaldiCommand("utils/validate_data_dir.sh --no-feats {output_dir}" - .format(output_dir = output_dir)) - - -# This function smooths the probability distribution in the list -def SmoothProbabilityDistribution(list, smoothing_weight=0.0, target_sum=1.0): - if len(list) > 0: - num_unspecified = 0 - accumulated_prob = 0 - for item in list: - if item.probability is None: - num_unspecified += 1 - else: - accumulated_prob += item.probability - - # Compute the probability for the items without specifying their probability - uniform_probability = 0 - if num_unspecified > 0 and accumulated_prob < 1: - uniform_probability = (1 - accumulated_prob) / float(num_unspecified) - elif num_unspecified > 0 and accumulate_prob >= 1: - warnings.warn("The sum of probabilities specified by user is larger than or equal to 1. " - "The items without probabilities specified will be given zero to their probabilities.") - - for item in list: - if item.probability is None: - item.probability = uniform_probability - else: - # smooth the probability - item.probability = (1 - smoothing_weight) * item.probability + smoothing_weight * uniform_probability - - # Normalize the probability - sum_p = sum(item.probability for item in list) - for item in list: - item.probability = item.probability / sum_p * target_sum - - return list - - -# This function parse the array of rir set parameter strings. -# It will assign probabilities to those rir sets which don't have a probability -# It will also check the existence of the rir list files. -def ParseSetParameterStrings(set_para_array): - set_list = [] - for set_para in set_para_array: - set = lambda: None - setattr(set, "filename", None) - setattr(set, "probability", None) - parts = set_para.split(',') - if len(parts) == 2: - set.probability = float(parts[0]) - set.filename = parts[1].strip() - else: - set.filename = parts[0].strip() - if not os.path.isfile(set.filename): - raise Exception(set.filename + " not found") - set_list.append(set) - - return SmoothProbabilityDistribution(set_list) - - -# This function creates the RIR list -# Each rir object in the list contains the following attributes: -# rir_id, room_id, receiver_position_id, source_position_id, rt60, drr, probability -# Please refer to the help messages in the parser for the meaning of these attributes -def ParseRirList(rir_set_para_array, smoothing_weight, sampling_rate = None): - rir_parser = argparse.ArgumentParser() - rir_parser.add_argument('--rir-id', type=str, required=True, help='This id is unique for each RIR and the noise may associate with a particular RIR by refering to this id') - rir_parser.add_argument('--room-id', type=str, required=True, help='This is the room that where the RIR is generated') - rir_parser.add_argument('--receiver-position-id', type=str, default=None, help='receiver position id') - rir_parser.add_argument('--source-position-id', type=str, default=None, help='source position id') - rir_parser.add_argument('--rt60', type=float, default=None, help='RT60 is the time required for reflections of a direct sound to decay 60 dB.') - rir_parser.add_argument('--drr', type=float, default=None, help='Direct-to-reverberant-ratio of the impulse response.') - rir_parser.add_argument('--cte', type=float, default=None, help='Early-to-late index of the impulse response.') - rir_parser.add_argument('--probability', type=float, default=None, help='probability of the impulse response.') - rir_parser.add_argument('rir_rspecifier', type=str, help="""rir rspecifier, it can be either a filename or a piped command. - E.g. data/impulses/Room001-00001.wav or "sox data/impulses/Room001-00001.wav -t wav - |" """) - - set_list = ParseSetParameterStrings(rir_set_para_array) - - rir_list = [] - for rir_set in set_list: - current_rir_list = map(lambda x: rir_parser.parse_args(shlex.split(x.strip())),open(rir_set.filename)) - for rir in current_rir_list: - if sampling_rate is not None: - # check if the rspecifier is a pipe or not - if len(rir.rir_rspecifier.split()) == 1: - rir.rir_rspecifier = "sox {0} -r {1} -t wav - |".format(rir.rir_rspecifier, sampling_rate) - else: - rir.rir_rspecifier = "{0} sox -t wav - -r {1} -t wav - |".format(rir.rir_rspecifier, sampling_rate) - - rir_list += SmoothProbabilityDistribution(current_rir_list, smoothing_weight, rir_set.probability) - - return rir_list - - -# This dunction checks if the inputs are approximately equal assuming they are floats. -def almost_equal(value_1, value_2, accuracy = 10**-8): - return abs(value_1 - value_2) < accuracy - -# This function converts a list of RIRs into a dictionary of RIRs indexed by the room-id. -# Its values are objects with two attributes: a local RIR list -# and the probability of the corresponding room -# Please look at the comments at ParseRirList() for the attributes that a RIR object contains -def MakeRoomDict(rir_list): - room_dict = {} - for rir in rir_list: - if rir.room_id not in room_dict: - # add new room - room_dict[rir.room_id] = lambda: None - setattr(room_dict[rir.room_id], "rir_list", []) - setattr(room_dict[rir.room_id], "probability", 0) - room_dict[rir.room_id].rir_list.append(rir) - - # the probability of the room is the sum of probabilities of its RIR - for key in room_dict.keys(): - room_dict[key].probability = sum(rir.probability for rir in room_dict[key].rir_list) - - assert almost_equal(sum(room_dict[key].probability for key in room_dict.keys()), 1.0) - - return room_dict - - -# This function creates the point-source noise list -# and the isotropic noise dictionary from the noise information file -# The isotropic noise dictionary is indexed by the room -# and its value is the corrresponding isotropic noise list -# Each noise object in the list contains the following attributes: -# noise_id, noise_type, bg_fg_type, room_linkage, probability, noise_rspecifier -# Please refer to the help messages in the parser for the meaning of these attributes -def ParseNoiseList(noise_set_para_array, smoothing_weight, sampling_rate = None): - noise_parser = argparse.ArgumentParser() - noise_parser.add_argument('--noise-id', type=str, required=True, help='noise id') - noise_parser.add_argument('--noise-type', type=str, required=True, help='the type of noise; i.e. isotropic or point-source', choices = ["isotropic", "point-source"]) - noise_parser.add_argument('--bg-fg-type', type=str, default="background", help='background or foreground noise, for background noises, ' - 'they will be extended before addition to cover the whole speech; for foreground noise, they will be kept ' - 'to their original duration and added at a random point of the speech.', choices = ["background", "foreground"]) - noise_parser.add_argument('--room-linkage', type=str, default=None, help='required if isotropic, should not be specified if point-source.') - noise_parser.add_argument('--probability', type=float, default=None, help='probability of the noise.') - noise_parser.add_argument('noise_rspecifier', type=str, help="""noise rspecifier, it can be either a filename or a piped command. - E.g. type5_noise_cirline_ofc_ambient1.wav or "sox type5_noise_cirline_ofc_ambient1.wav -t wav - |" """) - - set_list = ParseSetParameterStrings(noise_set_para_array) - - pointsource_noise_list = [] - iso_noise_dict = {} - for noise_set in set_list: - current_noise_list = map(lambda x: noise_parser.parse_args(shlex.split(x.strip())),open(noise_set.filename)) - current_pointsource_noise_list = [] - for noise in current_noise_list: - if sampling_rate is not None: - # check if the rspecifier is a pipe or not - if len(noise.noise_rspecifier.split()) == 1: - noise.noise_rspecifier = "sox {0} -r {1} -t wav - |".format(noise.noise_rspecifier, sampling_rate) - else: - noise.noise_rspecifier = "{0} sox -t wav - -r {1} -t wav - |".format(noise.noise_rspecifier, sampling_rate) + if output_reverb_dir is not None: + data_lib.CopyDataDirFiles(input_dir, output_reverb_dir, num_replicas, include_original, prefix) - if noise.noise_type == "isotropic": - if noise.room_linkage is None: - raise Exception("--room-linkage must be specified if --noise-type is isotropic") - else: - if noise.room_linkage not in iso_noise_dict: - iso_noise_dict[noise.room_linkage] = [] - iso_noise_dict[noise.room_linkage].append(noise) - else: - current_pointsource_noise_list.append(noise) - - pointsource_noise_list += SmoothProbabilityDistribution(current_pointsource_noise_list, smoothing_weight, noise_set.probability) - - # ensure the point-source noise probabilities sum to 1 - pointsource_noise_list = SmoothProbabilityDistribution(pointsource_noise_list, smoothing_weight, 1.0) - if len(pointsource_noise_list) > 0: - assert almost_equal(sum(noise.probability for noise in pointsource_noise_list), 1.0) - - # ensure the isotropic noise source probabilities for a given room sum to 1 - for key in iso_noise_dict.keys(): - iso_noise_dict[key] = SmoothProbabilityDistribution(iso_noise_dict[key]) - assert almost_equal(sum(noise.probability for noise in iso_noise_dict[key]), 1.0) - - return (pointsource_noise_list, iso_noise_dict) + if output_additive_noise_dir is not None: + data_lib.CopyDataDirFiles(input_dir, output_additive_noise_dir, num_replicas, include_original, prefix) def Main(): args = GetArgs() random.seed(args.random_seed) - rir_list = ParseRirList(args.rir_set_para_array, args.rir_smoothing_weight, args.source_sampling_rate) + rir_list = data_lib.ParseRirList(args.rir_set_para_array, args.rir_smoothing_weight, args.source_sampling_rate) print("Number of RIRs is {0}".format(len(rir_list))) pointsource_noise_list = [] iso_noise_dict = {} if args.noise_set_para_array is not None: - pointsource_noise_list, iso_noise_dict = ParseNoiseList(args.noise_set_para_array, args.noise_smoothing_weight, args.source_sampling_rate) + pointsource_noise_list, iso_noise_dict = data_lib.ParseNoiseList(args.noise_set_para_array, args.noise_smoothing_weight, args.source_sampling_rate) print("Number of point-source noises is {0}".format(len(pointsource_noise_list))) print("Number of isotropic noises is {0}".format(sum(len(iso_noise_dict[key]) for key in iso_noise_dict.keys()))) - room_dict = MakeRoomDict(rir_list) + room_dict = data_lib.MakeRoomDict(rir_list) if args.include_original_data == "true": include_original = True @@ -660,8 +338,9 @@ def Main(): shift_output = args.shift_output, isotropic_noise_addition_probability = args.isotropic_noise_addition_probability, pointsource_noise_addition_probability = args.pointsource_noise_addition_probability, - max_noises_per_minute = args.max_noises_per_minute) + max_noises_per_minute = args.max_noises_per_minute, + output_reverb_dir = args.output_reverb_dir, + output_additive_noise_dir = args.output_additive_noise_dir) if __name__ == "__main__": Main() - diff --git a/egs/wsj/s5/steps/data/split_wavs_randomly.py b/egs/wsj/s5/steps/data/split_wavs_randomly.py new file mode 100755 index 00000000000..b4c3b660ddd --- /dev/null +++ b/egs/wsj/s5/steps/data/split_wavs_randomly.py @@ -0,0 +1,114 @@ +#! /usr/bin/env python + +# Copyright 2016 Vimal Manohar +# Apache 2.0 + +from __future__ import print_function +import argparse +import random + +def get_args(): + parser = argparse.ArgumentParser(description="""This script converts a + wav.scp into split wav.scp that can be converted into noise-set-paramters + that can be passed to steps/data/reverberate_data_dir.py. The wav files in + wav.scp is trimmed randomly into pieces based on options such options such + as --max-duration, --skip-initial-duration and --num-parts-per-minute.""", + formatter_class=arparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument("--max-duration", type=float, default=30, + help="Maximum duration in seconds of the created " + "signal pieces") + parser.add_argument("--min-duration", type=float, default=0.5, + help="Minimum duration in seconds of the created " + "signal pieces") + parser.add_argument("--skip-initial-duration", type=float, default=5, + help="The duration in seconds of the original signal " + "that will be ignored while creating the pieces") + parser.add_argument("--num-parts-per-minute", type=int, default=3, + help="Used to control the number of parts to create " + "from a recording") + parser.add_argument("--sampling-rate", type=float, default=8000, + help="Required sampling rate of the output signals.") + parser.add_argument('--random-seed', type=int, default=0, + help='seed to be used in the random split of signals') + parser.add_argument("wav_scp", type=str, + help="The input wav.scp") + parser.add_argument("reco2dur", type=str, + help="""Durations of the recordings corresponding to the + input wav.scp""") + parser.add_argument("out_utt2dur", type=str, + help="Output utt2dur corresponding to split wavs") + parser.add_argument("out_wav_scp", type=str, + help="Output wav.scp corresponding to split wavs") + + args = parser.parse_args() + + return args + + +def get_noise_set(reco, reco_dur, wav_rspecifier_split, sampling_rate, + num_parts, max_duration, min_duration, skip_initial_duration): + noise_set = [] + for i in range(num_parts): + utt = "{0}-{1}".format(reco, i+1) + + start_time = round(random.random() * (reco_dur - skip_initial_duration) + + skip_initial_duration, 2) + duration = min(round(random.random() * (max_duration-min_duration) + + min_duration, 2), + reco_dur - start_time) + + if len(wav_rspecifier_split) == 1: + rspecifier = ("sox -D {wav} -r {sr} -t wav - " + "trim {st} {dur} |".format( + wav=wav_rspecifier_split[0], + sr=sampling_rate, st=start_time, dur=duration) + else: + rspecifier = ("{wav} sox -D -t wav - -r {sr} -t wav - " + "trim {st} {dur} |".format( + wav=" ".join(wav_rspecifier_split), + sr=sampling_rate, st=start_time, dur=duration) + + noise_set.append( (utt, rspecifier, duration) ) + return noise_set + + +def main(): + args = get_args() + random.seed(args.random_seed) + + reco2dur = {} + for line in open(args.reco2dur): + parts = line.strip().split() + if len(parts) != 2: + raise Exception( + "Expecting reco2dur to contain lines of the format " + " ; Got {0}".format(line)) + reco2dur[parts[0]] = float(parts[1]) + + out_wav_scp = open(args.out_wav_scp, 'w') + out_utt2dur = open(args.out_utt2dur, 'w') + + for line in open(args.wav_scp): + parts = line.strip().split() + reco = parts[0] + dur = reco2dur[reco] + + num_parts = int(float(args.num_parts_per_minute) / 60 * reco2dur[reco]) + + noise_set = get_noise_set( + reco, reco2dur[reco], wav_rspecifier_split=parts[1:], + sampling_rate=args.sampling_rate, num_parts=num_parts, + max_duration=args.max_duration, min_duration=args.min_duration, + skip_initial_duration=args.skip_initial_duration) + + for utt, rspecifier, dur in noise_set: + print ("{0} {1}".format(utt, rspecifier), file=out_wav_scp) + print ("{0} {1}".format(utt, dur), file=out_utt2dur) + + out_wav_scp.close() + out_utt2dur.close() + + +if __name__ == '__main__': + main() diff --git a/egs/wsj/s5/steps/data/wav_scp2noise_list.py b/egs/wsj/s5/steps/data/wav_scp2noise_list.py new file mode 100755 index 00000000000..960bce33c7d --- /dev/null +++ b/egs/wsj/s5/steps/data/wav_scp2noise_list.py @@ -0,0 +1,39 @@ +#! /usr/bin/env python + +# Copyright 2016 Vimal Manohar +# Apache 2.0 + +from __future__ import print_function +import argparse, random + +def GetArgs(): + parser = argparse.ArgumentParser(description="""This script converts a wav.scp +into noise-set-paramters that can be passed to steps/data/reverberate_data_dir.py.""") + + parser.add_argument("wav_scp", type=str, + help = "The input wav.scp") + parser.add_argument("noise_list", type=str, + help = "File to write the output noise-set-parameters") + + args = parser.parse_args() + + return args + +def Main(): + args = GetArgs() + + noise_list = open(args.noise_list, 'w') + + for line in open(args.wav_scp): + parts = line.strip().split() + + print ('''--noise-id {reco} --noise-type point-source \ +--bg-fg-type foreground "{wav}"'''.format( + reco = parts[0], + wav = " ".join(parts[1:])), file = noise_list) + + noise_list.close() + +if __name__ == '__main__': + Main() + diff --git a/egs/wsj/s5/steps/libs/__init__.py b/egs/wsj/s5/steps/libs/__init__.py index 013c95d0b3f..8f3540643c8 100644 --- a/egs/wsj/s5/steps/libs/__init__.py +++ b/egs/wsj/s5/steps/libs/__init__.py @@ -8,4 +8,4 @@ import common -__all__ = ["common"] +__all__ = ["common", "data"] diff --git a/egs/wsj/s5/steps/libs/common.py b/egs/wsj/s5/steps/libs/common.py index 9d01fae3027..f9bd87bef44 100644 --- a/egs/wsj/s5/steps/libs/common.py +++ b/egs/wsj/s5/steps/libs/common.py @@ -79,10 +79,13 @@ class KaldiCommandException(Exception): kaldi command that caused the error and the error string captured. """ def __init__(self, command, err=None): + import re Exception.__init__(self, "There was an error while running the command " - "{0}\n{1}\n{2}".format(command, "-"*10, - "" if err is None else err)) + "{0}\n{1}\n{2}".format( + re.sub('\s+', ' ', command).strip(), + "-"*10, + "" if err is None else err)) class BackgroundProcessHandler(): @@ -165,17 +168,20 @@ def add_process(self, t): self.start() def is_process_done(self, t): - p, command = t + p, command, exit_on_failure = t if p.poll() is None: return False return True def ensure_process_is_done(self, t): - p, command = t + p, command, exit_on_failure = t logger.debug("Waiting for process '{0}' to end".format(command)) [stdout, stderr] = p.communicate() if p.returncode is not 0: - raise KaldiCommandException(command, stderr) + print("There was an error while running the command " + "{0}\n{1}\n{2}".format(command, "-"*10, stderr)) + if exit_on_failure: + os._exit(1) def ensure_processes_are_done(self): self.__process_queue.reverse() @@ -192,7 +198,8 @@ def debug(self): logger.info("Process '{0}' is running".format(command)) -def run_job(command, wait=True, background_process_handler=None): +def run_job(command, wait=True, background_process_handler=None, + exit_on_failure=False): """ Runs a kaldi job, usually using a script such as queue.pl and run.pl, and redirects the stdout and stderr to the parent process's streams. @@ -206,12 +213,14 @@ class that is instantiated by the top-level script. If this is wait: If True, wait until the process is completed. However, if the background_process_handler is provided, this option will be ignored and the process will be run in the background. + exit_on_failure: If True, will exit from the script on failure. + Only applicable when background_process_handler is specified. """ p = subprocess.Popen(command, shell=True) if background_process_handler is not None: wait = False - background_process_handler.add_process((p, command)) + background_process_handler.add_process((p, command, exit_on_failure)) if wait: p.communicate() @@ -222,7 +231,8 @@ class that is instantiated by the top-level script. If this is return p -def run_kaldi_command(command, wait=True, background_process_handler=None): +def run_kaldi_command(command, wait=True, background_process_handler=None, + exit_on_failure=False): """ Runs commands frequently seen in Kaldi scripts and captures the stdout and stderr. These are usually a sequence of commands connected by pipes, so we use @@ -235,6 +245,8 @@ class that is instantiated by the top-level script. If this is wait: If True, wait until the process is completed. However, if the background_process_handler is provided, this option will be ignored and the process will be run in the background. + exit_on_failure: If True, will exit from the script on failure. + Only applicable when background_process_handler is specified. """ p = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, @@ -242,7 +254,7 @@ class that is instantiated by the top-level script. If this is if background_process_handler is not None: wait = False - background_process_handler.add_process((p, command)) + background_process_handler.add_process((p, command, exit_on_failure)) if wait: [stdout, stderr] = p.communicate() @@ -281,7 +293,7 @@ def get_number_of_jobs(alidir): num_jobs = int(open('{0}/num_jobs'.format(alidir)).readline().strip()) except (IOError, ValueError) as e: raise Exception("Exception while reading the " - "number of alignment jobs: {0}".format(e.errstr)) + "number of alignment jobs: {0}".format(e)) return num_jobs @@ -325,6 +337,7 @@ def split_data(data, num_jobs): run_kaldi_command("utils/split_data.sh {data} {num_jobs}".format( data=data, num_jobs=num_jobs)) + return "{0}/split{1}".format(data, num_jobs) def read_kaldi_matrix(matrix_file): diff --git a/egs/wsj/s5/steps/libs/data.py b/egs/wsj/s5/steps/libs/data.py new file mode 100644 index 00000000000..44895cae1a4 --- /dev/null +++ b/egs/wsj/s5/steps/libs/data.py @@ -0,0 +1,57 @@ +import os + +import libs.common as common_lib + +def get_frame_shift(data_dir): + frame_shift = common_lib.run_kaldi_command("utils/data/get_frame_shift.sh {0}".format(data_dir))[0] + return float(frame_shift.strip()) + +def generate_utt2dur(data_dir): + common_lib.run_kaldi_command("utils/data/get_utt2dur.sh {0}".format(data_dir)) + +def get_utt2dur(data_dir): + generate_utt2dur(data_dir) + utt2dur = {} + for line in open('{0}/utt2dur'.format(data_dir), 'r').readlines(): + parts = line.split() + utt2dur[parts[0]] = float(parts[1]) + return utt2dur + +def get_utt2uniq(data_dir): + utt2uniq_file = '{0}/utt2uniq'.format(data_dir) + if not os.path.exists(utt2uniq_file): + return None, None + utt2uniq = {} + uniq2utt = {} + for line in open(utt2uniq_file, 'r').readlines(): + parts = line.split() + utt2uniq[parts[0]] = parts[1] + if uniq2utt.has_key(parts[1]): + uniq2utt[parts[1]].append(parts[0]) + else: + uniq2utt[parts[1]] = [parts[0]] + return utt2uniq, uniq2utt + +def get_num_frames(data_dir, utts = None): + generate_utt2dur(data_dir) + frame_shift = get_frame_shift(data_dir) + total_duration = 0 + utt2dur = get_utt2dur(data_dir) + if utts is None: + utts = utt2dur.keys() + for utt in utts: + total_duration = total_duration + utt2dur[utt] + return int(float(total_duration)/frame_shift) + +def create_data_links(file_names): + # if file_names already exist create_data_link.pl returns with code 1 + # so we just delete them before calling create_data_link.pl + for file_name in file_names: + try_to_delete(file_name) + common_lib.run_kaldi_command(" utils/create_data_link.pl {0}".format(" ".join(file_names))) + +def try_to_delete(file_name): + try: + os.remove(file_name) + except OSError: + pass diff --git a/egs/wsj/s5/steps/libs/nnet3/train/chain_objf/acoustic_model.py b/egs/wsj/s5/steps/libs/nnet3/train/chain_objf/acoustic_model.py index f28aa89774e..be7b8d491dc 100644 --- a/egs/wsj/s5/steps/libs/nnet3/train/chain_objf/acoustic_model.py +++ b/egs/wsj/s5/steps/libs/nnet3/train/chain_objf/acoustic_model.py @@ -39,13 +39,13 @@ def create_phone_lm(dir, tree_dir, run_opts, lm_opts=None): common_lib.run_job( """{command} {dir}/log/make_phone_lm.log \ - gunzip -c {alignments} \| \ - ali-to-phones {tree_dir}/final.mdl ark:- ark:- \| \ - chain-est-phone-lm {lm_opts} ark:- {dir}/phone_lm.fst""".format( - command=run_opts.command, dir=dir, - alignments=alignments, - lm_opts=lm_opts if lm_opts is not None else '', - tree_dir=tree_dir)) + gunzip -c {alignments} \| \ + ali-to-phones {tree_dir}/final.mdl ark:- ark:- \| \ + chain-est-phone-lm {lm_opts} ark:- {dir}/phone_lm.fst""".format( + command=run_opts.command, dir=dir, + alignments=alignments, + lm_opts=lm_opts if lm_opts is not None else '', + tree_dir=tree_dir)) def create_denominator_fst(dir, tree_dir, run_opts): diff --git a/egs/wsj/s5/steps/libs/nnet3/train/common.py b/egs/wsj/s5/steps/libs/nnet3/train/common.py index e6ef511e7f2..bf0c36c46ca 100644 --- a/egs/wsj/s5/steps/libs/nnet3/train/common.py +++ b/egs/wsj/s5/steps/libs/nnet3/train/common.py @@ -528,6 +528,122 @@ def get_learning_rate(iter, num_jobs, num_iters, num_archives_processed, return num_jobs * effective_learning_rate +def parse_dropout_option(num_archives_to_process, dropout_option): + components = dropout_option.strip().split(' ') + dropout_schedule = [] + for component in components: + parts = component.split('=') + + if len(parts) == 2: + component_name = parts[0] + this_dropout_str = parts[1] + elif len(parts) == 1: + component_name = '*' + this_dropout_str = parts[0] + else: + raise Exception("The dropout schedule must be specified in the " + "format 'pattern1=func1 patter2=func2' where " + "the pattern can be omitted for a global function " + "for all components.\n" + "Got {0} in {1}".format(component, dropout_option)) + + this_dropout_values = _parse_dropout_string( + num_archives_to_process, this_dropout_str) + dropout_schedule.append((component_name, this_dropout_values)) + return dropout_schedule + + +def _parse_dropout_string(num_archives_to_process, dropout_str): + dropout_values = [] + parts = dropout_str.strip().split(',') + try: + if len(parts) < 2: + raise Exception("dropout proportion string must specify " + "at least the start and end dropouts") + + dropout_values.append((0, float(parts[0]))) + for i in range(1, len(parts)): + value_x_pair = parts[i].split('@') + if len(value_x_pair) == 1: + dropout_proportion = float(parts[i]) + dropout_values.append((0.5 * num_archives_to_process, + dropout_proportion)) + else: + assert len(value_x_pair) == 2 + dropout_proportion, data_fraction = value_x_pair + dropout_values.append( + (float(data_fraction) * num_archives_to_process, + float(dropout_proportion))) + + dropout_values.append((num_archives_to_process, float(parts[-1]))) + except Exception as e: + logger.error("Unable to parse dropout proportion string {0}. " + "See help for option " + "--dropout-schedule.".format(dropout_str)) + raise e + + # reverse sort so that its easy to retrieve the dropout proportion + # for a particular data fraction + dropout_values.sort(key=lambda x: x[0], reverse=True) + for num_archives, proportion in dropout_values: + assert num_archives <= num_archives_to_process and num_archives >= 0 + assert proportion <= 1 and proportion >= 0 + + return dropout_values + + +def get_dropout_proportions(dropout_schedule, + num_archives_processed): + + dropout_proportions = [] + for component_name, component_dropout_schedule in dropout_schedule: + dropout_proportions.append( + (component_name, + _get_component_dropout(component_dropout_schedule, + num_archives_processed))) + return dropout_proportions + + +def _get_component_dropout(dropout_schedule, num_archives_processed): + if num_archives_processed == 0: + assert dropout_schedule[-1][0] == 0 + return dropout_schedule[-1][1] + try: + (dropout_schedule_index, initial_num_archives, + initial_dropout) = next((i, tup[0], tup[1]) + for i, tup in enumerate(dropout_schedule) + if tup[0] < num_archives_processed) + except StopIteration as e: + logger.error("Could not find num_archives in dropout schedule " + "corresponding to num_archives_processed {0}.\n" + "Maybe something wrong with the parsed " + "dropout schedule {1}.".format( + num_archives_processed, dropout_schedule)) + raise e + + final_num_archives, final_dropout = dropout_schedule[ + dropout_schedule_index - 1] + assert (num_archives_processed > initial_num_archives + and num_archives_processed < final_num_archives) + + return ((num_archives_processed - initial_num_archives) + * (final_dropout - initial_dropout) + / (final_num_archives - initial_num_archives)) + + +def apply_dropout(dropout_proportions, raw_model_string): + edit_config_lines = [] + + for component_name, dropout_proportion in dropout_proportions: + edit_config_lines.append( + "set-dropout-proportion name={0} proportion={1}".format( + component_name, dropout_proportion)) + + return ("""{raw_model_string} nnet3-copy --edits='{edits}' \ + - - |""".format(raw_model_string=raw_model_string, + edits=";".join(edit_config_lines))) + + def do_shrinkage(iter, model_file, shrink_saturation_threshold, get_raw_nnet_from_am=True): @@ -593,7 +709,7 @@ def remove_model(nnet_dir, iter, num_iters, models_to_combine=None, os.remove(file_name) -def self_test(): +def _self_test(): assert halve_minibatch_size_str('64') == '32' assert halve_minibatch_size_str('64,16:32') == '32,8:16' assert halve_minibatch_size_str('1') == '1' @@ -681,6 +797,18 @@ def __init__(self, action=common_lib.NullstrToNoneAction, help="""String to provide options directly to steps/nnet3/get_egs.sh script""") + self.parser.add_argument("--egs.use-multitask-egs", type=str, + dest='use_multitask_egs', + default=False, choices=["true", "false"], + action=common_lib.StrToBoolAction, + help="""Use mutlitask egs created using + allocate_multilingual_egs.py.""") + self.parser.add_argument("--egs.rename-multitask-outputs", type=str, + dest='rename_multitask_outputs', + default=True, choices=["true", "false"], + action=common_lib.StrToBoolAction, + help="""Rename multitask outputs created using + allocate_multilingual_egs.py.""") # trainer options self.parser.add_argument("--trainer.srand", type=int, dest='srand', @@ -798,6 +926,13 @@ def __init__(self, lstm*=0,0.2,0'. More general should precede less general patterns, as they are applied sequentially.""") + self.parser.add_argument("--trainer.compute-per-dim-accuracy", + dest='compute_per_dim_accuracy', + type=str, choices=['true', 'false'], + default=False, + action=common_lib.StrToBoolAction, + help="Compute train and validation " + "accuracy per-dim") # General options self.parser.add_argument("--stage", type=int, default=-4, @@ -864,4 +999,4 @@ def __init__(self, if __name__ == '__main__': - self_test() + _self_test() diff --git a/egs/wsj/s5/steps/libs/nnet3/train/frame_level_objf/common.py b/egs/wsj/s5/steps/libs/nnet3/train/frame_level_objf/common.py index 3e732313612..c8891873a76 100644 --- a/egs/wsj/s5/steps/libs/nnet3/train/frame_level_objf/common.py +++ b/egs/wsj/s5/steps/libs/nnet3/train/frame_level_objf/common.py @@ -13,6 +13,7 @@ import logging import math import os +import random import time import libs.common as common_lib @@ -30,7 +31,9 @@ def train_new_models(dir, iter, srand, num_jobs, shuffle_buffer_size, minibatch_size_str, cache_read_opt, run_opts, frames_per_eg=-1, - min_deriv_time=None, max_deriv_time_relative=None): + min_deriv_time=None, max_deriv_time_relative=None, + extra_egs_copy_cmd="", use_multitask_egs=False, + rename_multitask_outputs=False): """ Called from train_one_iteration(), this model does one iteration of training with 'num_jobs' jobs, and writes files like exp/tdnn_a/24.{1,2,3,..}.raw @@ -80,6 +83,51 @@ def train_new_models(dir, iter, srand, num_jobs, cache_write_opt = "--write-cache={dir}/cache.{iter}".format( dir=dir, iter=iter+1) + if use_multitask_egs: + output_rename_opt = "" + if rename_multitask_outputs: + output_rename_opt = ( + "--output=ark:{egs_dir}" + "/output.{archive_index}".format( + egs_dir=egs_dir, archive_index=archive_index)) + egs_rspecifier = ( + "ark,bg:nnet3-copy-egs {frame_opts} {context_opts} " + "{output_rename_opt} " + "--weights=ark:{egs_dir}/weight.{archive_index} " + "scp:{egs_dir}/egs.{archive_index} ark:- | " + "{extra_egs_copy_cmd}" + "nnet3-merge-egs --minibatch-size={minibatch_size_str} " + "--measure-output-frames=false " + "--discard-partial-minibatches=true ark:- ark:- | " + "nnet3-shuffle-egs " + "--buffer-size={shuffle_buffer_size} --srand={srand} " + "ark:- ark:- |".format( + frame_opts=("" if chunk_level_training + else "--frame={0}".format(frame)), + context_opts=context_opts, egs_dir=egs_dir, + output_rename_opt=output_rename_opt, + archive_index=archive_index, srand=iter + srand, + shuffle_buffer_size=shuffle_buffer_size, + extra_egs_copy_cmd=extra_egs_copy_cmd, + minibatch_size_str=minibatch_size_str)) + else: + egs_rspecifier = ( + "ark,bg:nnet3-copy-egs {frame_opts} {context_opts} " + "ark:{egs_dir}/egs.{archive_index}.ark ark:- |" + "{extra_egs_copy_cmd}" + "nnet3-shuffle-egs --buffer-size={shuffle_buffer_size} " + "--srand={srand} ark:- ark:- | " + "nnet3-merge-egs --minibatch-size={minibatch_size_str} " + "--measure-output-frames=false " + "--discard-partial-minibatches=true ark:- ark:- |".format( + frame_opts=("" if chunk_level_training + else "--frame={0}".format(frame)), + context_opts=context_opts, egs_dir=egs_dir, + archive_index=archive_index, srand=iter + srand, + shuffle_buffer_size=shuffle_buffer_size, + extra_egs_copy_cmd=extra_egs_copy_cmd, + minibatch_size_str=minibatch_size_str)) + process_handle = common_lib.run_job( """{command} {train_queue_opt} {dir}/log/train.{iter}.{job}.log \ nnet3-train {parallel_train_opts} {cache_read_opt} \ @@ -87,31 +135,18 @@ def train_new_models(dir, iter, srand, num_jobs, --momentum={momentum} \ --max-param-change={max_param_change} \ {deriv_time_opts} "{raw_model}" \ - "ark,bg:nnet3-copy-egs {frame_opts} {context_opts} """ - """ark:{egs_dir}/egs.{archive_index}.ark ark:- |""" - """nnet3-shuffle-egs --buffer-size={shuffle_buffer_size} """ - """--srand={srand} ark:- ark:- | """ - """nnet3-merge-egs --minibatch-size={minibatch_size_str} """ - """--measure-output-frames=false """ - """--discard-partial-minibatches=true ark:- ark:- |" \ + "{egs_rspecifier}" \ {dir}/{next_iter}.{job}.raw""".format( command=run_opts.command, train_queue_opt=run_opts.train_queue_opt, - dir=dir, iter=iter, srand=iter + srand, - next_iter=iter + 1, - job=job, + dir=dir, iter=iter, next_iter=iter + 1, job=job, parallel_train_opts=run_opts.parallel_train_opts, cache_read_opt=cache_read_opt, cache_write_opt=cache_write_opt, - frame_opts=("" - if chunk_level_training - else "--frame={0}".format(frame)), momentum=momentum, max_param_change=max_param_change, deriv_time_opts=" ".join(deriv_time_opts), - raw_model=raw_model_string, context_opts=context_opts, - egs_dir=egs_dir, archive_index=archive_index, - shuffle_buffer_size=shuffle_buffer_size, - minibatch_size_str=minibatch_size_str), wait=False) + raw_model=raw_model_string, + egs_rspecifier=egs_rspecifier), wait=False) processes.append(process_handle) @@ -138,7 +173,10 @@ def train_one_iteration(dir, iter, srand, egs_dir, min_deriv_time=None, max_deriv_time_relative=None, shrinkage_value=1.0, dropout_edit_string="", get_raw_nnet_from_am=True, - background_process_handler=None): + background_process_handler=None, + extra_egs_copy_cmd="", use_multitask_egs=False, + rename_multitask_outputs=False, + compute_per_dim_accuracy=False): """ Called from steps/nnet3/train_*.py scripts for one iteration of neural network training @@ -183,7 +221,9 @@ def train_one_iteration(dir, iter, srand, egs_dir, left_context=left_context, right_context=right_context, run_opts=run_opts, get_raw_nnet_from_am=get_raw_nnet_from_am, wait=False, - background_process_handler=background_process_handler) + background_process_handler=background_process_handler, + extra_egs_copy_cmd=extra_egs_copy_cmd, + compute_per_dim_accuracy=compute_per_dim_accuracy) if iter > 0: # Runs in the background @@ -193,7 +233,8 @@ def train_one_iteration(dir, iter, srand, egs_dir, run_opts=run_opts, wait=False, get_raw_nnet_from_am=get_raw_nnet_from_am, - background_process_handler=background_process_handler) + background_process_handler=background_process_handler, + extra_egs_copy_cmd=extra_egs_copy_cmd) # an option for writing cache (storing pairs of nnet-computations # and computation-requests) during training. @@ -271,14 +312,16 @@ def train_one_iteration(dir, iter, srand, egs_dir, num_archives_processed=num_archives_processed, num_archives=num_archives, raw_model_string=raw_model_string, egs_dir=egs_dir, - left_context=left_context, right_context=right_context, momentum=momentum, max_param_change=cur_max_param_change, shuffle_buffer_size=shuffle_buffer_size, minibatch_size_str=cur_minibatch_size_str, cache_read_opt=cache_read_opt, run_opts=run_opts, frames_per_eg=frames_per_eg, min_deriv_time=min_deriv_time, - max_deriv_time_relative=max_deriv_time_relative) + max_deriv_time_relative=max_deriv_time_relative, + extra_egs_copy_cmd=extra_egs_copy_cmd, + use_multitask_egs=use_multitask_egs, + rename_multitask_outputs=rename_multitask_outputs) [models_to_average, best_model] = common_train_lib.get_successful_models( num_jobs, '{0}/log/train.{1}.%.log'.format(dir, iter)) @@ -335,25 +378,25 @@ def compute_preconditioning_matrix(dir, egs_dir, num_lda_jobs, run_opts, # Write stats with the same format as stats for LDA. common_lib.run_job( - """{command} JOB=1:{num_lda_jobs} {dir}/log/get_lda_stats.JOB.log \ - nnet3-acc-lda-stats --rand-prune={rand_prune} \ - {dir}/init.raw "ark:{egs_dir}/egs.JOB.ark" \ - {dir}/JOB.lda_stats""".format( - command=run_opts.command, - num_lda_jobs=num_lda_jobs, - dir=dir, - egs_dir=egs_dir, - rand_prune=rand_prune)) + """{command} JOB=1:{num_lda_jobs} {dir}/log/get_lda_stats.JOB.log """ + """ nnet3-acc-lda-stats --rand-prune={rand_prune}""" + """ {dir}/init.raw "ark:{egs_dir}/egs.JOB.ark" """ + """ {dir}/JOB.lda_stats""".format( + command=run_opts.command, + num_lda_jobs=num_lda_jobs, + dir=dir, + egs_dir=egs_dir, + rand_prune=rand_prune)) # the above command would have generated dir/{1..num_lda_jobs}.lda_stats lda_stat_files = map(lambda x: '{0}/{1}.lda_stats'.format(dir, x), range(1, num_lda_jobs + 1)) common_lib.run_job( - """{command} {dir}/log/sum_transform_stats.log \ - sum-lda-accs {dir}/lda_stats {lda_stat_files}""".format( - command=run_opts.command, - dir=dir, lda_stat_files=" ".join(lda_stat_files))) + "{command} {dir}/log/sum_transform_stats.log " + "sum-lda-accs {dir}/lda_stats {lda_stat_files}".format( + command=run_opts.command, + dir=dir, lda_stat_files=" ".join(lda_stat_files))) for file in lda_stat_files: try: @@ -367,11 +410,11 @@ def compute_preconditioning_matrix(dir, egs_dir, num_lda_jobs, run_opts, # variant of an LDA transform but without dimensionality reduction. common_lib.run_job( - """{command} {dir}/log/get_transform.log \ - nnet-get-feature-transform {lda_opts} {dir}/lda.mat \ - {dir}/lda_stats""".format( - command=run_opts.command, dir=dir, - lda_opts=lda_opts if lda_opts is not None else "")) + "{command} {dir}/log/get_transform.log" + " nnet-get-feature-transform {lda_opts} {dir}/lda.mat" + " {dir}/lda_stats".format( + command=run_opts.command, dir=dir, + lda_opts=lda_opts if lda_opts is not None else "")) common_lib.force_symlink("../lda.mat", "{0}/configs/lda.mat".format(dir)) @@ -379,7 +422,9 @@ def compute_preconditioning_matrix(dir, egs_dir, num_lda_jobs, run_opts, def compute_train_cv_probabilities(dir, iter, egs_dir, left_context, right_context, run_opts, wait=False, background_process_handler=None, - get_raw_nnet_from_am=True): + get_raw_nnet_from_am=True, + extra_egs_copy_cmd="", + compute_per_dim_accuracy=False): if get_raw_nnet_from_am: model = "nnet3-am-copy --raw=true {dir}/{iter}.mdl - |".format( dir=dir, iter=iter) @@ -389,38 +434,61 @@ def compute_train_cv_probabilities(dir, iter, egs_dir, left_context, context_opts = "--left-context={lc} --right-context={rc}".format( lc=left_context, rc=right_context) + if os.path.isfile("{0}/valid_diagnostic.egs".format(egs_dir)): + valid_diagnostic_egs = "ark:{0}/valid_diagnostic.egs".format(egs_dir) + else: + valid_diagnostic_egs = "scp:{0}/valid_diagnostic.egs.1".format( + egs_dir) + + opts = [] + if compute_per_dim_accuracy: + opts.append("--compute-per-dim-accuracy") + common_lib.run_job( """ {command} {dir}/log/compute_prob_valid.{iter}.log \ nnet3-compute-prob "{model}" \ - "ark,bg:nnet3-copy-egs {context_opts} \ - ark:{egs_dir}/valid_diagnostic.egs ark:- | \ + "ark,bg:nnet3-copy-egs {opts} {context_opts} \ + ark:{egs_dir}/valid_diagnostic.egs ark:- |{extra_egs_copy_cmd} \ nnet3-merge-egs --minibatch-size=1:64 ark:- \ ark:- |" """.format(command=run_opts.command, dir=dir, iter=iter, + opts=' '.join(opts), context_opts=context_opts, model=model, - egs_dir=egs_dir), + egs_dir=egs_dir, + extra_egs_copy_cmd=extra_egs_copy_cmd), wait=wait, background_process_handler=background_process_handler) + if os.path.isfile("{0}/train_diagnostic.egs".format(egs_dir)): + train_diagnostic_egs = "ark:{0}/train_diagnostic.egs".format(egs_dir) + else: + train_diagnostic_egs = "scp:{0}/train_diagnostic.egs.1".format( + egs_dir) + common_lib.run_job( - """{command} {dir}/log/compute_prob_train.{iter}.log \ - nnet3-compute-prob "{model}" \ - "ark,bg:nnet3-copy-egs {context_opts} \ - ark:{egs_dir}/train_diagnostic.egs ark:- | \ - nnet3-merge-egs --minibatch-size=1:64 ark:- \ - ark:- |" """.format(command=run_opts.command, - dir=dir, - iter=iter, - context_opts=context_opts, - model=model, - egs_dir=egs_dir), + """{command} {dir}/log/compute_prob_train.{iter}.log""" + """ nnet3-compute-prob {opts} "{model}" """ + """ "ark,bg:nnet3-copy-egs {context_opts}""" + """ {egs_rspecifier} ark:- | {extra_egs_copy_cmd}""" + """ nnet3-merge-egs --minibatch-size=1:64 ark:-""" + """ ark:- |" """.format(command=run_opts.command, + opts=' '.join(opts), + dir=dir, + iter=iter, + egs_rspecifier=train_diagnostic_egs, + context_opts=context_opts, + model=model, + egs_dir=egs_dir, + extra_egs_copy_cmd=extra_egs_copy_cmd), wait=wait, background_process_handler=background_process_handler) def compute_progress(dir, iter, egs_dir, left_context, right_context, - run_opts, background_process_handler=None, wait=False, - get_raw_nnet_from_am=True): + run_opts, + background_process_handler=None, wait=False, + get_raw_nnet_from_am=True, + extra_egs_copy_cmd=""): if get_raw_nnet_from_am: prev_model = "nnet3-am-copy --raw=true {0}/{1}.mdl - |".format( dir, iter - 1) @@ -432,21 +500,28 @@ def compute_progress(dir, iter, egs_dir, left_context, right_context, context_opts = "--left-context={lc} --right-context={rc}".format( lc=left_context, rc=right_context) + if os.path.isfile("{0}/train_diagnostic.egs".format(egs_dir)): + train_diagnostic_egs = "ark:{0}/train_diagnostic.egs".format(egs_dir) + else: + train_diagnostic_egs = "scp:{0}/train_diagnostic.egs.1".format( + egs_dir) + common_lib.run_job( - """{command} {dir}/log/progress.{iter}.log \ - nnet3-info "{model}" '&&' \ - nnet3-show-progress --use-gpu=no "{prev_model}" "{model}" \ - "ark,bg:nnet3-copy-egs {context_opts} \ - ark:{egs_dir}/train_diagnostic.egs ark:- | \ - nnet3-merge-egs --minibatch-size=1:64 ark:- \ - ark:- |" """.format(command=run_opts.command, - dir=dir, - iter=iter, - model=model, - context_opts=context_opts, - prev_model=prev_model, - egs_dir=egs_dir), - wait=wait, background_process_handler=background_process_handler) + """{command} {dir}/log/progress.{iter}.log nnet3-info "{model}" """ + """ '&&' nnet3-show-progress --use-gpu=no "{prev_model}" "{model}" """ + """ "ark,bg:nnet3-copy-egs {context_opts}""" + """ {egs_rspecifier} ark:- |{extra_egs_copy_cmd}""" + """ nnet3-merge-egs --minibatch-size=1:64 ark:-""" + """ ark:- |" """.format(command=run_opts.command, + dir=dir, + iter=iter, + egs_rspecifier=train_diagnostic_egs, + model=model, + context_opts=context_opts, + prev_model=prev_model, + egs_dir=egs_dir, + extra_egs_copy_cmd=extra_egs_copy_cmd), + wait=wait, background_process_handler=background_process_handler) def combine_models(dir, num_iters, models_to_combine, egs_dir, @@ -454,7 +529,8 @@ def combine_models(dir, num_iters, models_to_combine, egs_dir, minibatch_size_str, run_opts, background_process_handler=None, chunk_width=None, get_raw_nnet_from_am=True, - sum_to_one_penalty=0.0): + sum_to_one_penalty=0.0, + extra_egs_copy_cmd="", compute_per_dim_accuracy=False): """ Function to do model combination In the nnet3 setup, the logic @@ -489,6 +565,11 @@ def combine_models(dir, num_iters, models_to_combine, egs_dir, context_opts = "--left-context={lc} --right-context={rc}".format( lc=left_context, rc=right_context) + if os.path.isfile("{0}/combine.egs".format(egs_dir)): + combine_egs = "ark:{0}/combine.egs".format(egs_dir) + else: + combine_egs = "scp:{0}/combine.egs.1".format(egs_dir) + common_lib.run_job( """{command} {combine_queue_opt} {dir}/log/combine.log \ nnet3-combine --num-iters=80 \ @@ -497,19 +578,21 @@ def combine_models(dir, num_iters, models_to_combine, egs_dir, --enforce-positive-weights=true \ --verbose=3 {raw_models} \ "ark,bg:nnet3-copy-egs {context_opts} \ - ark:{egs_dir}/combine.egs ark:- | \ + {egs_rspecifier} ark:- |{extra_egs_copy_cmd} \ nnet3-merge-egs --measure-output-frames=false \ --minibatch-size={mbsize} ark:- ark:- |" \ "{out_model}" """.format(command=run_opts.command, combine_queue_opt=run_opts.combine_queue_opt, dir=dir, raw_models=" ".join(raw_model_strings), + egs_rspecifier=combine_egs, hard_enforce=(sum_to_one_penalty <= 0), penalty=sum_to_one_penalty, context_opts=context_opts, mbsize=minibatch_size_str, out_model=out_model, - egs_dir=egs_dir)) + egs_dir=egs_dir, + extra_egs_copy_cmd=extra_egs_copy_cmd)) # Compute the probability of the final, combined model with # the same subset we used for the previous compute_probs, as the @@ -519,14 +602,18 @@ def combine_models(dir, num_iters, models_to_combine, egs_dir, dir=dir, iter='combined', egs_dir=egs_dir, left_context=left_context, right_context=right_context, run_opts=run_opts, wait=False, - background_process_handler=background_process_handler) + background_process_handler=background_process_handler, + extra_egs_copy_cmd=extra_egs_copy_cmd, + compute_per_dim_accuracy=compute_per_dim_accuracy) else: compute_train_cv_probabilities( dir=dir, iter='final', egs_dir=egs_dir, left_context=left_context, right_context=right_context, run_opts=run_opts, wait=False, background_process_handler=background_process_handler, - get_raw_nnet_from_am=False) + get_raw_nnet_from_am=False, + extra_egs_copy_cmd=extra_egs_copy_cmd, + compute_per_dim_accuracy=compute_per_dim_accuracy) def get_realign_iters(realign_times, num_iters, @@ -639,7 +726,8 @@ def adjust_am_priors(dir, input_model, avg_posterior_vector, output_model, def compute_average_posterior(dir, iter, egs_dir, num_archives, prior_subset_size, left_context, right_context, - run_opts, get_raw_nnet_from_am=True): + run_opts, get_raw_nnet_from_am=True, + extra_egs_copy_cmd=""): """ Computes the average posterior of the network Note: this just uses CPUs, using a smallish subset of data. """ @@ -663,7 +751,7 @@ def compute_average_posterior(dir, iter, egs_dir, num_archives, """{command} JOB=1:{num_jobs_compute_prior} {prior_queue_opt} \ {dir}/log/get_post.{iter}.JOB.log \ nnet3-copy-egs {context_opts} \ - ark:{egs_dir}/egs.{egs_part}.ark ark:- \| \ + ark:{egs_dir}/egs.{egs_part}.ark ark:- \| {extra_egs_copy_cmd}\ nnet3-subset-egs --srand=JOB --n={prior_subset_size} \ ark:- ark:- \| \ nnet3-merge-egs --measure-output-frames=true \ @@ -679,7 +767,8 @@ def compute_average_posterior(dir, iter, egs_dir, num_archives, iter=iter, prior_subset_size=prior_subset_size, egs_dir=egs_dir, egs_part=egs_part, context_opts=context_opts, - prior_gpu_opt=run_opts.prior_gpu_opt)) + prior_gpu_opt=run_opts.prior_gpu_opt, + extra_egs_copy_cmd=extra_egs_copy_cmd)) # make sure there is time for $dir/post.{iter}.*.vec to appear. time.sleep(5) diff --git a/egs/wsj/s5/steps/libs/nnet3/xconfig/basic_layers.py b/egs/wsj/s5/steps/libs/nnet3/xconfig/basic_layers.py index 59b6006accb..4537fe9ee16 100644 --- a/egs/wsj/s5/steps/libs/nnet3/xconfig/basic_layers.py +++ b/egs/wsj/s5/steps/libs/nnet3/xconfig/basic_layers.py @@ -369,7 +369,8 @@ def set_default_configs(self): # note: self.config['input'] is a descriptor, '[-1]' means output # the most recent layer. - self.config = { 'input':'[-1]' } + self.config = {'input': '[-1]', + 'dim': -1} def check_configs(self): @@ -474,6 +475,7 @@ def set_default_configs(self): 'param-stddev' : 0.0, 'bias-stddev' : 0.0, 'output-delay' : 0, + 'objective-scale': 1.0, 'ng-affine-options' : '' } @@ -484,7 +486,7 @@ def check_configs(self): "".format(self.config['dim'])) if self.config['objective-type'] != 'linear' and \ - self.config['objective_type'] != 'quadratic': + self.config['objective-type'] != 'quadratic': raise RuntimeError("In output-layer, objective-type has" " invalid value {0}" "".format(self.config['objective-type'])) @@ -537,6 +539,7 @@ def get_full_config(self): bias_stddev = self.config['bias-stddev'] output_delay = self.config['output-delay'] max_change = self.config['max-change'] + objective_scale = self.config['objective-scale'] ng_affine_options = self.config['ng-affine-options'] # note: ref.config is used only for getting the left-context and @@ -578,6 +581,18 @@ def get_full_config(self): ans.append((config_name, line)) cur_node = '{0}.fixed-scale'.format(self.name) + if objective_scale != 1.0: + line = ('component name={0}.objective-scale' + ' type=ScaleGradientComponent scale={1} dim={2}' + ''.format(self.name, objective_scale, output_dim)) + ans.append((config_name, line)) + + line = ('component-node name={0}.objective-scale' + ' component={0}.objective-scale input={1}' + ''.format(self.name, cur_node)) + ans.append((config_name, line)) + cur_node = '{0}.objective-scale'.format(self.name) + if include_log_softmax: line = ('component name={0}.log-softmax' ' type=LogSoftmaxComponent dim={1}' @@ -593,7 +608,9 @@ def get_full_config(self): if output_delay != 0: cur_node = 'Offset({0}, {1})'.format(cur_node, output_delay) - line = ('output-node name={0} input={1}'.format(self.name, cur_node)) + line = ('output-node name={0} input={1} ' + 'objective={2}'.format( + self.name, cur_node, objective_type)) ans.append((config_name, line)) return ans @@ -636,7 +653,24 @@ def set_default_configs(self): 'max-change' : 0.75, 'self-repair-scale' : 1.0e-05, 'target-rms' : 1.0, - 'ng-affine-options' : ''} + 'ng-affine-options' : '', + 'add-log-stddev' : False } + + def set_derived_configs(self): + output_dim = self.config['dim'] + # If not set, the output-dim defaults to the input-dim. + if output_dim <= 0: + self.config['dim'] = self.descriptors['input']['dim'] + + if self.config['add-log-stddev']: + split_layer_name = self.layer_type.split('-') + assert split_layer_name[-1] == 'layer' + nonlinearities = split_layer_name[:-1] + + for nonlinearity in nonlinearities: + if nonlinearity == "renorm": + output_dim += 1 + self.config['output-dim'] = output_dim def check_configs(self): if self.config['dim'] < 0: @@ -660,12 +694,7 @@ def output_name(self, auxiliary_output=None): return '{0}.{1}'.format(self.name, last_nonlinearity) def output_dim(self, auxiliary_output = None): - output_dim = self.config['dim'] - # If not set, the output-dim defaults to the input-dim. - if output_dim <= 0: - output_dim = self.descriptors['input']['dim'] - return output_dim - + return self.config['output-dim'] def get_full_config(self): ans = [] @@ -695,10 +724,13 @@ def _generate_config(self): return self._add_components(input_desc, input_dim, nonlinearities) def _add_components(self, input_desc, input_dim, nonlinearities): - output_dim = self.output_dim() + output_dim = self.config['dim'] self_repair_scale = self.config['self-repair-scale'] target_rms = self.config['target-rms'] max_change = self.config['max-change'] + ng_opt_str = self.config['ng-affine-options'] + add_log_stddev = ("true" if self.config['add-log-stddev'] + else "false") ng_affine_options = self.config['ng-affine-options'] configs = [] @@ -745,8 +777,11 @@ def _add_components(self, input_desc, input_dim, nonlinearities): line = ('component name={0}.{1}' ' type=NormalizeComponent dim={2}' ' target-rms={3}' + ' add-log-stddev={4}' ''.format(self.name, nonlinearity, output_dim, - target_rms)) + target_rms, add_log_stddev)) + if self.config['add-log-stddev']: + output_dim += 1 else: raise RuntimeError("Unknown nonlinearity type: {0}" diff --git a/egs/wsj/s5/steps/libs/nnet3/xconfig/layers.py b/egs/wsj/s5/steps/libs/nnet3/xconfig/layers.py index fa356d15a18..188e0ec4322 100644 --- a/egs/wsj/s5/steps/libs/nnet3/xconfig/layers.py +++ b/egs/wsj/s5/steps/libs/nnet3/xconfig/layers.py @@ -5,3 +5,4 @@ from basic_layers import * from lstm import * +from stats_layer import * diff --git a/egs/wsj/s5/steps/libs/nnet3/xconfig/parser.py b/egs/wsj/s5/steps/libs/nnet3/xconfig/parser.py index 918d8bd2fb2..2c5424cdf94 100644 --- a/egs/wsj/s5/steps/libs/nnet3/xconfig/parser.py +++ b/egs/wsj/s5/steps/libs/nnet3/xconfig/parser.py @@ -27,7 +27,8 @@ 'lstm-layer' : xlayers.XconfigLstmLayer, 'lstmp-layer' : xlayers.XconfigLstmpLayer, 'fast-lstm-layer' : xlayers.XconfigFastLstmLayer, - 'fast-lstmp-layer' : xlayers.XconfigFastLstmpLayer + 'fast-lstmp-layer' : xlayers.XconfigFastLstmpLayer, + 'stats-layer': xlayers.XconfigStatsLayer } # Turn a config line and a list of previous layers into diff --git a/egs/wsj/s5/steps/libs/nnet3/xconfig/stats_layer.py b/egs/wsj/s5/steps/libs/nnet3/xconfig/stats_layer.py new file mode 100644 index 00000000000..e49a4fa3df6 --- /dev/null +++ b/egs/wsj/s5/steps/libs/nnet3/xconfig/stats_layer.py @@ -0,0 +1,141 @@ +# Copyright 2016 Vimal Manohar +# Apache 2.0. + +""" This module contains the statistics extraction and pooling layer. +""" + +from __future__ import print_function +import re +from libs.nnet3.xconfig.basic_layers import XconfigLayerBase + + +class XconfigStatsLayer(XconfigLayerBase): + """This class is for parsing lines like + stats-layer name=tdnn1-stats config=mean+stddev(-99:3:9:99) input=tdnn1 + + This adds statistics-pooling and statistics-extraction components. An + example string is 'mean(-99:3:9::99)', which means, compute the mean of + data within a window of -99 to +99, with distinct means computed every 9 + frames (we round to get the appropriate one), and with the input extracted + on multiples of 3 frames (so this will force the input to this layer to be + evaluated every 3 frames). Another example string is + 'mean+stddev(-99:3:9:99)', which will also cause the standard deviation to + be computed. + + The dimension is worked out from the input. mean and stddev add a + dimension of input_dim each to the output dimension. If counts is + specified, an additional dimension is added to the output to store log + counts. + + Parameters of the class, and their defaults: + input='[-1]' [Descriptor giving the input of the layer.] + dim=-1 [Output dimension of layer. If provided, must match the + dimension computed from input] + config='' [Required. Defines what stats must be computed.] + """ + def __init__(self, first_token, key_to_value, prev_names=None): + assert first_token in ['stats-layer'] + XconfigLayerBase.__init__(self, first_token, key_to_value, prev_names) + + def set_default_configs(self): + self.config = {'input': '[-1]', + 'dim': -1, + 'config': ''} + + def set_derived_configs(self): + config_string = self.config['config'] + if config_string == '': + raise RuntimeError("config has to be non-empty", + self.str()) + m = re.search("(mean|mean\+stddev|mean\+count|mean\+stddev\+count)" + "\((-?\d+):(-?\d+):(-?\d+):(-?\d+)\)", + config_string) + if m is None: + raise RuntimeError("Invalid statistic-config string: {0}".format( + config_string), self) + + self._output_stddev = (m.group(1) in ['mean+stddev', + 'mean+stddev+count']) + self._output_log_counts = (m.group(1) in ['mean+count', + 'mean+stddev+count']) + self._left_context = -int(m.group(2)) + self._input_period = int(m.group(3)) + self._stats_period = int(m.group(4)) + self._right_context = int(m.group(5)) + + output_dim = (self.descriptors['input']['dim'] + * (2 if self._output_stddev else 1) + + 1 if self._output_log_counts else 0) + + if self.config['dim'] > 0 and self.config['dim'] != output_dim: + raise RuntimeError( + "Invalid dim supplied {0:d} != " + "actual output dim {1:d}".format( + self.config['dim'], output_dim)) + self.config['dim'] = output_dim + + def check_configs(self): + if not (self._left_context > 0 and self._right_context > 0 + and self._input_period > 0 and self._stats_period > 0 + and self._left_context % self._stats_period == 0 + and self._right_context % self._stats_period == 0 + and self._stats_period % self._input_period == 0): + raise RuntimeError( + "Invalid configuration of statistics-extraction: {0}".format( + self.config['config']), self) + super(XconfigStatsLayer, self).check_configs() + + def _generate_config(self): + input_desc = self.descriptors['input']['final-string'] + input_dim = self.descriptors['input']['dim'] + + configs = [] + configs.append( + 'component name={name}-extraction-{lc}-{rc} ' + 'type=StatisticsExtractionComponent input-dim={dim} ' + 'input-period={input_period} output-period={output_period} ' + 'include-variance={var} '.format( + name=self.name, lc=self._left_context, rc=self._right_context, + dim=input_dim, input_period=self._input_period, + output_period=self._stats_period, + var='true' if self._output_stddev else 'false')) + configs.append( + 'component-node name={name}-extraction-{lc}-{rc} ' + 'component={name}-extraction-{lc}-{rc} input={input} '.format( + name=self.name, lc=self._left_context, rc=self._right_context, + input=input_desc)) + + stats_dim = 1 + input_dim * (2 if self._output_stddev else 1) + configs.append( + 'component name={name}-pooling-{lc}-{rc} ' + 'type=StatisticsPoolingComponent input-dim={dim} ' + 'input-period={input_period} left-context={lc} right-context={rc} ' + 'num-log-count-features={count} output-stddevs={var} '.format( + name=self.name, lc=self._left_context, rc=self._right_context, + dim=stats_dim, input_period=self._stats_period, + count=1 if self._output_log_counts else 0, + var='true' if self._output_stddev else 'false')) + configs.append( + 'component-node name={name}-pooling-{lc}-{rc} ' + 'component={name}-pooling-{lc}-{rc} ' + 'input={name}-extraction-{lc}-{rc} '.format( + name=self.name, lc=self._left_context, rc=self._right_context)) + return configs + + def output_name(self, auxiliary_output=None): + return 'Round({name}-pooling-{lc}-{rc}, {period})'.format( + name=self.name, lc=self._left_context, + rc=self._right_context, period=self._stats_period) + + def output_dim(self, auxiliary_outputs=None): + return self.config['dim'] + + def get_full_config(self): + ans = [] + config_lines = self._generate_config() + + for line in config_lines: + for config_name in ['ref', 'final']: + ans.append((config_name, line)) + + return ans diff --git a/egs/wsj/s5/steps/libs/nnet3/xconfig/utils.py b/egs/wsj/s5/steps/libs/nnet3/xconfig/utils.py index 3d958568717..76477300884 100644 --- a/egs/wsj/s5/steps/libs/nnet3/xconfig/utils.py +++ b/egs/wsj/s5/steps/libs/nnet3/xconfig/utils.py @@ -484,7 +484,7 @@ def parse_config_line(orig_config_line): # treats splitting on space as a special case that may give zero fields. config_line = orig_config_line.split('#')[0] # Note: this set of allowed characters may have to be expanded in future. - x = re.search('[^a-zA-Z0-9\.\-\(\)@_=,/\s"]', config_line) + x = re.search('[^a-zA-Z0-9\.\-\(\)@_=,/+:\s"]', config_line) if x is not None: bad_char = x.group(0) if bad_char == "'": diff --git a/egs/wsj/s5/steps/make_mfcc.sh b/egs/wsj/s5/steps/make_mfcc.sh index ddb63a0e6fb..5362e7fa9d9 100755 --- a/egs/wsj/s5/steps/make_mfcc.sh +++ b/egs/wsj/s5/steps/make_mfcc.sh @@ -10,7 +10,7 @@ nj=4 cmd=run.pl mfcc_config=conf/mfcc.conf compress=true -write_utt2num_frames=false # if true writes utt2num_frames +write_utt2num_frames=true # if true writes utt2num_frames # End configuration section. echo "$0 $@" # Print the command line for logging diff --git a/egs/wsj/s5/steps/make_mfcc_pitch.sh b/egs/wsj/s5/steps/make_mfcc_pitch.sh index ff9a7d2f5f3..4a2808b811f 100755 --- a/egs/wsj/s5/steps/make_mfcc_pitch.sh +++ b/egs/wsj/s5/steps/make_mfcc_pitch.sh @@ -96,6 +96,12 @@ for n in $(seq $nj); do utils/create_data_link.pl $mfcc_pitch_dir/raw_mfcc_pitch_$name.$n.ark done +if $write_utt2num_frames; then + write_num_frames_opt="--write-num-frames=ark,t:$logdir/utt2num_frames.JOB" +else + write_num_frames_opt= +fi + if [ -f $data/segments ]; then echo "$0 [info]: segments file exists: using that." split_segments="" @@ -111,7 +117,7 @@ if [ -f $data/segments ]; then $cmd JOB=1:$nj $logdir/make_mfcc_pitch_${name}.JOB.log \ paste-feats --length-tolerance=$paste_length_tolerance "$mfcc_feats" "$pitch_feats" ark:- \| \ - copy-feats --compress=$compress ark:- \ + copy-feats --compress=$compress $write_num_frames_opt ark:- \ ark,scp:$mfcc_pitch_dir/raw_mfcc_pitch_$name.JOB.ark,$mfcc_pitch_dir/raw_mfcc_pitch_$name.JOB.scp \ || exit 1; @@ -129,7 +135,7 @@ else $cmd JOB=1:$nj $logdir/make_mfcc_pitch_${name}.JOB.log \ paste-feats --length-tolerance=$paste_length_tolerance "$mfcc_feats" "$pitch_feats" ark:- \| \ - copy-feats --compress=$compress ark:- \ + copy-feats --compress=$compress $write_num_frames_opt ark:- \ ark,scp:$mfcc_pitch_dir/raw_mfcc_pitch_$name.JOB.ark,$mfcc_pitch_dir/raw_mfcc_pitch_$name.JOB.scp \ || exit 1; @@ -147,6 +153,13 @@ for n in $(seq $nj); do cat $mfcc_pitch_dir/raw_mfcc_pitch_$name.$n.scp || exit 1; done > $data/feats.scp +if $write_utt2num_frames; then + for n in $(seq $nj); do + cat $logdir/utt2num_frames.$n || exit 1; + done > $data/utt2num_frames || exit 1 + rm $logdir/uttnum_frames.* +fi + rm $logdir/wav_${name}.*.scp $logdir/segments.* 2>/dev/null nf=`cat $data/feats.scp | wc -l` diff --git a/egs/wsj/s5/steps/nnet3/chain/train.py b/egs/wsj/s5/steps/nnet3/chain/train.py index 19276817ea0..6c086e345cf 100755 --- a/egs/wsj/s5/steps/nnet3/chain/train.py +++ b/egs/wsj/s5/steps/nnet3/chain/train.py @@ -145,6 +145,7 @@ def get_args(): steps/nnet3/get_saturation.pl) exceeds this threshold we scale the parameter matrices with the shrink-value.""") + # RNN-specific training options parser.add_argument("--trainer.deriv-truncate-margin", type=int, dest='deriv_truncate_margin', default=None, @@ -358,7 +359,7 @@ def train(args, run_opts, background_process_handler): [egs_left_context, egs_right_context, frames_per_eg_str, num_archives] = ( - common_train_lib.verify_egs_dir(egs_dir, feat_dim, + common_train_lib.verify_egs_dir(egs_dir, feat_dim, ivector_dim, ivector_id, egs_left_context, egs_right_context, egs_left_context_initial, @@ -413,6 +414,10 @@ def learning_rate(iter, current_num_jobs, num_archives_processed): args.initial_effective_lrate, args.final_effective_lrate) + if args.dropout_schedule is not None: + dropout_schedule = common_train_lib.parse_dropout_option( + num_archives_to_process, args.dropout_schedule) + min_deriv_time = None max_deriv_time_relative = None if args.deriv_truncate_margin is not None: diff --git a/egs/wsj/s5/steps/nnet3/components.py b/egs/wsj/s5/steps/nnet3/components.py index 3fb92117d78..c811297cda8 100644 --- a/egs/wsj/s5/steps/nnet3/components.py +++ b/egs/wsj/s5/steps/nnet3/components.py @@ -6,6 +6,7 @@ import sys import warnings import copy +import re from operator import itemgetter def GetSumDescriptor(inputs): @@ -30,17 +31,33 @@ def AddInputLayer(config_lines, feat_dim, splice_indexes=[0], ivector_dim=0): components = config_lines['components'] component_nodes = config_lines['component-nodes'] output_dim = 0 - components.append('input-node name=input dim=' + str(feat_dim)) - list = [('Offset(input, {0})'.format(n) if n != 0 else 'input') for n in splice_indexes] - output_dim += len(splice_indexes) * feat_dim + components.append('input-node name=input dim={0}'.format(feat_dim)) + prev_layer_output = {'descriptor': "input", + 'dimension': feat_dim} + inputs = [] + for n in splice_indexes: + try: + offset = int(n) + if offset == 0: + inputs.append(prev_layer_output['descriptor']) + else: + inputs.append('Offset({0}, {1})'.format( + prev_layer_output['descriptor'], offset)) + output_dim += prev_layer_output['dimension'] + except ValueError: + stats = StatisticsConfig(n, prev_layer_output) + stats_layer = stats.AddLayer(config_lines, "Tdnn_stats_{0}".format(0)) + inputs.append(stats_layer['descriptor']) + output_dim += stats_layer['dimension'] + if ivector_dim > 0: - components.append('input-node name=ivector dim=' + str(ivector_dim)) - list.append('ReplaceIndex(ivector, t, 0)') + components.append('input-node name=ivector dim={0}'.format(ivector_dim)) + inputs.append('ReplaceIndex(ivector, t, 0)') output_dim += ivector_dim - if len(list) > 1: - splice_descriptor = "Append({0})".format(", ".join(list)) + if len(inputs) > 1: + splice_descriptor = "Append({0})".format(", ".join(inputs)) else: - splice_descriptor = list[0] + splice_descriptor = inputs[0] print(splice_descriptor) return {'descriptor': splice_descriptor, 'dimension': output_dim} @@ -55,6 +72,35 @@ def AddNoOpLayer(config_lines, name, input): return {'descriptor': '{0}_noop'.format(name), 'dimension': input['dimension']} +def AddGradientScaleLayer(config_lines, name, input, scale = 1.0, scales_vec = None): + components = config_lines['components'] + component_nodes = config_lines['component-nodes'] + + if scales_vec is None: + components.append('component name={0}_gradient_scale type=ScaleGradientComponent dim={1} scale={2}'.format(name, input['dimension'], scale)) + else: + components.append('component name={0}_gradient_scale type=ScaleGradientComponent scales={2}'.format(name, scales_vec)) + + component_nodes.append('component-node name={0}_gradient_scale component={0}_gradient_scale input={1}'.format(name, input['descriptor'])) + + return {'descriptor': '{0}_gradient_scale'.format(name), + 'dimension': input['dimension']} + +def AddFixedScaleLayer(config_lines, name, input, + scale = 1.0, scales_vec = None): + components = config_lines['components'] + component_nodes = config_lines['component-nodes'] + + if scales_vec is None: + components.append('component name={0}-fixed-scale type=FixedScaleComponent dim={1} scale={2}'.format(name, input['dimension'], scale)) + else: + components.append('component name={0}-fixed-scale type=FixedScaleComponent scales={2}'.format(name, scales_vec)) + + component_nodes.append('component-node name={0}-fixed-scale component={0}-fixed-scale input={1}'.format(name, input['descriptor'])) + + return {'descriptor': '{0}-fixed-scale'.format(name), + 'dimension': input['dimension']} + def AddLdaLayer(config_lines, name, input, lda_file): return AddFixedAffineLayer(config_lines, name, input, lda_file) @@ -257,7 +303,9 @@ def AddFinalLayer(config_lines, input, output_dim, include_log_softmax = True, add_final_sigmoid = False, name_affix = None, - objective_type = "linear"): + objective_type = "linear", + objective_scale = 1.0, + objective_scales_vec = None): components = config_lines['components'] component_nodes = config_lines['component-nodes'] @@ -283,6 +331,9 @@ def AddFinalLayer(config_lines, input, output_dim, prev_layer_output = AddSigmoidLayer(config_lines, final_node_prefix, prev_layer_output) # we use the same name_affix as a prefix in for affine/scale nodes but as a # suffix for output node + if (objective_scale != 1.0 or objective_scales_vec is not None): + prev_layer_output = AddGradientScaleLayer(config_lines, final_node_prefix, prev_layer_output, objective_scale, objective_scales_vec) + AddOutputLayer(config_lines, prev_layer_output, label_delay, suffix = name_affix, objective_type = objective_type) def AddLstmLayer(config_lines, @@ -485,3 +536,82 @@ def AddBLstmLayer(config_lines, 'dimension':output_dim } +# this is a bit like a struct, initialized from a string, which describes how to +# set up the statistics-pooling and statistics-extraction components. +# An example string is 'mean(-99:3:9::99)', which means, compute the mean of +# data within a window of -99 to +99, with distinct means computed every 9 frames +# (we round to get the appropriate one), and with the input extracted on multiples +# of 3 frames (so this will force the input to this layer to be evaluated +# every 3 frames). Another example string is 'mean+stddev(-99:3:9:99)', +# which will also cause the standard deviation to be computed. +class StatisticsConfig: + # e.g. c = StatisticsConfig('mean+stddev(-99:3:9:99)', 400, 'jesus1-forward-output-affine') + def __init__(self, config_string, input): + + self.input_dim = input['dimension'] + self.input_descriptor = input['descriptor'] + + m = re.search("(mean|mean\+stddev|mean\+count|mean\+stddev\+count)\((-?\d+):(-?\d+):(-?\d+):(-?\d+)\)", + config_string) + if m == None: + raise Exception("Invalid splice-index or statistics-config string: " + config_string) + self.output_stddev = (m.group(1) in ['mean+stddev', 'mean+stddev+count']) + self.output_log_counts = (m.group(1) in ['mean+count', 'mean+stddev+count']) + self.left_context = -int(m.group(2)) + self.input_period = int(m.group(3)) + self.stats_period = int(m.group(4)) + self.right_context = int(m.group(5)) + if not (self.left_context > 0 and self.right_context > 0 and + self.input_period > 0 and self.stats_period > 0 and + self.left_context % self.stats_period == 0 and + self.right_context % self.stats_period == 0 and + self.stats_period % self.input_period == 0): + raise Exception("Invalid configuration of statistics-extraction: " + config_string) + + # OutputDim() returns the output dimension of the node that this produces. + def OutputDim(self): + return (self.input_dim * (2 if self.output_stddev else 1) + + 1 if self.output_log_counts else 0) + + # OutputDims() returns an array of output dimensions, consisting of + # [ input-dim ] if just "mean" was specified, otherwise + # [ input-dim input-dim ] + def OutputDims(self): + output_dims = [ self.input_dim ] + if self.output_stddev: + output_dims.append(self.input_dim) + if self.output_log_counts: + output_dims.append(1) + return output_dims + + # Descriptor() returns the textual form of the descriptor by which the + # output of this node is to be accessed. + def Descriptor(self, name): + return 'Round({0}-pooling-{1}-{2}, {3})'.format(name, self.left_context, self.right_context, + self.stats_period) + + def AddLayer(self, config_lines, name): + components = config_lines['components'] + component_nodes = config_lines['component-nodes'] + + components.append('component name={name}-extraction-{lc}-{rc} type=StatisticsExtractionComponent input-dim={dim} ' + 'input-period={input_period} output-period={output_period} include-variance={var} '.format( + name = name, lc = self.left_context, rc = self.right_context, + dim = self.input_dim, input_period = self.input_period, output_period = self.stats_period, + var = ('true' if self.output_stddev else 'false'))) + component_nodes.append('component-node name={name}-extraction-{lc}-{rc} component={name}-extraction-{lc}-{rc} input={input} '.format( + name = name, lc = self.left_context, rc = self.right_context, input = self.input_descriptor)) + stats_dim = 1 + self.input_dim * (2 if self.output_stddev else 1) + components.append('component name={name}-pooling-{lc}-{rc} type=StatisticsPoolingComponent input-dim={dim} ' + 'input-period={input_period} left-context={lc} right-context={rc} num-log-count-features={count} ' + 'output-stddevs={var} '.format(name = name, lc = self.left_context, rc = self.right_context, + dim = stats_dim, input_period = self.stats_period, + count = 1 if self.output_log_counts else 0, + var = ('true' if self.output_stddev else 'false'))) + component_nodes.append('component-node name={name}-pooling-{lc}-{rc} component={name}-pooling-{lc}-{rc} input={name}-extraction-{lc}-{rc} '.format( + name = name, lc = self.left_context, rc = self.right_context)) + + return { 'dimension': self.OutputDim(), + 'descriptor': self.Descriptor(name), + 'dimensions': self.OutputDims() + } diff --git a/egs/wsj/s5/steps/nnet3/compute_output.sh b/egs/wsj/s5/steps/nnet3/compute_output.sh new file mode 100755 index 00000000000..4c32b5cb0ea --- /dev/null +++ b/egs/wsj/s5/steps/nnet3/compute_output.sh @@ -0,0 +1,185 @@ +#!/bin/bash + +# Copyright 2012-2015 Johns Hopkins University (Author: Daniel Povey). +# 2016 Vimal Manohar +# Apache 2.0. + +# This script does decoding with a neural-net. If the neural net was built on +# top of fMLLR transforms from a conventional system, you should provide the +# --transform-dir option. + +# Begin configuration section. +stage=1 +transform_dir= # dir to find fMLLR transforms. +nj=4 # number of jobs. If --transform-dir set, must match that number! +cmd=run.pl +use_gpu=false +frames_per_chunk=50 +ivector_scale=1.0 +iter=final +extra_left_context=0 +extra_right_context=0 +extra_left_context_initial=-1 +extra_right_context_final=-1 +frame_subsampling_factor=1 +feat_type= +compress=false +online_ivector_dir= +post_vec= +output_name= +use_raw_nnet=true +# End configuration section. + +echo "$0 $@" # Print the command line for logging + +[ -f ./path.sh ] && . ./path.sh; # source the path. +. parse_options.sh || exit 1; + +if [ $# -ne 3 ]; then + echo "Usage: $0 [options] " + echo "e.g.: steps/nnet3/compute_output.sh --nj 8 \\" + echo "--online-ivector-dir exp/nnet3/ivectors_test_eval92 \\" + echo " data/test_eval92_hires exp/nnet3/tdnn exp/nnet3/tdnn/output" + echo "main options (for others, see top of script file)" + echo " --transform-dir # directory of previous decoding" + echo " # where we can find transforms for SAT systems." + echo " --config # config containing options" + echo " --nj # number of parallel jobs" + echo " --cmd # Command to run in parallel with" + echo " --iter # Iteration of model to decode; default is final." + exit 1; +fi + +data=$1 +srcdir=$2 +dir=$3 + +if ! $use_raw_nnet; then + [ ! -f $srcdir/$iter.mdl ] && echo "$0: no such file $srcdir/$iter.mdl" && exit 1 + prog=nnet3-am-compute + model="$srcdir/$iter.mdl" +else + [ ! -f $srcdir/$iter.raw ] && echo "$0: no such file $srcdir/$iter.raw" && exit 1 + prog=nnet3-compute + model="nnet3-copy $srcdir/$iter.raw - |" +fi + +mkdir -p $dir/log +echo "rename-node old-name=$output_name new-name=output" > $dir/edits.config + +if [ ! -z "$output_name" ]; then + model="$model nnet3-copy --edits-config=$dir/edits.config - - |" +else + output_name=output +fi + +[ ! -z "$online_ivector_dir" ] && \ + extra_files="$online_ivector_dir/ivector_online.scp $online_ivector_dir/ivector_period" + +for f in $data/feats.scp $extra_files; do + [ ! -f $f ] && echo "$0: no such file $f" && exit 1; +done + +sdata=$data/split$nj; +cmvn_opts=`cat $srcdir/cmvn_opts` || exit 1; + +[[ -d $sdata && $data/feats.scp -ot $sdata ]] || split_data.sh $data $nj || exit 1; +echo $nj > $dir/num_jobs + + +## Set up features. +if [ -z "$feat_type" ]; then + if [ -f $srcdir/final.mat ]; then feat_type=lda; else feat_type=raw; fi + echo "$0: feature type is $feat_type" +fi + +splice_opts=`cat $srcdir/splice_opts 2>/dev/null` + +case $feat_type in + raw) feats="ark,s,cs:apply-cmvn $cmvn_opts --utt2spk=ark:$sdata/JOB/utt2spk scp:$sdata/JOB/cmvn.scp scp:$sdata/JOB/feats.scp ark:- |";; + lda) feats="ark,s,cs:apply-cmvn $cmvn_opts --utt2spk=ark:$sdata/JOB/utt2spk scp:$sdata/JOB/cmvn.scp scp:$sdata/JOB/feats.scp ark:- | splice-feats $splice_opts ark:- ark:- | transform-feats $srcdir/final.mat ark:- ark:- |" + ;; + *) echo "$0: invalid feature type $feat_type" && exit 1; +esac +if [ ! -z "$transform_dir" ]; then + echo "$0: using transforms from $transform_dir" + [ ! -s $transform_dir/num_jobs ] && \ + echo "$0: expected $transform_dir/num_jobs to contain the number of jobs." && exit 1; + nj_orig=$(cat $transform_dir/num_jobs) + + if [ $feat_type == "raw" ]; then trans=raw_trans; + else trans=trans; fi + if [ $feat_type == "lda" ] && \ + ! cmp $transform_dir/../final.mat $srcdir/final.mat && \ + ! cmp $transform_dir/final.mat $srcdir/final.mat; then + echo "$0: LDA transforms differ between $srcdir and $transform_dir" + exit 1; + fi + if [ ! -f $transform_dir/$trans.1 ]; then + echo "$0: expected $transform_dir/$trans.1 to exist (--transform-dir option)" + exit 1; + fi + if [ $nj -ne $nj_orig ]; then + # Copy the transforms into an archive with an index. + for n in $(seq $nj_orig); do cat $transform_dir/$trans.$n; done | \ + copy-feats ark:- ark,scp:$dir/$trans.ark,$dir/$trans.scp || exit 1; + feats="$feats transform-feats --utt2spk=ark:$sdata/JOB/utt2spk scp:$dir/$trans.scp ark:- ark:- |" + else + feats="$feats transform-feats --utt2spk=ark:$sdata/JOB/utt2spk ark:$transform_dir/$trans.JOB ark:- ark:- |" + fi +elif grep 'transform-feats --utt2spk' $srcdir/log/train.1.log >&/dev/null; then + echo "$0: **WARNING**: you seem to be using a neural net system trained with transforms," + echo " but you are not providing the --transform-dir option in test time." +fi +## + +if [ ! -z "$online_ivector_dir" ]; then + ivector_period=$(cat $online_ivector_dir/ivector_period) || exit 1; + ivector_opts="--online-ivectors=scp:$online_ivector_dir/ivector_online.scp --online-ivector-period=$ivector_period" +fi + +frame_subsampling_opt= +if [ $frame_subsampling_factor -ne 1 ]; then + # e.g. for 'chain' systems + frame_subsampling_opt="--frame-subsampling-factor=$frame_subsampling_factor" +fi + +if ! $use_raw_nnet; then + output_wspecifier="ark:| copy-feats --compress=$compress ark:- ark:- | gzip -c > $dir/log_likes.JOB.gz" +else + output_wspecifier="ark:| copy-feats --compress=$compress ark:- ark:- | gzip -c > $dir/nnet_output.JOB.gz" + + if [ ! -z $post_vec ]; then + if [ $stage -le 1 ]; then + copy-vector --binary=false $post_vec - | \ + awk '{for (i = 2; i < NF; i++) { sum += i; }; + printf ("["); + for (i = 2; i < NF; i++) { printf " "log(i/sum); }; + print (" ]");}' > $dir/log_priors.vec + fi + + output_wspecifier="ark:| matrix-add-offset ark:- 'vector-scale --scale=-1.0 $dir/log_priors.vec - |' ark:- | copy-feats --compress=$compress ark:- ark:- | gzip -c > $dir/log_likes.JOB.gz" + fi +fi + +gpu_opt="--use-gpu=no" +gpu_queue_opt= + +if $use_gpu; then + gpu_queue_opt="--gpu 1" + gpu_opt="--use-gpu=yes" +fi + +if [ $stage -le 2 ]; then + $cmd $gpu_queue_opt JOB=1:$nj $dir/log/compute_output.JOB.log \ + $prog $gpu_opt $ivector_opts $frame_subsampling_opt \ + --frames-per-chunk=$frames_per_chunk \ + --extra-left-context=$extra_left_context \ + --extra-right-context=$extra_right_context \ + --extra-left-context-initial=$extra_left_context_initial \ + --extra-right-context-final=$extra_right_context_final \ + "$model" "$feats" "$output_wspecifier" || exit 1; +fi + +exit 0; + diff --git a/egs/wsj/s5/steps/nnet3/get_egs.sh b/egs/wsj/s5/steps/nnet3/get_egs.sh index d72a3d23fe5..4f8b692488a 100755 --- a/egs/wsj/s5/steps/nnet3/get_egs.sh +++ b/egs/wsj/s5/steps/nnet3/get_egs.sh @@ -12,6 +12,8 @@ # right, and this ends up getting shared. This is at the expense of slightly # higher disk I/O while training. +set -o pipefail +trap "" PIPE # Begin configuration section. cmd=run.pl diff --git a/egs/wsj/s5/steps/nnet3/get_egs_multiple_targets.py b/egs/wsj/s5/steps/nnet3/get_egs_multiple_targets.py new file mode 100755 index 00000000000..30449c81e81 --- /dev/null +++ b/egs/wsj/s5/steps/nnet3/get_egs_multiple_targets.py @@ -0,0 +1,999 @@ +#!/usr/bin/env python + +# Copyright 2016 Vijayaditya Peddinti +# 2016 Vimal Manohar +# Apache 2.0. + +from __future__ import print_function +import os +import argparse +import sys +import logging +import shlex +import random +import math +import glob + +sys.path.insert(0, 'steps') +import libs.data as data_lib +import libs.common as common_lib + +logger = logging.getLogger('libs') +logger.setLevel(logging.INFO) +handler = logging.StreamHandler() +handler.setLevel(logging.INFO) +formatter = logging.Formatter("%(asctime)s [%(pathname)s:%(lineno)s - " + "%(funcName)s - %(levelname)s ] %(message)s") +handler.setFormatter(formatter) +logger.addHandler(handler) +logger.info('Getting egs for training') + + +def get_args(): + # we add compulsary arguments as named arguments for readability + parser = argparse.ArgumentParser( + description="""Generates training examples used to train the 'nnet3' + network (and also the validation examples used for diagnostics), + and puts them in separate archives.""", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument("--cmd", type=str, default="run.pl", + help="Specifies the script to launch jobs." + " e.g. queue.pl for launching on SGE cluster run.pl" + " for launching on local machine") + # feat options + parser.add_argument("--feat.dir", type=str, dest='feat_dir', required=True, + help="Directory with features used for training " + "the neural network.") + parser.add_argument("--feat.online-ivector-dir", type=str, + dest='online_ivector_dir', + default=None, action=common_lib.NullstrToNoneAction, + help="directory with the ivectors extracted in an " + "online fashion.") + parser.add_argument("--feat.cmvn-opts", type=str, dest='cmvn_opts', + default=None, action=common_lib.NullstrToNoneAction, + help="A string specifying '--norm-means' and " + "'--norm-vars' values") + parser.add_argument("--feat.apply-cmvn-sliding", type=str, + dest='apply_cmvn_sliding', + default=False, action=common_lib.StrToBoolAction, + help="Apply CMVN sliding, instead of per-utteance " + "or speakers") + + # egs extraction options + parser.add_argument("--frames-per-eg", type=int, default=8, + help="""Number of frames of labels per example. + more->less disk space and less time preparing egs, but + more I/O during training. + note: the script may reduce this if + reduce-frames-per-eg is true.""") + parser.add_argument("--left-context", type=int, default=4, + help="""Amount of left-context per eg (i.e. extra + frames of input features not present in the output + supervision).""") + parser.add_argument("--right-context", type=int, default=4, + help="Amount of right-context per eg") + parser.add_argument("--valid-left-context", type=int, default=None, + help="""Amount of left-context for validation egs, + typically used in recurrent architectures to ensure + matched condition with training egs""") + parser.add_argument("--valid-right-context", type=int, default=None, + help="""Amount of right-context for validation egs, + typically used in recurrent architectures to ensure + matched condition with training egs""") + parser.add_argument("--compress-input", type=str, default=True, + action=common_lib.StrToBoolAction, + choices=["true", "false"], + help="If false, disables compression. Might be " + "necessary to check if results will be affected.") + parser.add_argument("--input-compress-format", type=int, default=0, + help="Format used for compressing the input features") + + parser.add_argument("--reduce-frames-per-eg", type=str, default=True, + action=common_lib.StrToBoolAction, + choices=["true", "false"], + help="""If true, this script may reduce the + frames-per-eg if there is only one archive and even + with the reduced frames-per-eg, the number of + samples-per-iter that would result is less than or + equal to the user-specified value.""") + + parser.add_argument("--num-utts-subset", type=int, default=300, + help="Number of utterances in validation and training" + " subsets used for shrinkage and diagnostics") + parser.add_argument("--num-utts-subset-valid", type=int, + help="Number of utterances in validation" + " subset used for diagnostics") + parser.add_argument("--num-utts-subset-train", type=int, + help="Number of utterances in training" + " subset used for shrinkage and diagnostics") + parser.add_argument("--num-train-egs-combine", type=int, default=10000, + help="Training examples for combination weights at the" + " very end.") + parser.add_argument("--num-valid-egs-combine", type=int, default=0, + help="Validation examples for combination weights at " + "the very end.") + parser.add_argument("--num-egs-diagnostic", type=int, default=4000, + help="Numer of frames for 'compute-probs' jobs") + + parser.add_argument("--samples-per-iter", type=int, default=400000, + help="""This is the target number of egs in each + archive of egs (prior to merging egs). We probably + should have called it egs_per_iter. This is just a + guideline; it will pick a number that divides the + number of samples in the entire data.""") + + parser.add_argument("--stage", type=int, default=0, + help="Stage to start running script from") + parser.add_argument("--num-jobs", type=int, default=6, + help="""This should be set to the maximum number of + jobs you are comfortable to run in parallel; you can + increase it if your disk speed is greater and you have + more machines.""") + parser.add_argument("--srand", type=int, default=0, + help="Rand seed for nnet3-copy-egs and " + "nnet3-shuffle-egs") + parser.add_argument("--generate-egs-scp", type=str, + default=False, action=common_lib.StrToBoolAction, + help="Generate scp files in addition to archives") + + parser.add_argument("--targets-parameters", type=str, action='append', + required=True, dest='targets_para_array', + help="""Parameters for targets. Each set of parameters + corresponds to a separate output node of the neural + network. The targets can be sparse or dense. + The parameters used are: + --targets-rspecifier= + # rspecifier for the targets, can be alignment or + # matrix. + --num-targets= + # targets dimension. required for sparse feats. + --target-type=""") + + parser.add_argument("--dir", type=str, required=True, + help="Directory to store the examples") + + print(' '.join(sys.argv)) + print(sys.argv) + + args = parser.parse_args() + + args = process_args(args) + + return args + + +def process_args(args): + # process the options + if args.num_utts_subset_valid is None: + args.num_utts_subset_valid = args.num_utts_subset + + if args.num_utts_subset_train is None: + args.num_utts_subset_train = args.num_utts_subset + + if args.valid_left_context is None: + args.valid_left_context = args.left_context + if args.valid_right_context is None: + args.valid_right_context = args.right_context + + if (args.left_context < 0 or args.right_context < 0 + or args.valid_left_context < 0 or args.valid_right_context < 0): + raise Exception( + "--{,valid-}{left,right}-context should be non-negative") + + return args + + +def check_for_required_files(feat_dir, targets_scps, online_ivector_dir=None): + required_files = ['{0}/feats.scp'.format(feat_dir), + '{0}/cmvn.scp'.format(feat_dir)] + if online_ivector_dir is not None: + required_files.append('{0}/ivector_online.scp'.format( + online_ivector_dir)) + required_files.append('{0}/ivector_period'.format( + online_ivector_dir)) + required_files.extend(targets_scps) + + for file in required_files: + if not os.path.isfile(file): + raise Exception('Expected {0} to exist.'.format(file)) + + +def parse_targets_parameters_array(para_array): + targets_parser = argparse.ArgumentParser() + targets_parser.add_argument("--output-name", type=str, required=True, + help="Name of the output. e.g. output-xent") + targets_parser.add_argument("--dim", type=int, default=-1, + help="Target dimension (required for sparse " + "targets") + targets_parser.add_argument("--target-type", type=str, default="dense", + choices=["dense", "sparse"], + help="Dense for matrix format") + targets_parser.add_argument("--targets-scp", type=str, required=True, + help="Scp file of targets; can be posteriors " + "or matrices") + targets_parser.add_argument("--compress", type=str, default=True, + action=common_lib.StrToBoolAction, + help="Specifies whether the output must be " + "compressed") + targets_parser.add_argument("--compress-format", type=int, default=0, + help="Format for compressing target") + targets_parser.add_argument("--deriv-weights-scp", type=str, default="", + help="Per-frame deriv weights for this output") + targets_parser.add_argument("--scp2ark-cmd", type=str, default="", + help="""The command that is used to convert + targets scp to archive. e.g. An scp of + alignments can be converted to posteriors using + ali-to-post""") + + targets_parameters = [targets_parser.parse_args(shlex.split(x)) + for x in para_array] + + for t in targets_parameters: + if not os.path.isfile(t.targets_scp): + raise Exception("Expected {0} to exist.".format(t.targets_scp)) + + if t.target_type == "dense": + dim = common_lib.get_feat_dim_from_scp(t.targets_scp) + if t.dim != -1 and t.dim != dim: + raise Exception('Mismatch in --dim provided and feat dim for ' + 'file {0}; {1} vs {2}'.format(t.targets_scp, + t.dim, dim)) + t.dim = -dim + + return targets_parameters + + +def sample_utts(feat_dir, num_utts_subset, min_duration, exclude_list=None): + utt2durs_dict = data_lib.get_utt2dur(feat_dir) + utt2durs = utt2durs_dict.items() + utt2uniq, uniq2utt = data_lib.get_utt2uniq(feat_dir) + if num_utts_subset is None: + num_utts_subset = len(utt2durs) + if exclude_list is not None: + num_utts_subset = num_utts_subset - len(exclude_list) + + random.shuffle(utt2durs) + sampled_utts = [] + + index = 0 + num_trials = 0 + while (len(sampled_utts) < num_utts_subset + and num_trials <= len(utt2durs)): + if utt2durs[index][-1] >= min_duration: + if utt2uniq is not None: + uniq_id = utt2uniq[utt2durs[index][0]] + utts2add = uniq2utt[uniq_id] + else: + utts2add = [utt2durs[index][0]] + exclude_utt = False + if exclude_list is not None: + for utt in utts2add: + if utt in exclude_list: + exclude_utt = True + break + if not exclude_utt: + for utt in utts2add: + sampled_utts.append(utt) + + else: + logger.info("Skipping utterance %s of length %f", + utt2uniq[utt2durs[index][0]], utt2durs[index][1]) + index = index + 1 + num_trials = num_trials + 1 + if exclude_list is not None: + assert(len(set(exclude_list).intersection(sampled_utts)) == 0) + if len(sampled_utts) < num_utts_subset: + raise Exception( + """Number of utterances which have duration of at least {md} + seconds is really low (required={rl}, available={al}). Please + check your data.""".format( + md=min_duration, al=len(sampled_utts), rl=num_utts_subset)) + + sampled_utts_durs = [] + for utt in sampled_utts: + sampled_utts_durs.append([utt, utt2durs_dict[utt]]) + return sampled_utts, sampled_utts_durs + + +def write_list(listd, file_name): + file_handle = open(file_name, 'w') + assert(type(listd) == list) + for item in listd: + file_handle.write(str(item)+"\n") + file_handle.close() + + +def get_max_open_files(): + stdout, stderr = common_lib.run_kaldi_command("ulimit -n") + return int(stdout) + + +def get_feat_ivector_strings(dir, feat_dir, split_feat_dir, + cmvn_opt_string, ivector_dir=None, + apply_cmvn_sliding=False): + + if not apply_cmvn_sliding: + train_feats = ("ark,s,cs:utils/filter_scp.pl --exclude " + "{dir}/valid_uttlist {sdir}/JOB/feats.scp | " + "apply-cmvn {cmvn} --utt2spk=ark:{sdir}/JOB/utt2spk " + "scp:{sdir}/JOB/cmvn.scp scp:- ark:- |".format( + dir=dir, sdir=split_feat_dir, + cmvn=cmvn_opt_string)) + valid_feats = ("ark,s,cs:utils/filter_scp.pl {dir}/valid_uttlist " + "{fdir}/feats.scp | " + "apply-cmvn {cmvn} --utt2spk=ark:{fdir}/utt2spk " + "scp:{fdir}/cmvn.scp scp:- ark:- |".format( + dir=dir, fdir=feat_dir, cmvn=cmvn_opt_string)) + train_subset_feats = ("ark,s,cs:utils/filter_scp.pl " + "{dir}/train_subset_uttlist {fdir}/feats.scp | " + "apply-cmvn {cmvn} --utt2spk=ark:{fdir}/utt2spk " + "scp:{fdir}/cmvn.scp scp:- ark:- |".format( + dir=dir, fdir=feat_dir, + cmvn=cmvn_opt_string)) + + def feats_subset_func(subset_list): + return ("ark,s,cs:utils/filter_scp.pl {subset_list} " + "{fdir}/feats.scp | " + "apply-cmvn {cmvn} --utt2spk=ark:{fdir}/utt2spk " + "scp:{fdir}/cmvn.scp scp:- ark:- |".format( + dir=dir, subset_list=subset_list, + fdir=feat_dir, cmvn=cmvn_opt_string)) + + else: + train_feats = ("ark,s,cs:utils/filter_scp.pl --exclude " + "{dir}/valid_uttlist {sdir}/JOB/feats.scp | " + "apply-cmvn-sliding scp:{sdir}/JOB/cmvn.scp scp:- " + "ark:- |".format(dir=dir, sdir=split_feat_dir, + cmvn=cmvn_opt_string)) + + def feats_subset_func(subset_list): + return ("ark,s,cs:utils/filter_scp.pl {subset_list} " + "{fdir}/feats.scp | " + "apply-cmvn-sliding {cmvn} scp:{fdir}/cmvn.scp scp:- " + "ark:- |".format(dir=dir, subset_list=subset_list, + fdir=feat_dir, cmvn=cmvn_opt_string)) + + train_subset_feats = feats_subset_func( + "{0}/train_subset_uttlist".format(dir)) + valid_feats = feats_subset_func("{0}/valid_uttlist".format(dir)) + + if ivector_dir is not None: + ivector_period = common_lib.GetIvectorPeriod(ivector_dir) + ivector_opt = ("--ivectors='ark,s,cs:utils/filter_scp.pl " + "{sdir}/JOB/utt2spk {idir}/ivector_online.scp | " + "subsample-feats --n=-{period} scp:- ark:- |'".format( + sdir=split_feat_dir, idir=ivector_dir, + period=ivector_period)) + valid_ivector_opt = ("--ivectors='ark,s,cs:utils/filter_scp.pl " + "{dir}/valid_uttlist {idir}/ivector_online.scp | " + "subsample-feats --n=-{period} " + "scp:- ark:- |'".format( + dir=dir, idir=ivector_dir, + period=ivector_period)) + train_subset_ivector_opt = ( + "--ivectors='ark,s,cs:utils/filter_scp.pl " + "{dir}/train_subset_uttlist {idir}/ivector_online.scp | " + "subsample-feats --n=-{period} scp:- ark:- |'".format( + dir=dir, idir=ivector_dir, period=ivector_period)) + else: + ivector_opt = '' + valid_ivector_opt = '' + train_subset_ivector_opt = '' + + return {'train_feats': train_feats, + 'valid_feats': valid_feats, + 'train_subset_feats': train_subset_feats, + 'feats_subset_func': feats_subset_func, + 'ivector_opts': ivector_opt, + 'valid_ivector_opts': valid_ivector_opt, + 'train_subset_ivector_opts': train_subset_ivector_opt, + 'feat_dim': common_lib.get_feat_dim(feat_dir), + 'ivector_dim': common_lib.get_ivector_dim(ivector_dir)} + + +def get_egs_options(targets_parameters, frames_per_eg, + left_context, right_context, + valid_left_context, valid_right_context, + compress_input, + input_compress_format=0, length_tolerance=0): + + train_egs_opts = [] + train_egs_opts.append("--left-context={0}".format(left_context)) + train_egs_opts.append("--right-context={0}".format(right_context)) + train_egs_opts.append("--num-frames={0}".format(frames_per_eg)) + train_egs_opts.append("--compress-input={0}".format(compress_input)) + train_egs_opts.append("--input-compress-format={0}".format( + input_compress_format)) + train_egs_opts.append("--compress-targets={0}".format( + ':'.join(["true" if t.compress else "false" + for t in targets_parameters]))) + train_egs_opts.append("--targets-compress-formats={0}".format( + ':'.join([str(t.compress_format) + for t in targets_parameters]))) + train_egs_opts.append("--length-tolerance={0}".format(length_tolerance)) + train_egs_opts.append("--output-names={0}".format( + ':'.join([t.output_name + for t in targets_parameters]))) + train_egs_opts.append("--output-dims={0}".format( + ':'.join([str(t.dim) + for t in targets_parameters]))) + + valid_egs_opts = ( + "--left-context={vlc} --right-context={vrc} " + "--num-frames={n} --compress-input={comp} " + "--input-compress-format={icf} --compress-targets={ct} " + "--targets-compress-formats={tcf} --length-tolerance={tol} " + "--output-names={names} --output-dims={dims}".format( + vlc=valid_left_context, vrc=valid_right_context, n=frames_per_eg, + comp=compress_input, icf=input_compress_format, + ct=':'.join(["true" if t.compress else "false" + for t in targets_parameters]), + tcf=':'.join([str(t.compress_format) + for t in targets_parameters]), + tol=length_tolerance, + names=':'.join([t.output_name + for t in targets_parameters]), + dims=':'.join([str(t.dim) for t in targets_parameters]))) + + return {'train_egs_opts': " ".join(train_egs_opts), + 'valid_egs_opts': valid_egs_opts} + + +def get_targets_list(targets_parameters, subset_list): + targets_list = [] + for t in targets_parameters: + rspecifier = "ark,s,cs:" if t.scp2ark_cmd != "" else "scp,s,cs:" + rspecifier += get_subset_rspecifier(t.targets_scp, subset_list) + rspecifier += t.scp2ark_cmd + deriv_weights_rspecifier = "" + if t.deriv_weights_scp != "": + deriv_weights_rspecifier = "scp,s,cs:{0}".format( + get_subset_rspecifier(t.deriv_weights_scp, subset_list)) + this_targets = '''"{rspecifier}" "{dw}"'''.format( + rspecifier=rspecifier, dw=deriv_weights_rspecifier) + + targets_list.append(this_targets) + return " ".join(targets_list) + + +def get_subset_rspecifier(scp_file, subset_list): + if scp_file == "": + return "" + return "utils/filter_scp.pl {subset} {scp} |".format(subset=subset_list, + scp=scp_file) + + +def split_scp(scp_file, num_jobs): + out_scps = ["{0}.{1}".format(scp_file, n) for n in range(1, num_jobs + 1)] + common_lib.run_kaldi_command("utils/split_scp.pl {scp} {oscps}".format( + scp=scp_file, + oscps=' '.join(out_scps))) + return out_scps + + +def generate_valid_train_subset_egs(dir, targets_parameters, + feat_ivector_strings, egs_opts, + num_train_egs_combine, + num_valid_egs_combine, + num_egs_diagnostic, cmd, + num_jobs=1, + generate_egs_scp=False): + + if generate_egs_scp: + valid_combine_output = ("ark,scp:{0}/valid_combine.egs," + "{0}/valid_combine.egs.scp".format(dir)) + valid_diagnostic_output = ("ark,scp:{0}/valid_diagnostic.egs," + "{0}/valid_diagnostic.egs.scp".format(dir)) + train_combine_output = ("ark,scp:{0}/train_combine.egs," + "{0}/train_combine.egs.scp".format(dir)) + train_diagnostic_output = ("ark,scp:{0}/train_diagnostic.egs," + "{0}/train_diagnostic.egs.scp".format(dir)) + else: + valid_combine_output = "ark:{0}/valid_combine.egs".format(dir) + valid_diagnostic_output = "ark:{0}/valid_diagnostic.egs".format(dir) + train_combine_output = "ark:{0}/train_combine.egs".format(dir) + train_diagnostic_output = "ark:{0}/train_diagnostic.egs".format(dir) + + wait_pids = [] + + logger.info("Creating validation and train subset examples.") + + split_scp('{0}/valid_uttlist'.format(dir), num_jobs) + split_scp('{0}/train_subset_uttlist'.format(dir), num_jobs) + + valid_pid = common_lib.run_kaldi_command( + """{cmd} JOB=1:{nj} {dir}/log/create_valid_subset.JOB.log \ + nnet3-get-egs-multiple-targets {v_iv_opt} {v_egs_opt} "{v_feats}" \ + {targets} ark,scp:{dir}/valid_all.JOB.egs,""" + """{dir}/valid_all.JOB.egs.scp""".format( + cmd=cmd, nj=num_jobs, dir=dir, + v_egs_opt=egs_opts['valid_egs_opts'], + v_iv_opt=feat_ivector_strings['valid_ivector_opts'], + v_feats=feat_ivector_strings['feats_subset_func']( + '{dir}/valid_uttlist.JOB'.format(dir=dir)), + targets=get_targets_list( + targets_parameters, + '{dir}/valid_uttlist.JOB'.format(dir=dir))), + wait=False) + + train_pid = common_lib.run_kaldi_command( + """{cmd} JOB=1:{nj} {dir}/log/create_train_subset.JOB.log \ + nnet3-get-egs-multiple-targets {t_iv_opt} {v_egs_opt} "{t_feats}" \ + {targets} ark,scp:{dir}/train_subset_all.JOB.egs,""" + """{dir}/train_subset_all.JOB.egs.scp""".format( + cmd=cmd, nj=num_jobs, dir=dir, + v_egs_opt=egs_opts['valid_egs_opts'], + t_iv_opt=feat_ivector_strings['train_subset_ivector_opts'], + t_feats=feat_ivector_strings['feats_subset_func']( + '{dir}/train_subset_uttlist.JOB'.format(dir=dir)), + targets=get_targets_list( + targets_parameters, + '{dir}/train_subset_uttlist.JOB'.format(dir=dir))), + wait=False) + + wait_pids.append(valid_pid) + wait_pids.append(train_pid) + + for pid in wait_pids: + stdout, stderr = pid.communicate() + if pid.returncode != 0: + raise Exception(stderr) + + valid_egs_all = ' '.join( + ['{dir}/valid_all.{n}.egs.scp'.format(dir=dir, n=n) + for n in range(1, num_jobs + 1)]) + train_subset_egs_all = ' '.join( + ['{dir}/train_subset_all.{n}.egs.scp'.format(dir=dir, n=n) + for n in range(1, num_jobs + 1)]) + + wait_pids = [] + logger.info("... Getting subsets of validation examples for diagnostics " + " and combination.") + pid = common_lib.run_kaldi_command( + """{cmd} {dir}/log/create_valid_subset_combine.log \ + cat {valid_egs_all} \| nnet3-subset-egs --n={nve_combine} \ + scp:- {valid_combine_output}""".format( + cmd=cmd, dir=dir, valid_egs_all=valid_egs_all, + nve_combine=num_valid_egs_combine, + valid_combine_output=valid_combine_output), + wait=False) + wait_pids.append(pid) + + pid = common_lib.run_kaldi_command( + """{cmd} {dir}/log/create_valid_subset_diagnostic.log \ + cat {valid_egs_all} \| nnet3-subset-egs --n={ne_diagnostic} \ + scp:- {valid_diagnostic_output}""".format( + cmd=cmd, dir=dir, valid_egs_all=valid_egs_all, + ne_diagnostic=num_egs_diagnostic, + valid_diagnostic_output=valid_diagnostic_output), + wait=False) + wait_pids.append(pid) + + pid = common_lib.run_kaldi_command( + """{cmd} {dir}/log/create_train_subset_combine.log \ + cat {train_subset_egs_all} \| \ + nnet3-subset-egs --n={nte_combine} \ + scp:- {train_combine_output}""".format( + cmd=cmd, dir=dir, train_subset_egs_all=train_subset_egs_all, + nte_combine=num_train_egs_combine, + train_combine_output=train_combine_output), + wait=False) + wait_pids.append(pid) + + pid = common_lib.run_kaldi_command( + """{cmd} {dir}/log/create_train_subset_diagnostic.log \ + cat {train_subset_egs_all} \| \ + nnet3-subset-egs --n={ne_diagnostic} \ + scp:- {train_diagnostic_output}""".format( + cmd=cmd, dir=dir, train_subset_egs_all=train_subset_egs_all, + ne_diagnostic=num_egs_diagnostic, + train_diagnostic_output=train_diagnostic_output), + wait=False) + wait_pids.append(pid) + + for pid in wait_pids: + stdout, stderr = pid.communicate() + if pid.returncode != 0: + raise Exception(stderr) + + common_lib.run_kaldi_command( + """cat {dir}/valid_combine.egs {dir}/train_combine.egs > \ + {dir}/combine.egs""".format(dir=dir)) + + if generate_egs_scp: + common_lib.run_kaldi_command( + """cat {dir}/valid_combine.egs.scp {dir}/train_combine.egs.scp > \ + {dir}/combine.egs.scp""".format(dir=dir)) + common_lib.run_kaldi_command( + "rm {dir}/valid_combine.egs.scp {dir}/train_combine.egs.scp" + "".format(dir=dir)) + + # perform checks + for file_name in ('{0}/combine.egs {0}/train_diagnostic.egs ' + '{0}/valid_diagnostic.egs'.format(dir).split()): + if os.path.getsize(file_name) == 0: + raise Exception("No examples in {0}".format(file_name)) + + # clean-up + for x in ('{0}/valid_all.*.egs {0}/train_subset_all.*.egs ' + '{0}/valid_all.*.egs.scp {0}/train_subset_all.*.egs.scp ' + '{0}/train_combine.egs ' + '{0}/valid_combine.egs'.format(dir).split()): + for file_name in glob.glob(x): + os.remove(file_name) + + +def generate_training_examples_internal(dir, targets_parameters, feat_dir, + train_feats_string, + train_egs_opts_string, + ivector_opts, + num_jobs, frames_per_eg, + samples_per_iter, cmd, srand=0, + reduce_frames_per_eg=True, + only_shuffle=False, + dry_run=False, + generate_egs_scp=False): + + # The examples will go round-robin to egs_list. Note: we omit the + # 'normalization.fst' argument while creating temporary egs: the phase of + # egs preparation that involves the normalization FST is quite + # CPU-intensive and it's more convenient to do it later, in the 'shuffle' + # stage. Otherwise to make it efficient we need to use a large 'nj', like + # 40, and in that case there can be too many small files to deal with, + # because the total number of files is the product of 'nj' by + # 'num_archives_intermediate', which might be quite large. + num_frames = data_lib.get_num_frames(feat_dir) + num_archives = (num_frames) / (frames_per_eg * samples_per_iter) + 1 + + reduced = False + while (reduce_frames_per_eg and frames_per_eg > 1 + and num_frames / ((frames_per_eg-1)*samples_per_iter) == 0): + frames_per_eg -= 1 + num_archives = 1 + reduced = True + + if reduced: + logger.info("Reduced frames-per-eg to {0} " + "because amount of data is small".format(frames_per_eg)) + + max_open_files = get_max_open_files() + num_archives_intermediate = num_archives + archives_multiple = 1 + while (num_archives_intermediate+4) > max_open_files: + archives_multiple = archives_multiple + 1 + num_archives_intermediate = int(math.ceil(float(num_archives) + / archives_multiple)) + num_archives = num_archives_intermediate * archives_multiple + egs_per_archive = num_frames/(frames_per_eg * num_archives) + + if egs_per_archive > samples_per_iter: + raise Exception( + """egs_per_archive({epa}) > samples_per_iter({fpi}). + This is an error in the logic for determining + egs_per_archive""".format(epa=egs_per_archive, + fpi=samples_per_iter)) + + if dry_run: + if generate_egs_scp: + for i in range(1, num_archives_intermediate + 1): + for j in range(1, archives_multiple + 1): + archive_index = (i-1) * archives_multiple + j + common_lib.force_symlink( + "egs.{0}.ark".format(archive_index), + "{dir}/egs.{i}.{j}.ark".format(dir=dir, i=i, j=j)) + cleanup(dir, archives_multiple, generate_egs_scp) + return {'num_frames': num_frames, + 'num_archives': num_archives, + 'egs_per_archive': egs_per_archive} + + logger.info("Splitting a total of {nf} frames into {na} archives, " + "each with {epa} egs.".format(nf=num_frames, na=num_archives, + epa=egs_per_archive)) + + if os.path.isdir('{0}/storage'.format(dir)): + # this is a striped directory, so create the softlinks + data_lib.create_data_links(["{dir}/egs.{x}.ark".format(dir=dir, x=x) + for x in range(1, num_archives + 1)]) + for x in range(1, num_archives_intermediate + 1): + data_lib.create_data_links( + ["{dir}/egs_orig.{y}.{x}.ark".format(dir=dir, x=x, y=y) + for y in range(1, num_jobs + 1)]) + + split_feat_dir = "{0}/split{1}".format(feat_dir, num_jobs) + egs_list = ' '.join( + ['ark:{dir}/egs_orig.JOB.{ark_num}.ark'.format(dir=dir, ark_num=x) + for x in range(1, num_archives_intermediate + 1)]) + + if not only_shuffle: + common_lib.run_kaldi_command( + """{cmd} JOB=1:{nj} {dir}/log/get_egs.JOB.log \ + nnet3-get-egs-multiple-targets {iv_opts} {egs_opts} \ + "{feats}" {targets} ark:- \| \ + nnet3-copy-egs --random=true --srand=$[JOB+{srand}] \ + ark:- {egs_list}""".format( + cmd=cmd, nj=num_jobs, dir=dir, srand=srand, + iv_opts=ivector_opts, egs_opts=train_egs_opts_string, + feats=train_feats_string, + targets=get_targets_list(targets_parameters, + '{sdir}/JOB/utt2spk'.format( + sdir=split_feat_dir)), + egs_list=egs_list)) + + logger.info("Recombining and shuffling order of archives on disk") + egs_list = ' '.join(['{dir}/egs_orig.{n}.JOB.ark'.format(dir=dir, n=x) + for x in range(1, num_jobs + 1)]) + + if archives_multiple == 1: + # there are no intermediate archives so just shuffle egs across + # jobs and dump them into a single output + + if generate_egs_scp: + output_archive = ("ark,scp:{dir}/egs.JOB.ark," + "{dir}/egs.JOB.scp".format(dir=dir)) + else: + output_archive = "ark:{dir}/egs.JOB.ark".format(dir=dir) + + common_lib.run_kaldi_command( + """{cmd} --max-jobs-run {msjr} JOB=1:{nai} \ + {dir}/log/shuffle.JOB.log \ + nnet3-shuffle-egs --srand=$[JOB+{srand}] \ + "ark:cat {egs_list}|" {output_archive}""".format( + cmd=cmd, msjr=num_jobs, + nai=num_archives_intermediate, srand=srand, + dir=dir, egs_list=egs_list, + output_archive=output_archive)) + + if generate_egs_scp: + out_egs_handle = open("{0}/egs.scp".format(dir), 'w') + for i in range(1, num_archives_intermediate + 1): + for line in open("{0}/egs.{1}.scp".format(dir, i)): + print (line.strip(), file=out_egs_handle) + out_egs_handle.close() + else: + # there are intermediate archives so we shuffle egs across jobs + # and split them into archives_multiple output archives + if generate_egs_scp: + output_archives = ' '.join( + ["ark,scp:{dir}/egs.JOB.{ark_num}.ark," + "{dir}/egs.JOB.{ark_num}.scp".format( + dir=dir, ark_num=x) + for x in range(1, archives_multiple + 1)]) + else: + output_archives = ' '.join( + ["ark:{dir}/egs.JOB.{ark_num}.ark".format( + dir=dir, ark_num=x) + for x in range(1, archives_multiple + 1)]) + # archives were created as egs.x.y.ark + # linking them to egs.i.ark format which is expected by the training + # scripts + for i in range(1, num_archives_intermediate + 1): + for j in range(1, archives_multiple + 1): + archive_index = (i-1) * archives_multiple + j + common_lib.force_symlink( + "egs.{0}.ark".format(archive_index), + "{dir}/egs.{i}.{j}.ark".format(dir=dir, i=i, j=j)) + + common_lib.run_kaldi_command( + """{cmd} --max-jobs-run {msjr} JOB=1:{nai} \ + {dir}/log/shuffle.JOB.log \ + nnet3-shuffle-egs --srand=$[JOB+{srand}] \ + "ark:cat {egs_list}|" ark:- \| \ + nnet3-copy-egs ark:- {oarks}""".format( + cmd=cmd, msjr=num_jobs, + nai=num_archives_intermediate, srand=srand, + dir=dir, egs_list=egs_list, oarks=output_archives)) + + if generate_egs_scp: + out_egs_handle = open("{0}/egs.scp".format(dir), 'w') + for i in range(1, num_archives_intermediate + 1): + for j in range(1, archives_multiple + 1): + for line in open("{0}/egs.{1}.{2}.scp".format(dir, i, j)): + print (line.strip(), file=out_egs_handle) + out_egs_handle.close() + + cleanup(dir, archives_multiple, generate_egs_scp) + return {'num_frames': num_frames, + 'num_archives': num_archives, + 'egs_per_archive': egs_per_archive} + + +def cleanup(dir, archives_multiple, generate_egs_scp=False): + logger.info("Removing temporary archives in {0}.".format(dir)) + for file_name in glob.glob("{0}/egs_orig*".format(dir)): + real_path = os.path.realpath(file_name) + data_lib.try_to_delete(real_path) + data_lib.try_to_delete(file_name) + + if archives_multiple > 1 and not generate_egs_scp: + # there will be some extra soft links we want to delete + for file_name in glob.glob('{0}/egs.*.*.ark'.format(dir)): + os.remove(file_name) + + +def create_directory(dir): + import errno + try: + os.makedirs(dir) + except OSError, e: + if e.errno == errno.EEXIST: + pass + + +def generate_training_examples(dir, targets_parameters, feat_dir, + feat_ivector_strings, egs_opts, + frame_shift, frames_per_eg, samples_per_iter, + cmd, num_jobs, srand=0, + only_shuffle=False, dry_run=False, + generate_egs_scp=False): + + # generate the training options string with the given chunk_width + train_egs_opts = egs_opts['train_egs_opts'] + # generate the feature vector string with the utt list for the + # current chunk width + train_feats = feat_ivector_strings['train_feats'] + + if os.path.isdir('{0}/storage'.format(dir)): + real_paths = [os.path.realpath(x).strip("/") + for x in glob.glob('{0}/storage/*'.format(dir))] + common_lib.run_kaldi_command( + """utils/create_split_dir.pl {target_dirs} \ + {dir}/storage""".format( + target_dirs=" ".join(real_paths), dir=dir)) + + info = generate_training_examples_internal( + dir=dir, targets_parameters=targets_parameters, + feat_dir=feat_dir, train_feats_string=train_feats, + train_egs_opts_string=train_egs_opts, + ivector_opts=feat_ivector_strings['ivector_opts'], + num_jobs=num_jobs, frames_per_eg=frames_per_eg, + samples_per_iter=samples_per_iter, cmd=cmd, + srand=srand, + only_shuffle=only_shuffle, + dry_run=dry_run, + generate_egs_scp=generate_egs_scp) + + return info + + +def write_egs_info(info, info_dir): + for x in ['num_frames', 'num_archives', 'egs_per_archive', + 'feat_dim', 'ivector_dim', + 'left_context', 'right_context', 'frames_per_eg']: + write_list([info['{0}'.format(x)]], '{0}/{1}'.format(info_dir, x)) + + +def generate_egs(egs_dir, feat_dir, targets_para_array, + online_ivector_dir=None, + frames_per_eg=8, + left_context=4, + right_context=4, + valid_left_context=None, + valid_right_context=None, + cmd="run.pl", stage=0, + cmvn_opts=None, apply_cmvn_sliding=False, + compress_input=True, + input_compress_format=0, + num_utts_subset_train=300, + num_utts_subset_valid=300, + num_train_egs_combine=1000, + num_valid_egs_combine=0, + num_egs_diagnostic=4000, + samples_per_iter=400000, + num_jobs=6, + srand=0, + generate_egs_scp=False): + + for directory in '{0}/log {0}/info'.format(egs_dir).split(): + create_directory(directory) + + print (cmvn_opts if cmvn_opts is not None else '', + file=open('{0}/cmvn_opts'.format(egs_dir), 'w')) + print ("true" if apply_cmvn_sliding else "false", + file=open('{0}/apply_cmvn_sliding'.format(egs_dir), 'w')) + + targets_parameters = parse_targets_parameters_array(targets_para_array) + + # Check files + check_for_required_files(feat_dir, + [t.targets_scp for t in targets_parameters], + online_ivector_dir) + + frame_shift = data_lib.get_frame_shift(feat_dir) + min_duration = frames_per_eg * frame_shift + valid_utts = sample_utts(feat_dir, num_utts_subset_valid, min_duration)[0] + train_subset_utts = sample_utts(feat_dir, num_utts_subset_train, + min_duration, exclude_list=valid_utts)[0] + train_utts, train_utts_durs = sample_utts(feat_dir, None, -1, + exclude_list=valid_utts) + + write_list(valid_utts, '{0}/valid_uttlist'.format(egs_dir)) + write_list(train_subset_utts, '{0}/train_subset_uttlist'.format(egs_dir)) + write_list(train_utts, '{0}/train_uttlist'.format(egs_dir)) + + # split the training data into parts for individual jobs + # we will use the same number of jobs as that used for alignment + split_feat_dir = common_lib.split_data(feat_dir, num_jobs) + feat_ivector_strings = get_feat_ivector_strings( + dir=egs_dir, feat_dir=feat_dir, split_feat_dir=split_feat_dir, + cmvn_opt_string=cmvn_opts, + ivector_dir=online_ivector_dir, + apply_cmvn_sliding=apply_cmvn_sliding) + + egs_opts = get_egs_options(targets_parameters=targets_parameters, + frames_per_eg=frames_per_eg, + left_context=left_context, + right_context=right_context, + valid_left_context=valid_left_context, + valid_right_context=valid_right_context, + compress_input=compress_input, + input_compress_format=input_compress_format) + + if stage <= 2: + logger.info("Generating validation and training subset examples") + + generate_valid_train_subset_egs( + dir=egs_dir, + targets_parameters=targets_parameters, + feat_ivector_strings=feat_ivector_strings, + egs_opts=egs_opts, + num_train_egs_combine=num_train_egs_combine, + num_valid_egs_combine=num_valid_egs_combine, + num_egs_diagnostic=num_egs_diagnostic, + cmd=cmd, + num_jobs=num_jobs, + generate_egs_scp=generate_egs_scp) + + logger.info("Generating training examples on disk.") + info = generate_training_examples( + dir=egs_dir, + targets_parameters=targets_parameters, + feat_dir=feat_dir, + feat_ivector_strings=feat_ivector_strings, + egs_opts=egs_opts, + frame_shift=frame_shift, + frames_per_eg=frames_per_eg, + samples_per_iter=samples_per_iter, + cmd=cmd, + num_jobs=num_jobs, + srand=srand, + only_shuffle=True if stage > 3 else False, + dry_run=True if stage > 4 else False, + generate_egs_scp=generate_egs_scp) + + info['feat_dim'] = feat_ivector_strings['feat_dim'] + info['ivector_dim'] = feat_ivector_strings['ivector_dim'] + info['left_context'] = left_context + info['right_context'] = right_context + info['frames_per_eg'] = frames_per_eg + + write_egs_info(info, '{dir}/info'.format(dir=egs_dir)) + + +def main(): + args = get_args() + generate_egs(args.dir, args.feat_dir, args.targets_para_array, + online_ivector_dir=args.online_ivector_dir, + frames_per_eg=args.frames_per_eg, + left_context=args.left_context, + right_context=args.right_context, + valid_left_context=args.valid_left_context, + valid_right_context=args.valid_right_context, + cmd=args.cmd, stage=args.stage, + cmvn_opts=args.cmvn_opts, + apply_cmvn_sliding=args.apply_cmvn_sliding, + compress_input=args.compress_input, + input_compress_format=args.input_compress_format, + num_utts_subset_train=args.num_utts_subset_train, + num_utts_subset_valid=args.num_utts_subset_valid, + num_train_egs_combine=args.num_train_egs_combine, + num_valid_egs_combine=args.num_valid_egs_combine, + num_egs_diagnostic=args.num_egs_diagnostic, + samples_per_iter=args.samples_per_iter, + num_jobs=args.num_jobs, + srand=args.srand, + generate_egs_scp=args.generate_egs_scp) + + +if __name__ == "__main__": + main() diff --git a/egs/wsj/s5/steps/nnet3/get_egs_targets.sh b/egs/wsj/s5/steps/nnet3/get_egs_targets.sh index b8fcbfd51fa..9bf7a1ac853 100755 --- a/egs/wsj/s5/steps/nnet3/get_egs_targets.sh +++ b/egs/wsj/s5/steps/nnet3/get_egs_targets.sh @@ -24,6 +24,7 @@ feat_type=raw # set it to 'lda' to use LDA features. target_type=sparse # dense to have dense targets, # sparse to have posteriors targets num_targets= # required for target-type=sparse with raw nnet +deriv_weights_scp= frames_per_eg=8 # number of frames of labels per example. more->less disk space and # less time preparing egs, but more I/O during training. # Note: may in general be a comma-separated string of alternative @@ -38,6 +39,12 @@ compress=true # set this to false to disable compression (e.g. if you want to # results are affected). num_utts_subset=300 # number of utterances in validation and training # subsets used for shrinkage and diagnostics. +num_utts_subset_valid= # number of utterances in validation + # subsets used for shrinkage and diagnostics + # if provided, overrides num-utts-subset +num_utts_subset_train= # number of utterances in training + # subsets used for shrinkage and diagnostics. + # if provided, overrides num-utts-subset num_valid_frames_combine=0 # #valid frames for combination weights at the very end. num_train_frames_combine=60000 # # train frames for the above. num_frames_diagnostic=10000 # number of frames for "compute_prob" jobs @@ -53,7 +60,7 @@ stage=0 nj=6 # This should be set to the maximum number of jobs you are # comfortable to run in parallel; you can increase it if your disk # speed is greater and you have more machines. -srand=0 +srand=0 # rand seed for nnet3-copy-egs and nnet3-shuffle-egs online_ivector_dir= # can be used if we are including speaker information as iVectors. cmvn_opts= # can be used for specifying CMVN options, if feature type is not lda (if lda, # it doesn't make sense to use different options than were used as input to the @@ -113,9 +120,18 @@ utils/split_data.sh $data $nj mkdir -p $dir/log $dir/info +[ -z "$num_utts_subset_valid" ] && num_utts_subset_valid=$num_utts_subset +[ -z "$num_utts_subset_train" ] && num_utts_subset_train=$num_utts_subset + +num_utts=$(cat $data/utt2spk | wc -l) +if ! [ $num_utts -gt $[$num_utts_subset_valid*4] ]; then + echo "$0: number of utterances $num_utts in your training data is too small versus --num-utts-subset=$num_utts_subset" + echo "... you probably have so little data that it doesn't make sense to train a neural net." + exit 1 +fi # Get list of validation utterances. -awk '{print $1}' $data/utt2spk | utils/shuffle_list.pl | head -$num_utts_subset | sort \ +awk '{print $1}' $data/utt2spk | utils/shuffle_list.pl | head -$num_utts_subset_valid | sort \ > $dir/valid_uttlist || exit 1; if [ -f $data/utt2uniq ]; then # this matters if you use data augmentation. @@ -130,7 +146,7 @@ if [ -f $data/utt2uniq ]; then # this matters if you use data augmentation. fi awk '{print $1}' $data/utt2spk | utils/filter_scp.pl --exclude $dir/valid_uttlist | \ - utils/shuffle_list.pl | head -$num_utts_subset | sort > $dir/train_subset_uttlist || exit 1; + utils/shuffle_list.pl | head -$num_utts_subset_train > $dir/train_subset_uttlist || exit 1; if [ ! -z "$transform_dir" ] && [ -f $transform_dir/trans.1 ] && [ $feat_type != "raw" ]; then echo "$0: using transforms from $transform_dir" @@ -147,15 +163,33 @@ if [ -f $transform_dir/raw_trans.1 ] && [ $feat_type == "raw" ]; then fi fi +nj_subset=$nj + +if [ $nj_subset -gt `cat $dir/train_subset_uttlist | wc -l` ]; then + nj_subset=`cat $dir/train_subset_uttlist | wc -l` +fi + +if [ $nj_subset -gt `cat $dir/valid_uttlist | wc -l` ]; then + nj_subset=`cat $dir/valid_uttlist | wc -l` +fi + +valid_uttlist_all= +train_subset_uttlist_all= +for n in `seq $nj_subset`; do + valid_uttlist_all="$valid_uttlist_all $dir/valid_uttlist.$n" + train_subset_uttlist_all="$train_subset_uttlist_all $dir/train_subset_uttlist.$n" +done +utils/split_scp.pl $dir/valid_uttlist $valid_uttlist_all +utils/split_scp.pl $dir/train_subset_uttlist $train_subset_uttlist_all ## Set up features. echo "$0: feature type is $feat_type" case $feat_type in raw) feats="ark,s,cs:utils/filter_scp.pl --exclude $dir/valid_uttlist $sdata/JOB/feats.scp | apply-cmvn $cmvn_opts --utt2spk=ark:$sdata/JOB/utt2spk scp:$sdata/JOB/cmvn.scp scp:- ark:- |" - valid_feats="ark,s,cs:utils/filter_scp.pl $dir/valid_uttlist $data/feats.scp | apply-cmvn $cmvn_opts --utt2spk=ark:$data/utt2spk scp:$data/cmvn.scp scp:- ark:- |" - train_subset_feats="ark,s,cs:utils/filter_scp.pl $dir/train_subset_uttlist $data/feats.scp | apply-cmvn $cmvn_opts --utt2spk=ark:$data/utt2spk scp:$data/cmvn.scp scp:- ark:- |" + valid_feats="ark,s,cs:utils/filter_scp.pl $dir/valid_uttlist.JOB $data/feats.scp | apply-cmvn $cmvn_opts --utt2spk=ark:$data/utt2spk scp:$data/cmvn.scp scp:- ark:- |" + train_subset_feats="ark,s,cs:utils/filter_scp.pl $dir/train_subset_uttlist.JOB $data/feats.scp | apply-cmvn $cmvn_opts --utt2spk=ark:$data/utt2spk scp:$data/cmvn.scp scp:- ark:- |" echo $cmvn_opts >$dir/cmvn_opts # caution: the top-level nnet training script should copy this to its own dir now. ;; lda) @@ -166,8 +200,8 @@ case $feat_type in echo "You cannot supply --cmvn-opts option if feature type is LDA." && exit 1; cmvn_opts=$(cat $dir/cmvn_opts) feats="ark,s,cs:utils/filter_scp.pl --exclude $dir/valid_uttlist $sdata/JOB/feats.scp | apply-cmvn $cmvn_opts --utt2spk=ark:$sdata/JOB/utt2spk scp:$sdata/JOB/cmvn.scp scp:- ark:- | splice-feats $splice_opts ark:- ark:- | transform-feats $dir/final.mat ark:- ark:- |" - valid_feats="ark,s,cs:utils/filter_scp.pl $dir/valid_uttlist $data/feats.scp | apply-cmvn $cmvn_opts --utt2spk=ark:$data/utt2spk scp:$data/cmvn.scp scp:- ark:- | splice-feats $splice_opts ark:- ark:- | transform-feats $dir/final.mat ark:- ark:- |" - train_subset_feats="ark,s,cs:utils/filter_scp.pl $dir/train_subset_uttlist $data/feats.scp | apply-cmvn $cmvn_opts --utt2spk=ark:$data/utt2spk scp:$data/cmvn.scp scp:- ark:- | splice-feats $splice_opts ark:- ark:- | transform-feats $dir/final.mat ark:- ark:- |" + valid_feats="ark,s,cs:utils/filter_scp.pl $dir/valid_uttlist.JOB $data/feats.scp | apply-cmvn $cmvn_opts --utt2spk=ark:$data/utt2spk scp:$data/cmvn.scp scp:- ark:- | splice-feats $splice_opts ark:- ark:- | transform-feats $dir/final.mat ark:- ark:- |" + train_subset_feats="ark,s,cs:utils/filter_scp.pl $dir/train_subset_uttlist.JOB $data/feats.scp | apply-cmvn $cmvn_opts --utt2spk=ark:$data/utt2spk scp:$data/cmvn.scp scp:- ark:- | splice-feats $splice_opts ark:- ark:- | transform-feats $dir/final.mat ark:- ark:- |" ;; *) echo "$0: invalid feature type --feat-type '$feat_type'" && exit 1; esac @@ -183,6 +217,7 @@ if [ ! -z "$online_ivector_dir" ]; then ivector_dim=$(feat-to-dim scp:$online_ivector_dir/ivector_online.scp -) || exit 1 echo $ivector_dim > $dir/info/ivector_dim ivector_period=$(cat $online_ivector_dir/ivector_period) || exit 1; + ivector_opts="--online-ivectors=scp:$online_ivector_dir/ivector_online.scp --online-ivector-period=$ivector_period" else ivector_opts="" @@ -264,6 +299,10 @@ egs_opts="--left-context=$left_context --right-context=$right_context --compress [ $left_context_initial -ge 0 ] && egs_opts="$egs_opts --left-context-initial=$left_context_initial" [ $right_context_final -ge 0 ] && egs_opts="$egs_opts --right-context-final=$right_context_final" +[ ! -z "$deriv_weights_scp" ] && egs_opts="$egs_opts --deriv-weights-rspecifier=scp:$deriv_weights_scp" + +[ ! -z "$deriv_weights_scp" ] && valid_egs_opts="$valid_egs_opts --deriv-weights-rspecifier=scp:$deriv_weights_scp" + echo $left_context > $dir/info/left_context echo $right_context > $dir/info/right_context echo $left_context_initial > $dir/info/left_context_initial @@ -288,15 +327,15 @@ case $target_type in "dense") get_egs_program="nnet3-get-egs-dense-targets --num-targets=$num_targets" - targets="ark:utils/filter_scp.pl --exclude $dir/valid_uttlist $targets_scp_split | copy-feats scp:- ark:- |" - valid_targets="ark:utils/filter_scp.pl $dir/valid_uttlist $targets_scp | copy-feats scp:- ark:- |" - train_subset_targets="ark:utils/filter_scp.pl $dir/train_subset_uttlist $targets_scp | copy-feats scp:- ark:- |" + targets="ark,s,cs:utils/filter_scp.pl --exclude $dir/valid_uttlist $targets_scp_split | copy-feats scp:- ark:- |" + valid_targets="ark,s,cs:utils/filter_scp.pl $dir/valid_uttlist.JOB $targets_scp | copy-feats scp:- ark:- |" + train_subset_targets="ark,s,cs:utils/filter_scp.pl $dir/train_subset_uttlist.JOB $targets_scp | copy-feats scp:- ark:- |" ;; "sparse") get_egs_program="nnet3-get-egs --num-pdfs=$num_targets" - targets="ark:utils/filter_scp.pl --exclude $dir/valid_uttlist $targets_scp_split | ali-to-post scp:- ark:- |" - valid_targets="ark:utils/filter_scp.pl $dir/valid_uttlist $targets_scp | ali-to-post scp:- ark:- |" - train_subset_targets="ark:utils/filter_scp.pl $dir/train_subset_uttlist $targets_scp | ali-to-post scp:- ark:- |" + targets="ark,s,cs:utils/filter_scp.pl --exclude $dir/valid_uttlist $targets_scp_split | ali-to-post scp:- ark:- |" + valid_targets="ark,s,cs:utils/filter_scp.pl $dir/valid_uttlist.JOB $targets_scp | ali-to-post scp:- ark:- |" + train_subset_targets="ark,s,cs:utils/filter_scp.pl $dir/train_subset_uttlist.JOB $targets_scp | ali-to-post scp:- ark:- |" ;; default) echo "$0: Unknown --target-type $target_type. Choices are dense and sparse" @@ -306,31 +345,44 @@ esac if [ $stage -le 3 ]; then echo "$0: Getting validation and training subset examples." rm -f $dir/.error 2>/dev/null - $cmd $dir/log/create_valid_subset.log \ + $cmd JOB=1:$nj_subset $dir/log/create_valid_subset.JOB.log \ $get_egs_program \ $ivector_opts $egs_opts "$valid_feats" \ "$valid_targets" \ - "ark:$dir/valid_all.egs" || touch $dir/.error & - $cmd $dir/log/create_train_subset.log \ + "ark:$dir/valid_all.JOB.egs" || touch $dir/.error & + $cmd JOB=1:$nj_subset $dir/log/create_train_subset.JOB.log \ $get_egs_program \ $ivector_opts $egs_opts "$train_subset_feats" \ "$train_subset_targets" \ - "ark:$dir/train_subset_all.egs" || touch $dir/.error & + "ark:$dir/train_subset_all.JOB.egs" || touch $dir/.error & wait; + + valid_egs_all= + train_subset_egs_all= + for n in `seq $nj_subset`; do + valid_egs_all="$valid_egs_all $dir/valid_all.$n.egs" + train_subset_egs_all="$train_subset_egs_all $dir/train_subset_all.$n.egs" + done + [ -f $dir/.error ] && echo "Error detected while creating train/valid egs" && exit 1 echo "... Getting subsets of validation examples for diagnostics and combination." $cmd $dir/log/create_valid_subset_combine.log \ - nnet3-subset-egs --n=$[$num_valid_frames_combine/$frames_per_eg_principal] ark:$dir/valid_all.egs \ + cat $valid_egs_all \| \ + nnet3-subset-egs --n=$[$num_valid_frames_combine/$frames_per_eg_principal] ark:- \ ark:$dir/valid_combine.egs || touch $dir/.error & $cmd $dir/log/create_valid_subset_diagnostic.log \ - nnet3-subset-egs --n=$[$num_frames_diagnostic/$frames_per_eg_principal] ark:$dir/valid_all.egs \ + cat $valid_egs_all \| \ + nnet3-subset-egs --n=$[$num_frames_diagnostic/$frames_per_eg_principal] ark:- ark:$dir/valid_diagnostic.egs || touch $dir/.error & $cmd $dir/log/create_train_subset_combine.log \ - nnet3-subset-egs --n=$[$num_train_frames_combine/$frames_per_eg_principal] ark:$dir/train_subset_all.egs \ + cat $train_subset_egs_all \| \ + nnet3-subset-egs --n=$[$num_train_frames_combine/$frames_per_eg_principal] ark:- \ ark:$dir/train_combine.egs || touch $dir/.error & $cmd $dir/log/create_train_subset_diagnostic.log \ - nnet3-subset-egs --n=$[$num_frames_diagnostic/$frames_per_eg_principal] ark:$dir/train_subset_all.egs \ + cat $train_subset_egs_all \| \ + nnet3-subset-egs --n=$[$num_frames_diagnostic/$frames_per_eg_principal] ark:- \ + nnet3-subset-egs --n=$num_frames_diagnostic ark:- \ ark:$dir/train_diagnostic.egs || touch $dir/.error & wait sleep 5 # wait for file system to sync. @@ -339,7 +391,7 @@ if [ $stage -le 3 ]; then for f in $dir/{combine,train_diagnostic,valid_diagnostic}.egs; do [ ! -s $f ] && echo "No examples in file $f" && exit 1; done - rm -f $dir/valid_all.egs $dir/train_subset_all.egs $dir/{train,valid}_combine.egs + rm $dir/valid_all.*.egs $dir/train_subset_all.*.egs $dir/{train,valid}_combine.egs fi if [ $stage -le 4 ]; then @@ -394,6 +446,8 @@ if [ $stage -le 5 ]; then fi +wait + if [ $stage -le 6 ]; then echo "$0: removing temporary archives" for x in $(seq $nj); do @@ -407,9 +461,11 @@ if [ $stage -le 6 ]; then # there are some extra soft links that we should delete. for f in $dir/egs.*.*.ark; do rm $f; done fi - echo "$0: removing temporary" + echo "$0: removing temporary stuff" # Ignore errors below because trans.* might not exist. rm -f $dir/trans.{ark,scp} $dir/targets.*.scp 2>/dev/null fi +wait + echo "$0: Finished preparing training examples" diff --git a/egs/wsj/s5/steps/nnet3/lstm/make_configs.py b/egs/wsj/s5/steps/nnet3/lstm/make_configs.py index b80a8d4045b..578433e4fe6 100755 --- a/egs/wsj/s5/steps/nnet3/lstm/make_configs.py +++ b/egs/wsj/s5/steps/nnet3/lstm/make_configs.py @@ -58,6 +58,18 @@ def GetArgs(): parser.add_argument("--max-change-per-component-final", type=float, help="Enforces per-component max change for the final affine layer. " "if 0 it would not be enforced.", default=1.5) + parser.add_argument("--add-lda", type=str, action=nnet3_train_lib.StrToBoolAction, + help="If \"true\" an LDA matrix computed from the input features " + "(spliced according to the first set of splice-indexes) will be used as " + "the first Affine layer. This affine layer's parameters are fixed during training. " + "This variable needs to be set to \"false\" when using dense-targets.", + default=True, choices = ["false", "true"]) + parser.add_argument("--add-final-sigmoid", type=str, action=nnet3_train_lib.StrToBoolAction, + help="add a sigmoid layer as the final layer. Applicable only if skip-final-softmax is true.", + choices=['true', 'false'], default = False) + parser.add_argument("--objective-type", type=str, default="linear", + choices = ["linear", "quadratic", "xent"], + help = "the type of objective; i.e. quadratic or linear or cross-entropy per dim") # LSTM options parser.add_argument("--num-lstm-layers", type=int, @@ -219,7 +231,9 @@ def ParseLstmDelayString(lstm_delay): raise ValueError("invalid --lstm-delay argument, too-short element: " + lstm_delay) elif len(indexes) == 2 and indexes[0] * indexes[1] >= 0: - raise ValueError('Warning: ' + str(indexes) + ' is not a standard BLSTM mode. There should be a negative delay for the forward, and a postive delay for the backward.') + raise ValueError('Warning: ' + str(indexes) + + ' is not a standard BLSTM mode. ' + + 'There should be a negative delay for the forward, and a postive delay for the backward.') if len(indexes) == 2 and indexes[0] > 0: # always a negative delay followed by a postive delay indexes[0], indexes[1] = indexes[1], indexes[0] lstm_delay_array.append(indexes) @@ -229,29 +243,35 @@ def ParseLstmDelayString(lstm_delay): return lstm_delay_array -def MakeConfigs(config_dir, feat_dim, ivector_dim, num_targets, +def MakeConfigs(config_dir, feat_dim, ivector_dim, num_targets, add_lda, splice_indexes, lstm_delay, cell_dim, hidden_dim, recurrent_projection_dim, non_recurrent_projection_dim, num_lstm_layers, num_hidden_layers, norm_based_clipping, clipping_threshold, zeroing_threshold, zeroing_interval, ng_per_element_scale_options, ng_affine_options, - label_delay, include_log_softmax, xent_regularize, + label_delay, include_log_softmax, add_final_sigmoid, + objective_type, xent_regularize, self_repair_scale_nonlinearity, self_repair_scale_clipgradient, max_change_per_component, max_change_per_component_final): config_lines = {'components':[], 'component-nodes':[]} config_files={} - prev_layer_output = nodes.AddInputLayer(config_lines, feat_dim, splice_indexes[0], ivector_dim) + prev_layer_output = nodes.AddInputLayer(config_lines, feat_dim, splice_indexes[0], + ivector_dim) # Add the init config lines for estimating the preconditioning matrices init_config_lines = copy.deepcopy(config_lines) init_config_lines['components'].insert(0, '# Config file for initializing neural network prior to') init_config_lines['components'].insert(0, '# preconditioning matrix computation') - nodes.AddOutputLayer(init_config_lines, prev_layer_output) + nodes.AddOutputLayer(init_config_lines, prev_layer_output, label_delay = label_delay, objective_type = objective_type) config_files[config_dir + '/init.config'] = init_config_lines - prev_layer_output = nodes.AddLdaLayer(config_lines, "L0", prev_layer_output, config_dir + '/lda.mat') + # add_lda needs to be set "false" when using dense targets, + # or if the task is not a simple classification task + # (e.g. regression, multi-task) + if add_lda: + prev_layer_output = nodes.AddLdaLayer(config_lines, "L0", prev_layer_output, args.config_dir + '/lda.mat') for i in range(num_lstm_layers): if len(lstm_delay[i]) == 2: # add a bi-directional LSTM layer @@ -286,7 +306,7 @@ def MakeConfigs(config_dir, feat_dim, ivector_dim, num_targets, max_change_per_component = max_change_per_component) # make the intermediate config file for layerwise discriminative # training - nodes.AddFinalLayer(config_lines, prev_layer_output, num_targets, ng_affine_options, max_change_per_component = max_change_per_component_final, label_delay = label_delay, include_log_softmax = include_log_softmax) + nodes.AddFinalLayer(config_lines, prev_layer_output, num_targets, ng_affine_options, max_change_per_component = max_change_per_component_final, label_delay = label_delay, include_log_softmax = include_log_softmax, add_final_sigmoid = add_final_sigmoid, objective_type = objective_type) if xent_regularize != 0.0: @@ -301,10 +321,11 @@ def MakeConfigs(config_dir, feat_dim, ivector_dim, num_targets, for i in range(num_lstm_layers, num_hidden_layers): prev_layer_output = nodes.AddAffRelNormLayer(config_lines, "L{0}".format(i+1), prev_layer_output, hidden_dim, - ng_affine_options, self_repair_scale = self_repair_scale_nonlinearity, max_change_per_component = max_change_per_component) + ng_affine_options, self_repair_scale = self_repair_scale_nonlinearity, + max_change_per_component = max_change_per_component) # make the intermediate config file for layerwise discriminative # training - nodes.AddFinalLayer(config_lines, prev_layer_output, num_targets, ng_affine_options, max_change_per_component = max_change_per_component_final, label_delay = label_delay, include_log_softmax = include_log_softmax) + nodes.AddFinalLayer(config_lines, prev_layer_output, num_targets, ng_affine_options, max_change_per_component = max_change_per_component_final, label_delay = label_delay, include_log_softmax = include_log_softmax, add_final_sigmoid = add_final_sigmoid, objective_type = objective_type) if xent_regularize != 0.0: nodes.AddFinalLayer(config_lines, prev_layer_output, num_targets, @@ -333,24 +354,30 @@ def ProcessSpliceIndexes(config_dir, splice_indexes, label_delay, num_lstm_layer if (num_hidden_layers < num_lstm_layers): raise Exception("num-lstm-layers : number of lstm layers has to be greater than number of layers, decided based on splice-indexes") - # write the files used by other scripts like steps/nnet3/get_egs.sh - f = open(config_dir + "/vars", "w") - print('model_left_context=' + str(left_context), file=f) - print('model_right_context=' + str(right_context), file=f) - print('num_hidden_layers=' + str(num_hidden_layers), file=f) - # print('initial_right_context=' + str(splice_array[0][-1]), file=f) - f.close() - return [left_context, right_context, num_hidden_layers, splice_indexes] def Main(): args = GetArgs() - [left_context, right_context, num_hidden_layers, splice_indexes] = ProcessSpliceIndexes(args.config_dir, args.splice_indexes, args.label_delay, args.num_lstm_layers) + [left_context, right_context, + num_hidden_layers, splice_indexes] = ProcessSpliceIndexes(args.config_dir, args.splice_indexes, + args.label_delay, args.num_lstm_layers) + + # write the files used by other scripts like steps/nnet3/get_egs.sh + f = open(args.config_dir + "/vars", "w") + print('model_left_context=' + str(left_context), file=f) + print('model_right_context=' + str(right_context), file=f) + print('num_hidden_layers=' + str(num_hidden_layers), file=f) + print('num_targets=' + str(args.num_targets), file=f) + print('objective_type=' + str(args.objective_type), file=f) + print('add_lda=' + ("true" if args.add_lda else "false"), file=f) + print('include_log_softmax=' + ("true" if args.include_log_softmax else "false"), file=f) + f.close() MakeConfigs(config_dir = args.config_dir, feat_dim = args.feat_dim, ivector_dim = args.ivector_dim, num_targets = args.num_targets, + add_lda = args.add_lda, splice_indexes = splice_indexes, lstm_delay = args.lstm_delay, cell_dim = args.cell_dim, hidden_dim = args.hidden_dim, @@ -366,6 +393,8 @@ def Main(): ng_affine_options = args.ng_affine_options, label_delay = args.label_delay, include_log_softmax = args.include_log_softmax, + add_final_sigmoid = args.add_final_sigmoid, + objective_type = args.objective_type, xent_regularize = args.xent_regularize, self_repair_scale_nonlinearity = args.self_repair_scale_nonlinearity, self_repair_scale_clipgradient = args.self_repair_scale_clipgradient, diff --git a/egs/wsj/s5/steps/nnet3/make_jesus_configs.py b/egs/wsj/s5/steps/nnet3/make_jesus_configs.py index b442ce9715b..f88afd6b190 100755 --- a/egs/wsj/s5/steps/nnet3/make_jesus_configs.py +++ b/egs/wsj/s5/steps/nnet3/make_jesus_configs.py @@ -141,73 +141,6 @@ printable_name, old_val, new_val, args.num_jesus_blocks)) setattr(args, name, new_val); -# this is a bit like a struct, initialized from a string, which describes how to -# set up the statistics-pooling and statistics-extraction components. -# An example string is 'mean(-99:3:9::99)', which means, compute the mean of -# data within a window of -99 to +99, with distinct means computed every 9 frames -# (we round to get the appropriate one), and with the input extracted on multiples -# of 3 frames (so this will force the input to this layer to be evaluated -# every 3 frames). Another example string is 'mean+stddev(-99:3:9:99)', -# which will also cause the standard deviation to be computed. -class StatisticsConfig: - # e.g. c = StatisticsConfig('mean+stddev(-99:3:9:99)', 400, 'jesus1-forward-output-affine') - def __init__(self, config_string, input_dim, input_name): - self.input_dim = input_dim - self.input_name = input_name - - m = re.search("(mean|mean\+stddev)\((-?\d+):(-?\d+):(-?\d+):(-?\d+)\)", - config_string) - if m == None: - sys.exit("Invalid splice-index or statistics-config string: " + config_string) - self.output_stddev = (m.group(1) != 'mean') - self.left_context = -int(m.group(2)) - self.input_period = int(m.group(3)) - self.stats_period = int(m.group(4)) - self.right_context = int(m.group(5)) - if not (self.left_context > 0 and self.right_context > 0 and - self.input_period > 0 and self.stats_period > 0 and - self.left_context % self.stats_period == 0 and - self.right_context % self.stats_period == 0 and - self.stats_period % self.input_period == 0): - sys.exit("Invalid configuration of statistics-extraction: " + config_string) - - # OutputDim() returns the output dimension of the node that this produces. - def OutputDim(self): - return self.input_dim * (2 if self.output_stddev else 1) - - # OutputDims() returns an array of output dimensions, consisting of - # [ input-dim ] if just "mean" was specified, otherwise - # [ input-dim input-dim ] - def OutputDims(self): - return [ self.input_dim, self.input_dim ] if self.output_stddev else [ self.input_dim ] - - # Descriptor() returns the textual form of the descriptor by which the - # output of this node is to be accessed. - def Descriptor(self): - return 'Round({0}-pooling-{1}-{2}, {3})'.format(self.input_name, self.left_context, self.right_context, - self.stats_period) - - # This function writes the configuration lines need to compute the specified - # statistics, to the file f. - def WriteConfigs(self, f): - print('component name={0}-extraction-{1}-{2} type=StatisticsExtractionComponent input-dim={3} ' - 'input-period={4} output-period={5} include-variance={6} '.format( - self.input_name, self.left_context, self.right_context, - self.input_dim, self.input_period, self.stats_period, - ('true' if self.output_stddev else 'false')), file=f) - print('component-node name={0}-extraction-{1}-{2} component={0}-extraction-{1}-{2} input={0} '.format( - self.input_name, self.left_context, self.right_context), file=f) - stats_dim = 1 + self.input_dim * (2 if self.output_stddev else 1) - print('component name={0}-pooling-{1}-{2} type=StatisticsPoolingComponent input-dim={3} ' - 'input-period={4} left-context={1} right-context={2} num-log-count-features=0 ' - 'output-stddevs={5} '.format(self.input_name, self.left_context, self.right_context, - stats_dim, self.stats_period, - ('true' if self.output_stddev else 'false')), - file=f) - print('component-node name={0}-pooling-{1}-{2} component={0}-pooling-{1}-{2} input={0}-extraction-{1}-{2} '.format( - self.input_name, self.left_context, self.right_context), file=f) - - ## Work out splice_array diff --git a/egs/wsj/s5/steps/nnet3/multilingual/allocate_multilingual_examples.py b/egs/wsj/s5/steps/nnet3/multilingual/allocate_multilingual_examples.py new file mode 100644 index 00000000000..9bc6da53705 --- /dev/null +++ b/egs/wsj/s5/steps/nnet3/multilingual/allocate_multilingual_examples.py @@ -0,0 +1,288 @@ +#!/usr/bin/env python + +# This script generates egs.Archive.scp and ranges.* used for generating egs.Archive.scp +# for multilingual setup. +# Also this script generates outputs.*.scp and weight.*.scp, where each line +# corresponds to language-id and weight for the same example in egs.*.scp. +# weight.*.scp used to scale the output's posterior during training. +# ranges.*.scp is generated w.r.t frequency distribution of remaining examples +# in each language. +# +# You call this script as (e.g.) +# +# allocate_multilingual_examples.py [opts] num-of-languages example-scp-lists multilingual-egs-dir +# +# allocate_multilingual_examples.py --num-jobs 10 --samples-per-iter 10000 --minibatch-size 512 +# --lang2weight exp/multi/lang2weight 2 "exp/lang1/egs.scp exp/lang2/egs.scp" +# exp/multi/egs +# +# This script outputs specific ranges.* files to the temp directory (exp/multi/egs/temp) +# that will enable you to creat egs.*.scp files for multilingual training. +# exp/multi/egs/temp/ranges.* contains something like the following: +# e.g. +# lang1 0 0 256 +# lang2 1 256 256 +# +# where each line can be interpreted as follows: +# +# +# note that is the zero-based line number in egs.scp for +# that language. +# num-examples is multiple of actual minibatch-size. +# +# +# egs.1.scp is generated using ranges.1.scp as following: +# "num_examples" consecutive examples starting from line "local-scp-line" from +# egs.scp file for language "source-lang" is copied to egs.1.scp. +# +# + +from __future__ import print_function +import re, os, argparse, sys, math, warnings, random, io, imp + +import logging + +sys.path.insert(0, 'steps') +import libs.common as common_lib + +logger = logging.getLogger('libs') +logger.setLevel(logging.INFO) +handler = logging.StreamHandler() +handler.setLevel(logging.INFO) +formatter = logging.Formatter("%(asctime)s [%(filename)s:%(lineno)s - " + "%(funcName)s - %(levelname)s ] %(message)s") +handler.setFormatter(formatter) +logger.addHandler(handler) + + +def GetArgs(): + + parser = argparse.ArgumentParser(description="Writes ranges.*, outputs.* and weights.* files " + "in preparation for dumping egs for multilingual training.", + epilog="Called by steps/nnet3/multilingual/get_egs.sh") + parser.add_argument("--samples-per-iter", type=int, default=40000, + help="The target number of egs in each archive of egs, " + "(prior to merging egs). "); + parser.add_argument("--num-jobs", type=int, default=20, + help="This can be used for better randomness in distributing languages across archives." + ", where egs.job.archive.scp generated randomly and examples are combined " + " across all jobs as eg.archive.scp.") + parser.add_argument("--random-lang", type=str, action=common_lib.StrToBoolAction, + help="If true, the lang-id in ranges.* selected" + " w.r.t frequency distribution of remaining examples in each language," + " otherwise it is selected sequentially.", + default=True, choices = ["false", "true"]) + parser.add_argument("--max-archives", type=int, default=1000, + help="max number of archives used to generate egs.*.scp"); + parser.add_argument("--seed", type=int, default=1, + help="Seed for random number generator") + + parser.add_argument("--minibatch-size", type=int, default=512, + help="The minibatch size used to generate scp files per job. " + "It should be multiple of actual minibatch size."); + + parser.add_argument("--prefix", type=str, default="", + help="Adds a prefix to the range files. This is used to distinguish between the train " + "and diagnostic files.") + + parser.add_argument("--lang2weight", type=str, + help="lang2weight file contains the weight per language to scale output posterior for that language.(format is: " + " )"); +# now the positional arguments + parser.add_argument("num_langs", type=int, + help="num of languages used in multilingual training setup."); + parser.add_argument("egs_scp_lists", type=str, + help="list of egs.scp files per input language." + "e.g. exp/lang1/egs/egs.scp exp/lang2/egs/egs.scp"); + + parser.add_argument("egs_dir", + help="Name of egs directory e.g. exp/multilingual_a/egs"); + + + print(' '.join(sys.argv)) + + args = parser.parse_args() + + return args + + +# Returns a random language number w.r.t +# amount of examples in each language. +# It works based on sampling from a +# discrete distribution, where it returns i +# with prob(i) as (num_egs in lang(i)/ tot_egs). +# tot_egs is sum of lang_len. +def RandomLang(lang_len, tot_egs, random_selection): + assert(tot_egs > 0) + rand_int = random.randint(0, tot_egs - 1) + count = 0 + for l in range(len(lang_len)): + if random_selection: + if rand_int > count and rand_int <= (count + lang_len[l]): + rand_lang = l + break + else: + count += lang_len[l] + else: + if (lang_len[l] > 0): + rand_lang = l + break + assert(rand_lang >= 0 and rand_lang < len(lang_len)) + return rand_lang + +# Read lang2weight file and return lang2weight array +# where lang2weight[i] is weight for language i. +def ReadLang2weight(lang2w_file): + f = open(lang2w_file, "r"); + if f is None: + raise Exception("Error opening lang2weight file " + str(lang2w_file)) + lang2w = [] + for line in f: + a = line.split() + if len(a) != 2: + raise Exception("bad line in lang2weight file " + line) + lang2w.append(int(a[1])) + f.close() + return lang2w + +# struct to keep archives correspond to each job +class ArchiveToJob(): + def __init__(self, job_id, archives_for_job): + self.job_id = job_id + self.archives = archives_for_job + +def Main(): + args = GetArgs() + random.seed(args.seed) + num_langs = args.num_langs + rand_select = args.random_lang + + # read egs.scp for input languages + scp_lists = args.egs_scp_lists.split(); + assert(len(scp_lists) == num_langs); + + scp_files = [open(scp_lists[lang], 'r') for lang in range(num_langs)] + + # computes lang2len, where lang2len[i] shows number of + # examples for language i. + lang2len = [0] * num_langs + for lang in range(num_langs): + lang2len[lang] = sum(1 for line in open(scp_lists[lang])) + logger.info("Number of examples for language {0} is {1}".format(lang, lang2len[lang])) + + # If weights are not provided, the scaling weights + # are one. + if args.lang2weight is None: + lang2weight = [ 1.0 ] * num_langs + else: + lang2weight = ReadLang2Len(args.lang2weight) + assert(len(lang2weight) == num_langs) + + if not os.path.exists(args.egs_dir + "/temp"): + os.makedirs(args.egs_dir + "/temp") + + num_lang_file = open(args.egs_dir + "/info/" + args.prefix + "num_lang", "w"); + print("{0}".format(num_langs), file = num_lang_file) + + + # Each element of all_egs (one per num_archive * num_jobs) is + # an array of 3-tuples (lang-id, local-start-egs-line, num-egs) + all_egs = [] + lang_len = lang2len[:] + tot_num_egs = sum(lang2len[i] for i in range(len(lang2len))) # total num of egs in all languages + num_archives = max(1, min(args.max_archives, tot_num_egs / args.samples_per_iter)) + + + num_arch_file = open(args.egs_dir + "/info/" + args.prefix + "num_archives", "w"); + print("{0}".format(num_archives), file = num_arch_file) + num_arch_file.close() + + this_num_egs_per_archive = tot_num_egs / (num_archives * args.num_jobs) # num of egs per archive + for job_index in range(args.num_jobs): + for archive_index in range(num_archives): + # Temporary scp.job_index.archive_index files to store egs.scp correspond to each archive. + logger.debug("Processing archive {0} for job {1}".format(archive_index + 1, job_index + 1)) + archfile = open(args.egs_dir + "/temp/" + args.prefix + "scp." + str(job_index + 1) + "." + str(archive_index + 1), "w") + + this_egs = [] # this will be array of 2-tuples (lang-id start-frame num-frames) + + num_egs = 0 + while num_egs <= this_num_egs_per_archive: + rem_egs = sum(lang_len[i] for i in range(len(lang_len))) + if rem_egs > 0: + lang_id = RandomLang(lang_len, rem_egs, rand_select) + start_egs = lang2len[lang_id] - lang_len[lang_id] + this_egs.append((lang_id, start_egs, args.minibatch_size)) + for scpline in range(args.minibatch_size): + lines = scp_files[lang_id].readline().splitlines() + try: + print("{0} {1}".format(lines[0], lang_id), file=archfile) + except Exception: + logger.error("Failure to read from file %s, got %s", + scp_files[lang_id].name, lines) + raise + + lang_len[lang_id] = lang_len[lang_id] - args.minibatch_size + num_egs = num_egs + args.minibatch_size; + # If the num of remaining egs in each lang is less than minibatch_size, + # they are discarded. + if lang_len[lang_id] < args.minibatch_size: + lang_len[lang_id] = 0 + logger.debug("Run out of data for language {0}".format(lang_id)) + else: + logger.debug("Run out of data for all languages.") + break + all_egs.append(this_egs) + archfile.close() + + # combine examples across all jobs correspond to each archive. + for archive in range(num_archives): + logger.debug("Processing archive {0} by combining all jobs.".format(archive + 1)) + this_ranges = [] + f = open(args.egs_dir + "/temp/" + args.prefix + "ranges." + str(archive + 1), "w") + o = open(args.egs_dir + "/" + args.prefix + "output." + str(archive + 1), "w") + w = open(args.egs_dir + "/" + args.prefix + "weight." + str(archive + 1), "w") + scp_per_archive_file = open(args.egs_dir + "/" + args.prefix + "egs." + str(archive + 1), "w") + + # check files befor writing. + if f is None: + raise Exception("Error opening file " + args.egs_dir + "/temp/" + args.prefix + "ranges." + str(job + 1)) + if o is None: + raise Exception("Error opening file " + args.egs_dir + "/" + args.prefix + "output." + str(job + 1)) + if w is None: + raise Exception("Error opening file " + args.egs_dir + "/" + args.prefix + "weight." + str(job + 1)) + if scp_per_archive_file is None: + raise Exception("Error opening file " + args.egs_dir + "/" + args.prefix + "egs." + str(archive + 1), "w") + + for job in range(args.num_jobs): + # combine egs.job.archive.scp across all jobs. + scp = args.egs_dir + "/temp/" + args.prefix + "scp." + str(job + 1) + "." + str(archive + 1) + with open(scp, "r") as scpfile: + for line in scpfile: + try: + scp_line = line.splitlines()[0].split() + print("{0} {1}".format(scp_line[0], scp_line[1]), file=scp_per_archive_file) + print("{0} output-{1}".format(scp_line[0], scp_line[2]), file=o) + print("{0} {1}".format(scp_line[0], lang2weight[int(scp_line[2])]), file=w) + except Exception: + logger.error("Failed processing line %s in scp %s", line, + scpfile.name) + raise + os.remove(scp) + + # combine ranges.* across all jobs for archive + for (lang_id, start_eg_line, num_egs) in all_egs[num_archives * job + archive]: + this_ranges.append((lang_id, start_eg_line, num_egs)) + + # write ranges.archive + for (lang_id, start_eg_line, num_egs) in this_ranges: + print("{0} {1} {2}".format(lang_id, start_eg_line, num_egs), file=f) + + scp_per_archive_file.close() + f.close() + o.close() + w.close() + print("allocate_multilingual_examples.py finished generating " + args.prefix + "egs.*.scp and " + args.prefix + "ranges.* and " + args.prefix + "output.*" + args.prefix + "weight.* files") + +if __name__ == "__main__": + Main() diff --git a/egs/wsj/s5/steps/nnet3/multilingual/get_egs.sh b/egs/wsj/s5/steps/nnet3/multilingual/get_egs.sh new file mode 100755 index 00000000000..58ef965de3e --- /dev/null +++ b/egs/wsj/s5/steps/nnet3/multilingual/get_egs.sh @@ -0,0 +1,133 @@ +#!/bin/bash +# +# This script uses separate input egs directory for each language as input, +# to generate egs.*.scp files in multilingual egs directory +# where the scp line points to the original archive for each egs directory. +# $megs/egs.*.scp is randomized w.r.t language id. +# +# Also this script generates egs.JOB.scp, output.JOB.scp and weight.JOB.scp, +# where output file contains language-id for each example +# and weight file contains weights for scaling output posterior +# for each example w.r.t input language. +# + +set -e +set -o pipefail +set -u + +# Begin configuration section. +cmd=run.pl +minibatch_size=512 # multiple of minibatch used during training. +minibatch_size= +num_jobs=10 # This can be set to max number of jobs to run in parallel; + # Helps for better randomness across languages + # per archive. +samples_per_iter=400000 # this is the target number of egs in each archive of egs + # (prior to merging egs). We probably should have called + # it egs_per_iter. This is just a guideline; it will pick + # a number that divides the number of samples in the + # entire data. +stage=0 + +echo "$0 $@" # Print the command line for logging + +if [ -f path.sh ]; then . ./path.sh; fi +. parse_options.sh || exit 1; + +num_langs=$1 +shift 1 +args=("$@") +megs_dir=${args[-1]} # multilingual directory +mkdir -p $megs_dir +mkdir -p $megs_dir/info + +if [ ${#args[@]} != $[$num_langs+1] ]; then + echo "$0: Number of input example dirs provided is not compatible with num_langs $num_langs." + echo "Usage:$0 [opts] ... " + echo "Usage:$0 [opts] 2 exp/lang1/egs exp/lang2/egs exp/multi/egs" + exit 1; +fi + +required_files="egs.scp combine.egs.scp train_diagnostic.egs.scp valid_diagnostic.egs.scp" +train_scp_list= +train_diagnostic_scp_list= +valid_diagnostic_scp_list= +combine_scp_list= + +# copy paramters from $egs_dir[0]/info +# into multilingual dir egs_dir/info + +params_to_check="feat_dim ivector_dim left_context right_context frames_per_eg" +for param in $params_to_check; do + cat ${args[0]}/info/$param > $megs_dir/info/$param || exit 1; +done + +for lang in $(seq 0 $[$num_langs-1]);do + multi_egs_dir[$lang]=${args[$lang]} + echo "arg[$lang] = ${args[$lang]}" + for f in $required_files; do + if [ ! -f ${multi_egs_dir[$lang]}/$f ]; then + echo "$0: no such a file ${multi_egs_dir[$lang]}/$f." && exit 1; + fi + done + train_scp_list="$train_scp_list ${args[$lang]}/egs.scp" + train_diagnostic_scp_list="$train_diagnostic_scp_list ${args[$lang]}/train_diagnostic.egs.scp" + valid_diagnostic_scp_list="$valid_diagnostic_scp_list ${args[$lang]}/valid_diagnostic.egs.scp" + combine_scp_list="$combine_scp_list ${args[$lang]}/combine.egs.scp" + + # check parameter dimension to be the same in all egs dirs + for f in $params_to_check; do + f1=`cat $megs_dir/info/$param`; + f2=`cat ${multi_egs_dir[$lang]}/info/$f`; + if [ $f1 != $f1 ]; then + echo "$0: mismatch in dimension for $f parameter in ${multi_egs_dir[$lang]}." + exit 1; + fi + done +done + +cp ${multi_egs_dir[$lang]}/cmvn_opts $megs_dir + +if [ $stage -le 0 ]; then + echo "$0: allocating multilingual examples for training." + # Generate egs.*.scp for multilingual setup. + $cmd $megs_dir/log/allocate_multilingual_examples_train.log \ + python steps/nnet3/multilingual/allocate_multilingual_examples.py \ + --minibatch-size $minibatch_size \ + --samples-per-iter $samples_per_iter \ + $num_langs "$train_scp_list" $megs_dir || exit 1; +fi + +if [ $stage -le 1 ]; then + echo "$0: combine combine.egs.scp examples from all langs in $megs_dir/combine.egs.scp." + # Generate combine.egs.scp for multilingual setup. + $cmd $megs_dir/log/allocate_multilingual_examples_combine.log \ + python steps/nnet3/multilingual/allocate_multilingual_examples.py \ + --random-lang false \ + --max-archives 1 --num-jobs 1 \ + --minibatch-size $minibatch_size \ + --prefix "combine." \ + $num_langs "$combine_scp_list" $megs_dir || exit 1; + + echo "$0: combine train_diagnostic.egs.scp examples from all langs in $megs_dir/train_diagnostic.egs.scp." + # Generate train_diagnostic.egs.scp for multilingual setup. + $cmd $megs_dir/log/allocate_multilingual_examples_train_diagnostic.log \ + python steps/nnet3/multilingual/allocate_multilingual_examples.py \ + --random-lang false \ + --max-archives 1 --num-jobs 1 \ + --minibatch-size $minibatch_size \ + --prefix "train_diagnostic." \ + $num_langs "$train_diagnostic_scp_list" $megs_dir || exit 1; + + + echo "$0: combine valid_diagnostic.egs.scp examples from all langs in $megs_dir/valid_diagnostic.egs.scp." + # Generate valid_diagnostic.egs.scp for multilingual setup. + $cmd $megs_dir/log/allocate_multilingual_examples_valid_diagnostic.log \ + python steps/nnet3/multilingual/allocate_multilingual_examples.py \ + --random-lang false --max-archives 1 --num-jobs 1\ + --minibatch-size $minibatch_size \ + --prefix "valid_diagnostic." \ + $num_langs "$valid_diagnostic_scp_list" $megs_dir || exit 1; + +fi + diff --git a/egs/wsj/s5/steps/nnet3/train_dnn.py b/egs/wsj/s5/steps/nnet3/train_dnn.py index 2f324512114..8388afd0188 100755 --- a/egs/wsj/s5/steps/nnet3/train_dnn.py +++ b/egs/wsj/s5/steps/nnet3/train_dnn.py @@ -232,7 +232,7 @@ def train(args, run_opts, background_process_handler): [egs_left_context, egs_right_context, frames_per_eg_str, num_archives] = ( - common_train_lib.verify_egs_dir(egs_dir, feat_dim, + common_train_lib.verify_egs_dir(egs_dir, feat_dim, ivector_dim, ivector_id, left_context, right_context)) assert(str(args.frames_per_eg) == frames_per_eg_str) @@ -292,6 +292,10 @@ def learning_rate(iter, current_num_jobs, num_archives_processed): args.initial_effective_lrate, args.final_effective_lrate) + if args.dropout_schedule is not None: + dropout_schedule = common_train_lib.parse_dropout_option( + num_archives_to_process, args.dropout_schedule) + logger.info("Training will run for {0} epochs = " "{1} iterations".format(args.num_epochs, num_iters)) @@ -411,14 +415,14 @@ def main(): polling_time=args.background_polling_time) train(args, run_opts, background_process_handler) background_process_handler.ensure_processes_are_done() - except Exception as e: + except Exception: if args.email is not None: message = ("Training session for experiment {dir} " "died due to an error.".format(dir=args.dir)) common_lib.send_mail(message, message, args.email) - traceback.print_exc() background_process_handler.stop() - raise e + logger.error("Training session failed; traceback = ", exc_info=True) + raise SystemExit(1) if __name__ == "__main__": diff --git a/egs/wsj/s5/steps/nnet3/train_raw_dnn.py b/egs/wsj/s5/steps/nnet3/train_raw_dnn.py index a10b7eb604a..2c7bc882597 100755 --- a/egs/wsj/s5/steps/nnet3/train_raw_dnn.py +++ b/egs/wsj/s5/steps/nnet3/train_raw_dnn.py @@ -53,6 +53,9 @@ def get_args(): parser.add_argument("--egs.frames-per-eg", type=int, dest='frames_per_eg', default=8, help="Number of output labels per example") + parser.add_argument("--egs.extra-copy-cmd", type=str, + dest='extra_egs_copy_cmd', default = "", + help="""Modify egs before passing it to training"""); # trainer options parser.add_argument("--trainer.prior-subset-size", type=int, @@ -247,7 +250,7 @@ def train(args, run_opts, background_process_handler): [egs_left_context, egs_right_context, frames_per_eg_str, num_archives] = ( - common_train_lib.verify_egs_dir(egs_dir, feat_dim, + common_train_lib.verify_egs_dir(egs_dir, feat_dim, ivector_dim, ivector_id, left_context, right_context)) assert(str(args.frames_per_eg) == frames_per_eg_str) @@ -296,6 +299,10 @@ def learning_rate(iter, current_num_jobs, num_archives_processed): args.initial_effective_lrate, args.final_effective_lrate) + if args.dropout_schedule is not None: + dropout_schedule = common_train_lib.parse_dropout_option( + num_archives_to_process, args.dropout_schedule) + logger.info("Training will run for {0} epochs = " "{1} iterations".format(args.num_epochs, num_iters)) @@ -333,7 +340,8 @@ def learning_rate(iter, current_num_jobs, num_archives_processed): shuffle_buffer_size=args.shuffle_buffer_size, run_opts=run_opts, get_raw_nnet_from_am=False, - background_process_handler=background_process_handler) + background_process_handler=background_process_handler, + extra_egs_copy_cmd=args.extra_egs_copy_cmd) if args.cleanup: # do a clean up everythin but the last 2 models, under certain @@ -365,6 +373,7 @@ def learning_rate(iter, current_num_jobs, num_archives_processed): minibatch_size_str=args.minibatch_size, run_opts=run_opts, background_process_handler=background_process_handler, get_raw_nnet_from_am=False, + extra_egs_copy_cmd=args.extra_egs_copy_cmd, sum_to_one_penalty=args.combine_sum_to_one_penalty) if include_log_softmax and args.stage <= num_iters + 1: @@ -375,7 +384,8 @@ def learning_rate(iter, current_num_jobs, num_archives_processed): num_archives=num_archives, left_context=left_context, right_context=right_context, prior_subset_size=args.prior_subset_size, run_opts=run_opts, - get_raw_nnet_from_am=False) + get_raw_nnet_from_am=False, + extra_egs_copy_cmd=args.extra_egs_copy_cmd) if args.cleanup: logger.info("Cleaning up the experiment directory " diff --git a/egs/wsj/s5/steps/nnet3/train_raw_rnn.py b/egs/wsj/s5/steps/nnet3/train_raw_rnn.py index 272485b898a..ec6442ba787 100755 --- a/egs/wsj/s5/steps/nnet3/train_raw_rnn.py +++ b/egs/wsj/s5/steps/nnet3/train_raw_rnn.py @@ -65,6 +65,9 @@ def get_args(): should halve --trainer.samples-per-iter. May be a comma-separated list of alternatives: first width is the 'principal' chunk-width, used preferentially""") + parser.add_argument("--egs.extra-copy-cmd", type=str, + dest='extra_egs_copy_cmd', default = "", + help="""Modify egs before passing it to training"""); # trainer options parser.add_argument("--trainer.samples-per-iter", type=int, @@ -237,12 +240,18 @@ def train(args, run_opts, background_process_handler): # discriminative pretraining num_hidden_layers = variables['num_hidden_layers'] add_lda = common_lib.str_to_bool(variables['add_lda']) - include_log_softmax = common_lib.str_to_bool( - variables['include_log_softmax']) except KeyError as e: raise Exception("KeyError {0}: Variables need to be defined in " "{1}".format(str(e), '{0}/configs'.format(args.dir))) + try: + include_log_softmax = common_lib.str_to_bool( + variables['include_log_softmax']) + except KeyError as e: + logger.warning("KeyError {0}: Using default include-log-softmax value " + "as False.".format(str(e))) + include_log_softmax = False + left_context = args.chunk_left_context + model_left_context right_context = args.chunk_right_context + model_right_context left_context_initial = (args.chunk_left_context_initial + model_left_context if @@ -363,6 +372,10 @@ def learning_rate(iter, current_num_jobs, num_archives_processed): args.initial_effective_lrate, args.final_effective_lrate) + if args.dropout_schedule is not None: + dropout_schedule = common_train_lib.parse_dropout_option( + num_archives_to_process, args.dropout_schedule) + min_deriv_time = None max_deriv_time_relative = None if args.deriv_truncate_margin is not None: @@ -421,7 +434,11 @@ def learning_rate(iter, current_num_jobs, num_archives_processed): shuffle_buffer_size=args.shuffle_buffer_size, run_opts=run_opts, get_raw_nnet_from_am=False, - background_process_handler=background_process_handler) + background_process_handler=background_process_handler, + extra_egs_copy_cmd=args.extra_egs_copy_cmd, + use_multitask_egs=args.use_multitask_egs, + rename_multitask_outputs=args.rename_multitask_outputs, + compute_per_dim_accuracy=args.compute_per_dim_accuracy) if args.cleanup: # do a clean up everythin but the last 2 models, under certain @@ -446,6 +463,9 @@ def learning_rate(iter, current_num_jobs, num_archives_processed): if args.stage <= num_iters: logger.info("Doing final combination to produce final.raw") + common_lib.run_kaldi_command( + "cp {dir}/{num_iters}.raw {dir}/pre_combine.raw" + "".format(dir=args.dir, num_iters=num_iters)) train_lib.common.combine_models( dir=args.dir, num_iters=num_iters, models_to_combine=models_to_combine, egs_dir=egs_dir, @@ -454,6 +474,8 @@ def learning_rate(iter, current_num_jobs, num_archives_processed): run_opts=run_opts, chunk_width=args.chunk_width, background_process_handler=background_process_handler, get_raw_nnet_from_am=False, + extra_egs_copy_cmd=args.extra_egs_copy_cmd, + compute_per_dim_accuracy=args.compute_per_dim_accuracy, sum_to_one_penalty=args.combine_sum_to_one_penalty) if include_log_softmax and args.stage <= num_iters + 1: @@ -464,7 +486,8 @@ def learning_rate(iter, current_num_jobs, num_archives_processed): num_archives=num_archives, left_context=left_context, right_context=right_context, prior_subset_size=args.prior_subset_size, run_opts=run_opts, - get_raw_nnet_from_am=False) + get_raw_nnet_from_am=False, + extra_egs_copy_cmd=args.extra_egs_copy_cmd) if args.cleanup: logger.info("Cleaning up the experiment directory " diff --git a/egs/wsj/s5/steps/nnet3/train_rnn.py b/egs/wsj/s5/steps/nnet3/train_rnn.py index 6636513e03d..3aa9301b5f4 100755 --- a/egs/wsj/s5/steps/nnet3/train_rnn.py +++ b/egs/wsj/s5/steps/nnet3/train_rnn.py @@ -167,6 +167,7 @@ def process_args(args): "directory which is the output of " "make_configs.py script") + if args.transform_dir is None: args.transform_dir = args.ali_dir @@ -296,7 +297,7 @@ def train(args, run_opts, background_process_handler): [egs_left_context, egs_right_context, frames_per_eg_str, num_archives] = ( - common_train_lib.verify_egs_dir(egs_dir, feat_dim, + common_train_lib.verify_egs_dir(egs_dir, feat_dim, ivector_dim, ivector_id, left_context, right_context, left_context_initial, right_context_final)) @@ -359,6 +360,10 @@ def learning_rate(iter, current_num_jobs, num_archives_processed): args.initial_effective_lrate, args.final_effective_lrate) + if args.dropout_schedule is not None: + dropout_schedule = common_train_lib.parse_dropout_option( + num_archives_to_process, args.dropout_schedule) + min_deriv_time = None max_deriv_time_relative = None if args.deriv_truncate_margin is not None: @@ -407,8 +412,8 @@ def learning_rate(iter, current_num_jobs, num_archives_processed): minibatch_size_str=args.num_chunk_per_minibatch, num_hidden_layers=num_hidden_layers, add_layers_period=args.add_layers_period, - left_context=left_context, - right_context=right_context, + max_left_context = left_context, + max_right_context = right_context, min_deriv_time=min_deriv_time, max_deriv_time_relative=max_deriv_time_relative, momentum=args.momentum, diff --git a/egs/wsj/s5/steps/nnet3/xconfig_to_configs.py b/egs/wsj/s5/steps/nnet3/xconfig_to_configs.py index 7e876bda1ed..27eadeef05c 100755 --- a/egs/wsj/s5/steps/nnet3/xconfig_to_configs.py +++ b/egs/wsj/s5/steps/nnet3/xconfig_to_configs.py @@ -30,6 +30,9 @@ def get_args(): help='Filename of input xconfig file') parser.add_argument('--config-dir', required=True, help='Directory to write config files and variables') + parser.add_argument('--nnet-edits', type=str, default=None, + action=common_lib.NullstrToNoneAction, + help="Edit network before getting nnet3-info") print(' '.join(sys.argv)) @@ -209,13 +212,19 @@ def write_config_files(config_dir, all_layers): raise -def add_back_compatibility_info(config_dir): +def add_back_compatibility_info(config_dir, nnet_edits=None): """This will be removed when python script refactoring is done.""" common_lib.run_kaldi_command("nnet3-init {0}/ref.config " "{0}/ref.raw".format(config_dir)) - out, err = common_lib.run_kaldi_command("nnet3-info {0}/ref.raw | " - "head -4".format(config_dir)) + model = "{0}/ref.raw".format(config_dir) + if nnet_edits is not None: + model = """nnet3-copy --edits='{0}' {1} - |""".format(nnet_edits, + model) + + print("""nnet3-info "{0}" | head -4""".format(model), file=sys.stderr) + out, err = common_lib.run_kaldi_command("""nnet3-info "{0}" | """ + """head -4""".format(model)) # out looks like this # left-context: 7 # right-context: 0 @@ -284,7 +293,7 @@ def main(): write_expanded_xconfig_files(args.config_dir, all_layers) write_config_files(args.config_dir, all_layers) check_model_contexts(args.config_dir) - add_back_compatibility_info(args.config_dir) + add_back_compatibility_info(args.config_dir, args.nnet_edits) if __name__ == '__main__': diff --git a/egs/wsj/s5/steps/online/nnet2/extract_ivectors.sh b/egs/wsj/s5/steps/online/nnet2/extract_ivectors.sh index 53026b840bd..cf51a5cf091 100755 --- a/egs/wsj/s5/steps/online/nnet2/extract_ivectors.sh +++ b/egs/wsj/s5/steps/online/nnet2/extract_ivectors.sh @@ -172,8 +172,8 @@ if [ $sub_speaker_frames -gt 0 ]; then feat-to-len scp:$data/feats.scp ark,t:- > $dir/utt_counts || exit 1; fi if ! [ $(wc -l <$dir/utt_counts) -eq $(wc -l <$data/feats.scp) ]; then - echo "$0: error getting per-utterance counts." - exit 0; + echo "$0: error getting per-utterance counts. Number of lines in $dir/utt_counts differs from $data/feats.scp" + exit 1; fi cat $data/spk2utt | python -c " import sys @@ -229,8 +229,8 @@ if [ $stage -le 2 ]; then if [ ! -z "$ali_or_decode_dir" ]; then $cmd JOB=1:$nj $dir/log/extract_ivectors.JOB.log \ gmm-global-get-post --n=$num_gselect --min-post=$min_post $srcdir/final.dubm "$gmm_feats" ark:- \| \ - weight-post ark:- "ark,s,cs:gunzip -c $dir/weights.gz|" ark:- \| \ - ivector-extract --acoustic-weight=$posterior_scale --compute-objf-change=true \ + weight-post --length-tolerance=1 ark:- "ark,s,cs:gunzip -c $dir/weights.gz|" ark:- \| \ + ivector-extract --length-tolerance=1 --acoustic-weight=$posterior_scale --compute-objf-change=true \ --max-count=$max_count --spk2utt=ark:$this_sdata/JOB/spk2utt \ $srcdir/final.ie "$feats" ark,s,cs:- ark,t:$dir/ivectors_spk.JOB.ark || exit 1; else diff --git a/egs/wsj/s5/steps/online/nnet2/extract_ivectors_online.sh b/egs/wsj/s5/steps/online/nnet2/extract_ivectors_online.sh index f4d908e9446..74db006906f 100755 --- a/egs/wsj/s5/steps/online/nnet2/extract_ivectors_online.sh +++ b/egs/wsj/s5/steps/online/nnet2/extract_ivectors_online.sh @@ -42,6 +42,9 @@ max_count=0 # The use of this option (e.g. --max-count 100) can make # posterior-scaling, so assuming the posterior-scale is 0.1, # --max-count 100 starts having effect after 1000 frames, or # 10 seconds of data. +weights= +use_most_recent_ivector=true +max_remembered_frames=1000 # End configuration section. @@ -89,6 +92,8 @@ splice_opts=$(cat $srcdir/splice_opts) # involved in online decoding. We need to create a config file for iVector # extraction. +absdir=$(readlink -f $dir) + ieconf=$dir/conf/ivector_extractor.conf echo -n >$ieconf cp $srcdir/online_cmvn.conf $dir/conf/ || exit 1; @@ -103,12 +108,19 @@ echo "--ivector-extractor=$srcdir/final.ie" >>$ieconf echo "--num-gselect=$num_gselect" >>$ieconf echo "--min-post=$min_post" >>$ieconf echo "--posterior-scale=$posterior_scale" >>$ieconf -echo "--max-remembered-frames=1000" >>$ieconf # the default +echo "--max-remembered-frames=$max_remembered_frames" >>$ieconf # the default echo "--max-count=$max_count" >>$ieconf +echo "--use-most-recent-ivector=$use_most_recent_ivector" >>$use_most_recent_ivector +if [ ! -z "$weights" ]; then + if [ -f $weights ] && gunzip -c $weights > /dev/null; then + cp -f $weights $absdir/weights.gz || exit 1 + else + echo "Could not open file $weights" + exit 1 + fi +fi -absdir=$(readlink -f $dir) - for n in $(seq $nj); do # This will do nothing unless the directory $dir/storage exists; # it can be used to distribute the data among multiple machines. @@ -117,10 +129,21 @@ done if [ $stage -le 0 ]; then echo "$0: extracting iVectors" - $cmd JOB=1:$nj $dir/log/extract_ivectors.JOB.log \ - ivector-extract-online2 --config=$ieconf ark:$sdata/JOB/spk2utt scp:$sdata/JOB/feats.scp ark:- \| \ - copy-feats --compress=$compress ark:- \ + if [ ! -z "$weights" ]; then + $cmd JOB=1:$nj $dir/log/extract_ivectors.JOB.log \ + ivector-extract-online2 --config=$ieconf \ + --frame-weights-rspecifier="ark:gunzip -c $absdir/weights.gz |" \ + --length-tolerance=1 \ + ark:$sdata/JOB/spk2utt scp:$sdata/JOB/feats.scp ark:- \| \ + copy-feats --compress=$compress ark:- \ + ark,scp:$absdir/ivector_online.JOB.ark,$absdir/ivector_online.JOB.scp || exit 1; + else + $cmd JOB=1:$nj $dir/log/extract_ivectors.JOB.log \ + ivector-extract-online2 --config=$ieconf \ + ark:$sdata/JOB/spk2utt scp:$sdata/JOB/feats.scp ark:- \| \ + copy-feats --compress=$compress ark:- \ ark,scp:$absdir/ivector_online.JOB.ark,$absdir/ivector_online.JOB.scp || exit 1; + fi fi if [ $stage -le 1 ]; then diff --git a/egs/wsj/s5/steps/resolve_ctm_overlaps.py b/egs/wsj/s5/steps/resolve_ctm_overlaps.py new file mode 100755 index 00000000000..aaee767e7e4 --- /dev/null +++ b/egs/wsj/s5/steps/resolve_ctm_overlaps.py @@ -0,0 +1,149 @@ +#!/usr/bin/env python +# Copyright 2014 Johns Hopkins University (Authors: Daniel Povey, Vijayaditya Peddinti). +# 2016 Vimal Manohar +# Apache 2.0. + +# Script to combine ctms with overlapping segments + +import sys, math, numpy as np, argparse +break_threshold = 0.01 + +def ReadSegments(segments_file): + segments = {} + for line in open(segments_file).readlines(): + parts = line.strip().split() + segments[parts[0]] = (parts[1], float(parts[2]), float(parts[3])) + return segments + +#def get_breaks(ctm, prev_end): +# breaks = [] +# for i in xrange(0, len(ctm)): +# if ctm[i][2] - prev_end > break_threshold: +# breaks.append([i, ctm[i][2]]) +# prev_end = ctm[i][2] + ctm[i][3] +# return np.array(breaks) + +# Resolve overlaps within segments of the same recording +def ResolveOverlaps(ctms, segments): + total_ctm = [] + if len(ctms) == 0: + raise Exception('Something wrong with the input ctms') + + next_utt = ctms[0][0][0] + for ctm_index in range(len(ctms) - 1): + # Assumption here is that the segments are written in consecutive order? + cur_ctm = ctms[ctm_index] + next_ctm = ctms[ctm_index + 1] + + cur_utt = next_utt + next_utt = next_ctm[0][0] + if (next_utt not in segments): + raise Exception('Could not find utterance %s in segments' % next_utt) + + if len(cur_ctm) > 0: + assert(cur_utt == cur_ctm[0][0]) + + assert(next_utt > cur_utt) + if (cur_utt not in segments): + raise Exception('Could not find utterance %s in segments' % cur_utt) + + # length of this segment + window_length = segments[cur_utt][2] - segments[cur_utt][1] + + # overlap of this segment with the next segment + # Note: It is possible for this to be negative when there is actually + # no overlap between consecutive segments. + overlap = segments[cur_utt][2] - segments[next_utt][1] + + # find the breaks after overlap starts + index = len(cur_ctm) + + for i in xrange(len(cur_ctm)): + if (cur_ctm[i][2] + cur_ctm[i][3]/2.0 > (window_length - overlap/2.0)): + # if midpoint of a hypothesis word is beyond the midpoint of the + # overlap region + index = i + break + + # Ignore the hypotheses beyond this midpoint. They will be considered as + # part of the next segment. + total_ctm += cur_ctm[:index] + + # Ignore the hypotheses of the next utterance that overlaps with the + # current utterance + index = -1 + for i in xrange(len(next_ctm)): + if (next_ctm[i][2] + next_ctm[i][3]/2.0 > (overlap/2.0)): + index = i + break + + if index >= 0: + ctms[ctm_index + 1] = next_ctm[index:] + else: + ctms[ctm_index + 1] = [] + + # merge the last ctm entirely + total_ctm += ctms[-1] + + return total_ctm + +def ReadCtm(ctm_file_lines, segments): + ctms = {} + for key in [ x[0] for x in segments.values() ]: + ctms[key] = [] + + ctm = [] + prev_utt = ctm_file_lines[0].split()[0] + for line in ctm_file_lines: + parts = line.split() + if (prev_utt == parts[0]): + ctm.append([parts[0], parts[1], float(parts[2]), + float(parts[3])] + parts[4:]) + else: + # New utterance. Append the previous utterance's CTM + # into the list for the utterance's recording + ctms[segments[ctm[0][0]][0]].append(ctm) + + assert(parts[0] > prev_utt) + + prev_utt = parts[0] + ctm = [] + ctm.append([parts[0], parts[1], float(parts[2]), + float(parts[3])] + parts[4:]) + + # append the last ctm + ctms[segments[ctm[0][0]][0]].append(ctm) + return ctms + +def WriteCtm(ctm_lines, out_file): + for line in ctm_lines: + out_file.write("{0} {1} {2} {3} {4}\n".format(line[0], line[1], line[2], line[3], " ".join(line[4:]))) + +if __name__ == "__main__": + usage = """ Python script to resolve overlaps in ctms """ + parser = argparse.ArgumentParser(usage) + parser.add_argument('segments', type=str, help = 'use segments to resolve overlaps') + parser.add_argument('ctm_in', type=str, help='input_ctm_file') + parser.add_argument('ctm_out', type=str, help='output_ctm_file') + params = parser.parse_args() + + if params.ctm_in == "-": + params.ctm_in = sys.stdin + else: + params.ctm_in = open(params.ctm_in) + if params.ctm_out == "-": + params.ctm_out = sys.stdout + else: + params.ctm_out = open(params.ctm_out, 'w') + + segments = ReadSegments(params.segments) + + # Read CTMs into a dictionary indexed by the recording + ctms = ReadCtm(params.ctm_in.readlines(), segments) + + for key in sorted(ctms.keys()): + # Process CTMs in the sorted order of recordings + ctm_reco = ctms[key] + ctm_reco = ResolveOverlaps(ctm_reco, segments) + WriteCtm(ctm_reco, params.ctm_out) + params.ctm_out.close() diff --git a/egs/wsj/s5/steps/resolve_ctm_overlaps.py.old b/egs/wsj/s5/steps/resolve_ctm_overlaps.py.old new file mode 100755 index 00000000000..aaee767e7e4 --- /dev/null +++ b/egs/wsj/s5/steps/resolve_ctm_overlaps.py.old @@ -0,0 +1,149 @@ +#!/usr/bin/env python +# Copyright 2014 Johns Hopkins University (Authors: Daniel Povey, Vijayaditya Peddinti). +# 2016 Vimal Manohar +# Apache 2.0. + +# Script to combine ctms with overlapping segments + +import sys, math, numpy as np, argparse +break_threshold = 0.01 + +def ReadSegments(segments_file): + segments = {} + for line in open(segments_file).readlines(): + parts = line.strip().split() + segments[parts[0]] = (parts[1], float(parts[2]), float(parts[3])) + return segments + +#def get_breaks(ctm, prev_end): +# breaks = [] +# for i in xrange(0, len(ctm)): +# if ctm[i][2] - prev_end > break_threshold: +# breaks.append([i, ctm[i][2]]) +# prev_end = ctm[i][2] + ctm[i][3] +# return np.array(breaks) + +# Resolve overlaps within segments of the same recording +def ResolveOverlaps(ctms, segments): + total_ctm = [] + if len(ctms) == 0: + raise Exception('Something wrong with the input ctms') + + next_utt = ctms[0][0][0] + for ctm_index in range(len(ctms) - 1): + # Assumption here is that the segments are written in consecutive order? + cur_ctm = ctms[ctm_index] + next_ctm = ctms[ctm_index + 1] + + cur_utt = next_utt + next_utt = next_ctm[0][0] + if (next_utt not in segments): + raise Exception('Could not find utterance %s in segments' % next_utt) + + if len(cur_ctm) > 0: + assert(cur_utt == cur_ctm[0][0]) + + assert(next_utt > cur_utt) + if (cur_utt not in segments): + raise Exception('Could not find utterance %s in segments' % cur_utt) + + # length of this segment + window_length = segments[cur_utt][2] - segments[cur_utt][1] + + # overlap of this segment with the next segment + # Note: It is possible for this to be negative when there is actually + # no overlap between consecutive segments. + overlap = segments[cur_utt][2] - segments[next_utt][1] + + # find the breaks after overlap starts + index = len(cur_ctm) + + for i in xrange(len(cur_ctm)): + if (cur_ctm[i][2] + cur_ctm[i][3]/2.0 > (window_length - overlap/2.0)): + # if midpoint of a hypothesis word is beyond the midpoint of the + # overlap region + index = i + break + + # Ignore the hypotheses beyond this midpoint. They will be considered as + # part of the next segment. + total_ctm += cur_ctm[:index] + + # Ignore the hypotheses of the next utterance that overlaps with the + # current utterance + index = -1 + for i in xrange(len(next_ctm)): + if (next_ctm[i][2] + next_ctm[i][3]/2.0 > (overlap/2.0)): + index = i + break + + if index >= 0: + ctms[ctm_index + 1] = next_ctm[index:] + else: + ctms[ctm_index + 1] = [] + + # merge the last ctm entirely + total_ctm += ctms[-1] + + return total_ctm + +def ReadCtm(ctm_file_lines, segments): + ctms = {} + for key in [ x[0] for x in segments.values() ]: + ctms[key] = [] + + ctm = [] + prev_utt = ctm_file_lines[0].split()[0] + for line in ctm_file_lines: + parts = line.split() + if (prev_utt == parts[0]): + ctm.append([parts[0], parts[1], float(parts[2]), + float(parts[3])] + parts[4:]) + else: + # New utterance. Append the previous utterance's CTM + # into the list for the utterance's recording + ctms[segments[ctm[0][0]][0]].append(ctm) + + assert(parts[0] > prev_utt) + + prev_utt = parts[0] + ctm = [] + ctm.append([parts[0], parts[1], float(parts[2]), + float(parts[3])] + parts[4:]) + + # append the last ctm + ctms[segments[ctm[0][0]][0]].append(ctm) + return ctms + +def WriteCtm(ctm_lines, out_file): + for line in ctm_lines: + out_file.write("{0} {1} {2} {3} {4}\n".format(line[0], line[1], line[2], line[3], " ".join(line[4:]))) + +if __name__ == "__main__": + usage = """ Python script to resolve overlaps in ctms """ + parser = argparse.ArgumentParser(usage) + parser.add_argument('segments', type=str, help = 'use segments to resolve overlaps') + parser.add_argument('ctm_in', type=str, help='input_ctm_file') + parser.add_argument('ctm_out', type=str, help='output_ctm_file') + params = parser.parse_args() + + if params.ctm_in == "-": + params.ctm_in = sys.stdin + else: + params.ctm_in = open(params.ctm_in) + if params.ctm_out == "-": + params.ctm_out = sys.stdout + else: + params.ctm_out = open(params.ctm_out, 'w') + + segments = ReadSegments(params.segments) + + # Read CTMs into a dictionary indexed by the recording + ctms = ReadCtm(params.ctm_in.readlines(), segments) + + for key in sorted(ctms.keys()): + # Process CTMs in the sorted order of recordings + ctm_reco = ctms[key] + ctm_reco = ResolveOverlaps(ctm_reco, segments) + WriteCtm(ctm_reco, params.ctm_out) + params.ctm_out.close() diff --git a/egs/wsj/s5/steps/segmentation/cluster_segments_aIB.sh b/egs/wsj/s5/steps/segmentation/cluster_segments_aIB.sh new file mode 100755 index 00000000000..7cf151f1ad0 --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/cluster_segments_aIB.sh @@ -0,0 +1,154 @@ +#! /bin/bash + +window=2.5 +overlap=0.0 +stage=-1 +cmd=queue.pl +reco_nj=4 +frame_shift=0.01 +utt_nj=18 +min_clusters=10 +clustering_opts="--stopping-threshold=0.5 --max-merge-thresh=0.25 --normalize-by-entropy" + +. path.sh +. utils/parse_options.sh + +set -o pipefail +set -e +set -u + +if [ $# -ne 3 ]; then + echo "Usage: $0 " + exit 1 +fi + +data=$1 +dir=$2 +out_data=$3 + +num_frames=`perl -e "print int($window / $frame_shift + 0.5)"` +num_frames_overlap=`perl -e "print int($overlap/ $frame_shift + 0.5)"` + +data_uniform_seg=$dir/`basename ${data}`_uniform_seg_window${window}_ovlp${overlap} + +mkdir -p ${data_uniform_seg} + +mkdir -p $dir + +#segmentation-cluster-adjacent-segments --verbose=0 'ark:segmentation-copy --keep-label=1 "ark:gunzip -c exp/nnet3_lstm_sad_music/nnet_lstm_1e//segmentation_bn_eval97_whole_bp/orig_segmentation.1.gz |" ark:- | segmentation-split-segments --max-segment-length=250 --overlap-length=0 ark:- ark:- |' scp:data/bn_eval97_bp_hires/feats.scp "ark:| segmentation-post-process --merge-adjacent-segments ark:- ark:- | segmentation-to-segments ark:- ark,t:- /dev/null" 2>&1 | less + +if [ $stage -le 0 ]; then + $cmd $dir/log/get_subsegments.log \ + segmentation-init-from-segments --frame-overlap=0.015 $data/segments ark:- \| \ + segmentation-split-segments --max-segment-length=$num_frames --overlap-length=$num_frames_overlap ark:- ark:- \| \ + segmentation-to-segments --frame-overlap=0.0 ark:- ark:/dev/null \ + ${data_uniform_seg}/sub_segments + + utils/data/subsegment_data_dir.sh ${data} ${data_uniform_seg}{/sub_segments,} +fi + +gmm_dir=$dir/gmms +mkdir -p $gmm_dir + +utils/split_data.sh --per-reco ${data_uniform_seg} $reco_nj + +if [ $stage -le 1 ]; then + echo $reco_nj > $gmm_dir/num_jobs + $cmd JOB=1:$reco_nj $gmm_dir/log/train_gmm.JOB.log \ + gmm-global-init-models-from-feats --share-covars=true \ + --spk2utt-rspecifier=ark,t:${data_uniform_seg}/split${reco_nj}reco/JOB/reco2utt \ + --num-gauss-init=64 --num-gauss=64 --num-gauss-fraction=0.001 --max-gauss=512 --min-gauss=64 \ + --num-iters=20 --num-frames=500000 \ + scp:${data_uniform_seg}/split${reco_nj}reco/JOB/feats.scp \ + ark,scp:$gmm_dir/gmm.JOB.ark,$gmm_dir/gmm.JOB.scp + + for n in `seq $reco_nj`; do + cat $gmm_dir/gmm.$n.scp + done > $gmm_dir/gmm.scp + +fi + +post_dir=$gmm_dir/post_`basename $data_uniform_seg` +mkdir -p $post_dir + +if [ $stage -le 2 ]; then + echo $reco_nj > $post_dir/num_jobs + + $cmd JOB=1:$reco_nj $gmm_dir/log/compute_post.JOB.log \ + gmm-global-get-post \ + --utt2spk="ark,t:cut -d ' ' -f 1,2 ${data_uniform_seg}/split${reco_nj}reco/JOB/segments |" \ + scp:$gmm_dir/gmm.scp \ + scp:${data_uniform_seg}/split${reco_nj}reco/JOB/feats.scp \ + "ark:| gzip -c > $post_dir/post.JOB.gz" \ + "ark:| gzip -c > $post_dir/frame_loglikes.JOB.gz" +fi + +if [ $stage -le 3 ]; then + utils/data/get_utt2num_frames.sh --nj $utt_nj --cmd "$cmd" ${data_uniform_seg} + + $cmd JOB=1:$reco_nj $post_dir/log/compute_average_post.JOB.log \ + gmm-global-post-to-feats \ + --utt2spk="ark,t:cut -d ' ' -f 1,2 ${data_uniform_seg}/split${reco_nj}reco/JOB/segments |" \ + scp:$gmm_dir/gmm.scp "ark:gunzip -c $post_dir/post.JOB.gz |" ark:- \| \ + matrix-sum-rows --do-average ark:- "ark:| gzip -c > $post_dir/avg_post.JOB.gz" +fi + +seg_dir=$dir/segmentation_`basename $data_uniform_seg` + +if [ $stage -le 4 ]; then + $cmd JOB=1:$reco_nj $seg_dir/log/compute_scores.JOB.log \ + ib-scoring-dense --input-factor=0.0 $clustering_opts \ + --counts-rspecifier="ark,t:utils/filter_scp.pl $data_uniform_seg/split${reco_nj}reco/JOB/utt2spk $data_uniform_seg/utt2num_frames |" \ + "ark,t:${data_uniform_seg}/split${reco_nj}reco/JOB/reco2utt" \ + "ark:gunzip -c $post_dir/avg_post.JOB.gz |" \ + ark,t:$seg_dir/scores.JOB.txt ark:/dev/null +fi + +if [ $stage -le 5 ]; then + threshold=$(for n in `seq $reco_nj`; do + /export/a12/vmanoha1/kaldi-diarization-v2/src/ivectorbin/compute-calibration \ + ark,t:$seg_dir/scores.$n.txt -; done | \ + awk '{i += $1; j++;} END{print i / j}') + echo $threshold > $seg_dir/threshold +fi + +threshold=$(cat $seg_dir/threshold) +if [ $stage -le 6 ]; then + $cmd JOB=1:$reco_nj $seg_dir/log/cluster_segments.JOB.log \ + agglomerative-cluster-ib --input-factor=0.0 --min-clusters=$min_clusters $clustering_opts \ + --max-merge-thresh=$threshold --verbose=3 \ + --counts-rspecifier="ark,t:utils/filter_scp.pl $data_uniform_seg/split${reco_nj}reco/JOB/utt2spk $data_uniform_seg/utt2num_frames |" \ + "ark:gunzip -c $post_dir/avg_post.JOB.gz |" \ + "ark,t:${data_uniform_seg}/split${reco_nj}reco/JOB/reco2utt" \ + ark,t:$seg_dir/utt2cluster_id.JOB +fi + +if [ $stage -le 7 ]; then + $cmd JOB=1:$reco_nj $seg_dir/log/init_segmentation.JOB.log \ + segmentation-init-from-segments --frame-overlap=0.0 --shift-to-zero=false \ + --utt2label-rspecifier=ark,t:${seg_dir}/utt2cluster_id.JOB \ + ${data_uniform_seg}/split${reco_nj}reco/JOB/segments ark:- \| \ + segmentation-combine-segments-to-recordings ark:- \ + ark,t:${data_uniform_seg}/split${reco_nj}reco/JOB/reco2utt \ + ark:- \| \ + segmentation-post-process --merge-adjacent-segments ark:- ark:- \| \ + segmentation-post-process --max-segment-length=1000 --overlap-length=250 ark:- ark:- \| \ + segmentation-to-segments ark:- ark,t:$seg_dir/utt2spk.JOB $seg_dir/segments.JOB +fi + +if [ $stage -le 8 ]; then + rm -r $out_data || true + utils/data/convert_data_dir_to_whole.sh $data $out_data + rm $out_data/{text,cmvn.scp} || true + + for n in `seq $reco_nj`; do + cat $seg_dir/utt2spk.$n + done > $out_data/utt2spk + + for n in `seq $reco_nj`; do + cat $seg_dir/segments.$n + done > $out_data/segments + + utils/utt2spk_to_spk2utt.pl $out_data/utt2spk > $out_data/spk2utt + utils/fix_data_dir.sh $out_data +fi diff --git a/egs/wsj/s5/steps/segmentation/cluster_segments_aIB_change_point.sh b/egs/wsj/s5/steps/segmentation/cluster_segments_aIB_change_point.sh new file mode 100755 index 00000000000..9ca3efb7b9a --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/cluster_segments_aIB_change_point.sh @@ -0,0 +1,161 @@ +#! /bin/bash + +window=2.5 +overlap=0.0 +stage=-1 +cmd=queue.pl +reco_nj=4 +frame_shift=0.01 +frame_overlap=0.0 +utt_nj=18 +min_clusters=10 +clustering_opts="--stopping-threshold=0.5 --max-merge-thresh=0.25 --normalize-by-entropy" + +. path.sh +. utils/parse_options.sh + +set -o pipefail +set -e +set -u + +if [ $# -ne 3 ]; then + echo "Usage: $0 " + exit 1 +fi + +data=$1 +dir=$2 +out_data=$3 + +num_frames=`perl -e "print int($window / $frame_shift + 0.5)"` +num_frames_overlap=`perl -e "print int($overlap/ $frame_shift + 0.5)"` + +data_id=`basename $data` +data_uniform_seg=$dir/${data_id}_uniform_seg_window${window}_ovlp${overlap} + +mkdir -p $dir + +#segmentation-cluster-adjacent-segments --verbose=0 'ark:segmentation-copy --keep-label=1 "ark:gunzip -c exp/nnet3_lstm_sad_music/nnet_lstm_1e//segmentation_bn_eval97_whole_bp/orig_segmentation.1.gz |" ark:- | segmentation-split-segments --max-segment-length=250 --overlap-length=0 ark:- ark:- |' scp:data/bn_eval97_bp_hires/feats.scp "ark:| segmentation-post-process --merge-adjacent-segments ark:- ark:- | segmentation-to-segments ark:- ark,t:- /dev/null" 2>&1 | less + +if [ $stage -le 0 ]; then + rm -r ${data_uniform_seg} || true + mkdir -p ${data_uniform_seg} + + $cmd $dir/log/get_subsegments.log \ + segmentation-init-from-segments --frame-overlap=$frame_overlap $data/segments ark:- \| \ + segmentation-split-segments --max-segment-length=$num_frames --overlap-length=$num_frames_overlap ark:- ark:- \| \ + segmentation-cluster-adjacent-segments --verbose=3 ark:- "scp:$data/feats.scp" ark:- \| \ + segmentation-post-process --merge-adjacent-segments ark:- ark:- \| \ + segmentation-to-segments --frame-overlap=0.0 ark:- ark:/dev/null \ + ${data_uniform_seg}/sub_segments + + utils/data/subsegment_data_dir.sh ${data} ${data_uniform_seg}{/sub_segments,} +fi + +gmm_dir=$dir/gmms +mkdir -p $gmm_dir + +utils/split_data.sh --per-reco ${data_uniform_seg} $reco_nj + +if [ $stage -le 1 ]; then + echo $reco_nj > $gmm_dir/num_jobs + $cmd JOB=1:$reco_nj $gmm_dir/log/train_gmm.JOB.log \ + gmm-global-init-models-from-feats --share-covars=true \ + --spk2utt-rspecifier=ark,t:${data_uniform_seg}/split${reco_nj}reco/JOB/reco2utt \ + --num-gauss-init=64 --num-gauss=64 --num-gauss-fraction=0.001 --max-gauss=512 --min-gauss=64 \ + --num-iters=20 --num-frames=500000 \ + scp:${data_uniform_seg}/split${reco_nj}reco/JOB/feats.scp \ + ark,scp:$gmm_dir/gmm.JOB.ark,$gmm_dir/gmm.JOB.scp + + for n in `seq $reco_nj`; do + cat $gmm_dir/gmm.$n.scp + done > $gmm_dir/gmm.scp + +fi + +post_dir=$gmm_dir/post_`basename $data_uniform_seg` +mkdir -p $post_dir + +if [ $stage -le 2 ]; then + echo $reco_nj > $post_dir/num_jobs + + $cmd JOB=1:$reco_nj $gmm_dir/log/compute_post.JOB.log \ + gmm-global-get-post \ + --utt2spk="ark,t:cut -d ' ' -f 1,2 ${data_uniform_seg}/split${reco_nj}reco/JOB/segments |" \ + scp:$gmm_dir/gmm.scp \ + scp:${data_uniform_seg}/split${reco_nj}reco/JOB/feats.scp \ + "ark:| gzip -c > $post_dir/post.JOB.gz" \ + "ark:| gzip -c > $post_dir/frame_loglikes.JOB.gz" +fi + +if [ $stage -le 3 ]; then + $cmd JOB=1:$reco_nj $post_dir/log/compute_average_post.JOB.log \ + gmm-global-post-to-feats \ + --utt2spk="ark,t:cut -d ' ' -f 1,2 ${data_uniform_seg}/split${reco_nj}reco/JOB/segments |" \ + scp:$gmm_dir/gmm.scp "ark:gunzip -c $post_dir/post.JOB.gz |" ark:- \| \ + matrix-sum-rows --do-average ark:- "ark:| gzip -c > $post_dir/avg_post.JOB.gz" +fi + +seg_dir=$dir/segmentation_`basename $data_uniform_seg` + +if [ $stage -le 4 ]; then + utils/data/get_utt2num_frames.sh --nj $utt_nj --cmd "$cmd" ${data_uniform_seg} + + $cmd JOB=1:$reco_nj $seg_dir/log/compute_scores.JOB.log \ + ib-scoring-dense --input-factor=0 $clustering_opts \ + --counts-rspecifier="ark,t:utils/filter_scp.pl $data_uniform_seg/split${reco_nj}reco/JOB/utt2spk $data_uniform_seg/utt2num_frames |" \ + "ark,t:${data_uniform_seg}/split${reco_nj}reco/JOB/reco2utt" \ + "ark:gunzip -c $post_dir/avg_post.JOB.gz |" \ + ark,t:$seg_dir/scores.JOB.txt ark:/dev/null +fi + +if [ $stage -le 5 ]; then + $cmd JOB=1:$reco_nj $seg_dir/log/calibrate.JOB.log \ + /export/a12/vmanoha1/kaldi-diarization-v2/src/ivectorbin/compute-calibration \ + ark,t:$seg_dir/scores.JOB.txt $seg_dir/threshold.JOB.txt + + threshold=$(for n in `seq $reco_nj`; do cat $seg_dir/threshold.$n.txt; done | \ + awk '{i += $1; j++;} END{print i / j}') + echo $threshold > $seg_dir/threshold +fi + +threshold=$(cat $seg_dir/threshold) +if [ $stage -le 6 ]; then + $cmd JOB=1:$reco_nj $seg_dir/log/cluster_segments.JOB.log \ + agglomerative-cluster-ib --input-factor=0.0 $clustering_opts \ + --max-merge-thresh=$threshold --verbose=3 \ + --counts-rspecifier="ark,t:utils/filter_scp.pl $data_uniform_seg/split${reco_nj}reco/JOB/utt2spk $data_uniform_seg/utt2num_frames |" \ + "ark:gunzip -c $post_dir/avg_post.JOB.gz |" \ + "ark,t:${data_uniform_seg}/split${reco_nj}reco/JOB/reco2utt" \ + ark,t:$seg_dir/utt2cluster_id.JOB +fi + +if [ $stage -le 7 ]; then + $cmd JOB=1:$reco_nj $seg_dir/log/init_segmentation.JOB.log \ + segmentation-init-from-segments --frame-overlap=0.0 --shift-to-zero=false \ + --utt2label-rspecifier=ark,t:${seg_dir}/utt2cluster_id.JOB \ + ${data_uniform_seg}/split${reco_nj}reco/JOB/segments ark:- \| \ + segmentation-combine-segments-to-recordings ark:- \ + ark,t:${data_uniform_seg}/split${reco_nj}reco/JOB/reco2utt \ + ark:- \| \ + segmentation-post-process --merge-adjacent-segments ark:- ark:- \| \ + segmentation-post-process --max-segment-length=1000 --overlap-length=250 ark:- ark:- \| \ + segmentation-to-segments ark:- ark,t:$seg_dir/utt2spk.JOB $seg_dir/segments.JOB +fi + +if [ $stage -le 8 ]; then + rm -r $out_data || true + utils/data/convert_data_dir_to_whole.sh $data $out_data + rm $out_data/{text,cmvn.scp} || true + + for n in `seq $reco_nj`; do + cat $seg_dir/utt2spk.$n + done > $out_data/utt2spk + + for n in `seq $reco_nj`; do + cat $seg_dir/segments.$n + done > $out_data/segments + + utils/utt2spk_to_spk2utt.pl $out_data/utt2spk > $out_data/spk2utt + utils/fix_data_dir.sh $out_data +fi diff --git a/egs/wsj/s5/steps/segmentation/convert_ali_to_vec.pl b/egs/wsj/s5/steps/segmentation/convert_ali_to_vec.pl new file mode 100755 index 00000000000..c0d1a9eeae2 --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/convert_ali_to_vec.pl @@ -0,0 +1,17 @@ +#! /usr/bin/perl + +# Converts a kaldi integer vector in text format to +# a kaldi vector in text format by adding a pair +# of square brackets around the data. +# Assumes the first column to be the utterance id. + +while (<>) { + chomp; + my @F = split; + + printf ("$F[0] [ "); + for (my $i = 1; $i <= $#F; $i++) { + printf ("$F[$i] "); + } + print ("]"); +} diff --git a/egs/wsj/s5/steps/segmentation/convert_rttm_to_utt2spk_and_segments.py b/egs/wsj/s5/steps/segmentation/convert_rttm_to_utt2spk_and_segments.py new file mode 100755 index 00000000000..23dc5a14f09 --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/convert_rttm_to_utt2spk_and_segments.py @@ -0,0 +1,79 @@ +#! /usr/bin/env python + +"""This script converts an RTTM with +speaker info into kaldi utt2spk and segments""" + +import argparse + +def get_args(): + parser = argparse.ArgumentParser( + description="""This script converts an RTTM with + speaker info into kaldi utt2spk and segments""") + parser.add_argument("--use-reco-id-as-spkr", type=str, + choices=["true", "false"], + help="Use the recording ID based on RTTM and " + "reco2file_and_channel as the speaker") + parser.add_argument("rttm_file", type=str, + help="""Input RTTM file. + The format of the RTTM file is + """ + """ """) + parser.add_argument("reco2file_and_channel", type=str, + help="""Input reco2file_and_channel. + The format is .""") + parser.add_argument("utt2spk", type=str, + help="Output utt2spk file") + parser.add_argument("segments", type=str, + help="Output segments file") + + args = parser.parse_args() + + args.use_reco_id_as_spkr = bool(args.use_reco_id_as_spkr == "true") + + return args + +def main(): + args = get_args() + + file_and_channel2reco = {} + for line in open(args.reco2file_and_channel): + parts = line.strip().split() + file_and_channel2reco[(parts[1], parts[2])] = parts[0] + + utt2spk_writer = open(args.utt2spk, 'w') + segments_writer = open(args.segments, 'w') + for line in open(args.rttm_file): + parts = line.strip().split() + if parts[0] != "SPEAKER": + continue + + file_id = parts[1] + channel = parts[2] + + try: + reco = file_and_channel2reco[(file_id, channel)] + except KeyError as e: + raise Exception("Could not find recording with " + "(file_id, channel) " + "= ({0},{1}) in {2}: {3}\n".format( + file_id, channel, + args.reco2file_and_channel, str(e))) + + start_time = float(parts[3]) + end_time = start_time + float(parts[4]) + + if args.use_reco_id_as_spkr: + spkr = reco + else: + spkr = parts[7] + + st = int(start_time * 100) + end = int(end_time * 100) + utt = "{0}-{1:06d}-{2:06d}".format(spkr, st, end) + + utt2spk_writer.write("{0} {1}\n".format(utt, spkr)) + segments_writer.write("{0} {1} {2:7.2f} {3:7.2f}\n".format( + utt, reco, start_time, end_time)) + +if __name__ == '__main__': + main() diff --git a/egs/wsj/s5/steps/segmentation/convert_utt2spk_and_segments_to_rttm.py b/egs/wsj/s5/steps/segmentation/convert_utt2spk_and_segments_to_rttm.py new file mode 100755 index 00000000000..1443259286b --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/convert_utt2spk_and_segments_to_rttm.py @@ -0,0 +1,65 @@ +#! /usr/bin/env python + +"""This script converts kaldi-style utt2spk and segments to an RTTM""" + +import argparse + +def get_args(): + parser = argparse.ArgumentParser( + description="""This script converts kaldi-style utt2spk and + segments to an RTTM""") + + parser.add_argument("utt2spk", type=str, + help="Input utt2spk file") + parser.add_argument("segments", type=str, + help="Input segments file") + parser.add_argument("reco2file_and_channel", type=str, + help="""Input reco2file_and_channel. + The format is .""") + parser.add_argument("rttm_file", type=str, + help="Output RTTM file") + + args = parser.parse_args() + return args + +def main(): + args = get_args() + + reco2file_and_channel = {} + for line in open(args.reco2file_and_channel): + parts = line.strip().split() + reco2file_and_channel[parts[0]] = (parts[1], parts[2]) + + utt2spk = {} + with open(args.utt2spk, 'r') as utt2spk_reader: + for line in utt2spk_reader: + parts = line.strip().split() + utt2spk[parts[0]] = parts[1] + + with open(args.rttm_file, 'w') as rttm_writer: + for line in open(args.segments, 'r'): + parts = line.strip().split() + + utt = parts[0] + spkr = utt2spk[utt] + + reco = parts[1] + + try: + file_id, channel = reco2file_and_channel[reco] + except KeyError as e: + raise Exception("Could not find recording {0} in {1}: " + "{2}\n".format(reco, + args.reco2file_and_channel, + str(e))) + + start_time = float(parts[2]) + duration = float(parts[3]) - start_time + + rttm_writer.write("SPEAKER {0} {1} {2:7.2f} {3:7.2f} " + " {4} \n".format( + file_id, channel, start_time, + duration, spkr)) + +if __name__ == '__main__': + main() diff --git a/egs/wsj/s5/steps/segmentation/decode_sad.sh b/egs/wsj/s5/steps/segmentation/decode_sad.sh new file mode 100755 index 00000000000..a39e93dd83f --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/decode_sad.sh @@ -0,0 +1,56 @@ +#! /bin/bash + +set -e +set -o pipefail + +cmd=run.pl +acwt=0.1 +beam=8 +max_active=1000 +get_pdfs=false +iter=final + +. path.sh + +. utils/parse_options.sh + +if [ $# -ne 3 ]; then + echo "Usage: $0 " + echo " e.g.: $0 " + exit 1 +fi + +graph_dir=$1 +log_likes_dir=$2 +dir=$3 + +mkdir -p $dir +nj=`cat $log_likes_dir/num_jobs` +echo $nj > $dir/num_jobs + +if [ -f $dir/$iter.mdl ]; then + srcdir=$dir +else + srcdir=`dirname $dir` +fi + +for f in $srcdir/$iter.mdl $log_likes_dir/log_likes.1.gz $graph_dir/HCLG.fst; do + if [ ! -f $f ]; then + echo "$0: Could not find file $f" + exit 1 + fi +done + +decoder_opts+=(--acoustic-scale=$acwt --beam=$beam --max-active=$max_active) + +ali="ark:| ali-to-phones --per-frame $srcdir/$iter.mdl ark:- ark:- | gzip -c > $dir/ali.JOB.gz" + +if $get_pdfs; then + ali="ark:| ali-to-pdf $srcdir/$iter.mdl ark:- ark:- | gzip -c > $dir/ali.JOB.gz" +fi + +$cmd JOB=1:$nj $dir/log/decode.JOB.log \ + decode-faster-mapped ${decoder_opts[@]} \ + $srcdir/$iter.mdl \ + $graph_dir/HCLG.fst "ark:gunzip -c $log_likes_dir/log_likes.JOB.gz |" \ + ark:/dev/null "$ali" diff --git a/egs/wsj/s5/steps/segmentation/decode_sad_to_segments.sh b/egs/wsj/s5/steps/segmentation/decode_sad_to_segments.sh new file mode 100755 index 00000000000..84287230fba --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/decode_sad_to_segments.sh @@ -0,0 +1,109 @@ +#! /bin/bash + +set -e +set -o pipefail +set -u + +stage=-1 +segmentation_config=conf/segmentation.conf +cmd=run.pl + +# Viterbi options +min_silence_duration=30 # minimum number of frames for silence +min_speech_duration=30 # minimum number of frames for speech +frame_subsampling_factor=1 +nonsil_transition_probability=0.1 +sil_transition_probability=0.1 +sil_prior=0.5 +speech_prior=0.5 +use_unigram_lm=true + +# Decoding options +acwt=1 +beam=10 +max_active=7000 + +. utils/parse_options.sh + +if [ $# -ne 4 ]; then + echo "Usage: $0 " + echo " e.g.: $0 data/babel_bengali_dev10h exp/nnet3_sad_snr/tdnn_b_n4/sad_babel_bengali_dev10h exp/nnet3_sad_snr/tdnn_b_n4/segmentation_babel_bengali_dev10h exp/nnet3_sad_snr/tdnn_b_n4/segmentation_babel_bengali_dev10h/babel_bengali_dev10h.seg" + exit 1 +fi + +data=$1 +sad_likes_dir=$2 +dir=$3 +out_data=$4 + +t=sil${sil_prior}_sp${speech_prior} +lang=$dir/lang_test_${t} + +min_silence_duration=`perl -e "print (int($min_silence_duration / $frame_subsampling_factor))"` +min_speech_duration=`perl -e "print (int($min_speech_duration / $frame_subsampling_factor))"` + +if [ $stage -le 1 ]; then + mkdir -p $lang + + steps/segmentation/internal/prepare_sad_lang.py \ + --phone-transition-parameters="--phone-list=1 --min-duration=$min_silence_duration --end-transition-probability=$sil_transition_probability" \ + --phone-transition-parameters="--phone-list=2 --min-duration=$min_speech_duration --end-transition-probability=$nonsil_transition_probability" $lang + + cp $lang/phones.txt $lang/words.txt +fi + +feat_dim=2 # dummy. We don't need this. +if [ $stage -le 2 ]; then + $cmd $dir/log/create_transition_model.log gmm-init-mono \ + $lang/topo $feat_dim - $dir/tree \| \ + copy-transition-model --binary=false - $dir/trans.mdl || exit 1 +fi + +if [ $stage -le 3 ]; then + if $use_unigram_lm; then + cat > $lang/word2prior < $lang/G.fst + else + { + echo "1 0.99 1:0.6 2:0.39"; + echo "2 0.01 1:0.5 2:0.49"; + } | \ + steps/segmentation/internal/make_bigram_G_fst.py - - | \ + fstcompile --isymbols=$lang/words.txt --osymbols=$lang/words.txt \ + --keep_isymbols=false --keep_osymbols=false \ + > $lang/G.fst + fi +fi + +graph_dir=$dir/graph_test_${t} + +if [ $stage -le 4 ]; then + $cmd $dir/log/make_vad_graph.log \ + steps/segmentation/internal/make_sad_graph.sh --iter trans \ + $lang $dir $dir/graph_test_${t} || exit 1 + cp $dir/trans.mdl $graph_dir +fi + +if [ $stage -le 5 ]; then + steps/segmentation/decode_sad.sh \ + --acwt $acwt --beam $beam --max-active $max_active --iter trans \ + $graph_dir $sad_likes_dir $dir +fi + +if [ $stage -le 6 ]; then + cat > $lang/phone2sad_map < 8kHz sampling frequency. +do_downsampling=false + +# Segmentation configs +min_silence_duration=30 +min_speech_duration=30 +sil_prior=0.5 +speech_prior=0.5 +segmentation_config=conf/segmentation_speech.conf +convert_data_dir_to_whole=true + +echo $* + +. utils/parse_options.sh + +if [ $# -ne 4 ]; then + echo "Usage: $0 " + echo " e.g.: $0 ~/workspace/egs/ami/s5b/data/sdm1/dev exp/nnet3_sad_snr/nnet_tdnn_j_n4 mfcc_hires_bp data/ami_sdm1_dev" + exit 1 +fi + +src_data_dir=$1 # The input data directory that needs to be segmented. + # Any segments in that will be ignored. +sad_nnet_dir=$2 # The SAD neural network +mfcc_dir=$3 # The directory to store the features +data_dir=$4 # The output data directory will be ${data_dir}_seg + +affix=${affix:+_$affix} +feat_affix=${feat_affix:+_$feat_affix} + +data_id=`basename $data_dir` +sad_dir=${sad_nnet_dir}/${sad_name}${affix}_${data_id}_whole${feat_affix} +seg_dir=${sad_nnet_dir}/${segmentation_name}${affix}_${data_id}_whole${feat_affix} + +export PATH="$KALDI_ROOT/tools/sph2pipe_v2.5/:$PATH" +[ ! -z `which sph2pipe` ] + +whole_data_dir=${sad_dir}/${data_id}_whole + +if $convert_data_dir_to_whole; then + if [ $stage -le 0 ]; then + utils/data/convert_data_dir_to_whole.sh $src_data_dir ${whole_data_dir} + + if $do_downsampling; then + freq=`cat $mfcc_config | perl -pe 's/\s*#.*//g' | grep "sample-frequency=" | awk -F'=' '{if (NF == 0) print 16000; else print $2}'` + utils/data/downsample_data_dir.sh $freq $whole_data_dir + fi + + utils/copy_data_dir.sh ${whole_data_dir} ${whole_data_dir}${feat_affix}_hires + fi + + if [ $stage -le 1 ]; then + steps/make_mfcc.sh --mfcc-config $mfcc_config --nj $reco_nj --cmd "$train_cmd" \ + ${whole_data_dir}${feat_affix}_hires exp/make_hires/${data_id}_whole${feat_affix} $mfcc_dir + steps/compute_cmvn_stats.sh ${whole_data_dir}${feat_affix}_hires exp/make_hires/${data_id}_whole${feat_affix} $mfcc_dir + utils/fix_data_dir.sh ${whole_data_dir}${feat_affix}_hires + fi + test_data_dir=${whole_data_dir}${feat_affix}_hires +else + test_data_dir=$src_data_dir +fi + +post_vec=$sad_nnet_dir/post_${output_name}.vec +if [ ! -f $sad_nnet_dir/post_${output_name}.vec ]; then + echo "$0: Could not find $sad_nnet_dir/post_${output_name}.vec. See the last stage of local/segmentation/run_train_sad.sh" + exit 1 +fi + +if [ $stage -le 2 ]; then + steps/nnet3/compute_output.sh --nj $reco_nj --cmd "$train_cmd" \ + --post-vec "$post_vec" \ + --iter $iter \ + --extra-left-context $extra_left_context \ + --extra-right-context $extra_right_context \ + --frames-per-chunk 150 \ + --stage $sad_stage --output-name $output_name \ + --frame-subsampling-factor $frame_subsampling_factor \ + --use-raw-nnet true ${test_data_dir} $sad_nnet_dir $sad_dir +fi + +if [ $stage -le 3 ]; then + steps/segmentation/decode_sad_to_segments.sh \ + --use-unigram-lm false \ + --frame-subsampling-factor $frame_subsampling_factor \ + --min-silence-duration $min_silence_duration \ + --min-speech-duration $min_speech_duration \ + --sil-prior $sil_prior \ + --speech-prior $speech_prior \ + --segmentation-config $segmentation_config --cmd "$train_cmd" \ + ${test_data_dir} $sad_dir $seg_dir ${data_dir}_seg +fi + +# Subsegment data directory +if [ $stage -le 4 ]; then + rm ${data_dir}_seg/feats.scp || true + utils/data/get_reco2num_frames.sh --cmd "$train_cmd" --nj $reco_nj ${test_data_dir} + awk '{print $1" "$2}' ${data_dir}_seg/segments | \ + utils/apply_map.pl -f 2 ${test_data_dir}/reco2num_frames > \ + ${data_dir}_seg/utt2max_frames + + #frame_shift_info=`cat $mfcc_config | steps/segmentation/get_frame_shift_info_from_config.pl` + #utils/data/get_subsegment_feats.sh ${test_data_dir}/feats.scp \ + # $frame_shift_info ${data_dir}_seg/segments | \ + # utils/data/fix_subsegmented_feats.pl ${data_dir}_seg/utt2max_frames > \ + # ${data_dir}_seg/feats.scp + steps/compute_cmvn_stats.sh --fake ${data_dir}_seg + + utils/fix_data_dir.sh ${data_dir}_seg +fi diff --git a/egs/wsj/s5/steps/segmentation/do_segmentation_data_dir_simple.sh b/egs/wsj/s5/steps/segmentation/do_segmentation_data_dir_simple.sh new file mode 100755 index 00000000000..7211b6b7084 --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/do_segmentation_data_dir_simple.sh @@ -0,0 +1,182 @@ +#!/bin/bash + +set -e +set -o pipefail +set -u + +. path.sh +. cmd.sh + +affix= # Affix for the segmentation +nj=32 # works on recordings as against on speakers + +# Feature options (Must match training) +mfcc_config=conf/mfcc_hires_bp.conf +feat_affix=bp # Affix for the type of feature used + +convert_data_dir_to_whole=true + +# Set to true if the test data has > 8kHz sampling frequency. +do_downsampling=false + +stage=-1 +sad_stage=-1 +output_name=output-speech # The output node in the network +sad_name=sad # Base name for the directory storing the computed loglikes +segmentation_name=segmentation # Base name for the directory doing segmentation + +# SAD network config +iter=final # Model iteration to use + +# Contexts must ideally match training for LSTM models, but +# may not necessarily for stats components +extra_left_context=0 # Set to some large value, typically 40 for LSTM (must match training) +extra_right_context=0 + +frame_subsampling_factor=1 # Subsampling at the output + +transition_scale=3.0 +loopscale=0.1 +acwt=1.0 + +# Segmentation configs +segmentation_config=conf/segmentation_speech.conf + +echo $* + +. utils/parse_options.sh + +if [ $# -ne 5 ]; then + echo "Usage: $0 " + echo " e.g.: $0 ~/workspace/egs/ami/s5b/data/sdm1/dev exp/nnet3_sad_snr/nnet_tdnn_j_n4 mfcc_hires_bp data/ami_sdm1_dev" + exit 1 +fi + +src_data_dir=$1 # The input data directory that needs to be segmented. + # Any segments in that will be ignored. +sad_nnet_dir=$2 # The SAD neural network +lang=$3 +mfcc_dir=$4 # The directory to store the features +data_dir=$5 # The output data directory will be ${data_dir}_seg + +affix=${affix:+_$affix} +feat_affix=${feat_affix:+_$feat_affix} + +data_id=`basename $data_dir` +sad_dir=${sad_nnet_dir}/${sad_name}${affix}_${data_id}_whole${feat_affix} +seg_dir=${sad_nnet_dir}/${segmentation_name}${affix}_${data_id}_whole${feat_affix} + +export PATH="$KALDI_ROOT/tools/sph2pipe_v2.5/:$PATH" +[ ! -z `which sph2pipe` ] + +test_data_dir=data/${data_id}${feat_affix}_hires + +if $convert_data_dir_to_whole; then + if [ $stage -le 0 ]; then + whole_data_dir=${sad_dir}/${data_id}_whole + utils/data/convert_data_dir_to_whole.sh $src_data_dir ${whole_data_dir} + + if $do_downsampling; then + freq=`cat $mfcc_config | perl -pe 's/\s*#.*//g' | grep "sample-frequency=" | awk -F'=' '{if (NF == 0) print 16000; else print $2}'` + utils/data/downsample_data_dir.sh $freq $whole_data_dir + fi + + rm -r ${test_data_dir} || true + utils/copy_data_dir.sh ${whole_data_dir} $test_data_dir + fi +else + if [ $stage -le 0 ]; then + rm -r ${test_data_dir} || true + utils/copy_data_dir.sh $src_data_dir $test_data_dir + + if $do_downsampling; then + freq=`cat $mfcc_config | perl -pe 's/\s*#.*//g' | grep "sample-frequency=" | awk -F'=' '{if (NF == 0) print 16000; else print $2}'` + utils/data/downsample_data_dir.sh $freq $test_data_dir + fi + fi +fi + +if [ $stage -le 1 ]; then + utils/fix_data_dir.sh $test_data_dir + steps/make_mfcc.sh --mfcc-config $mfcc_config --nj $nj --cmd "$train_cmd" \ + ${test_data_dir} exp/make_hires/${data_id}${feat_affix} $mfcc_dir + steps/compute_cmvn_stats.sh ${test_data_dir} exp/make_hires/${data_id}${feat_affix} $mfcc_dir + utils/fix_data_dir.sh ${test_data_dir} +fi + +post_vec=$sad_nnet_dir/post_${output_name}.vec +if [ ! -f $sad_nnet_dir/post_${output_name}.vec ]; then + echo "$0: Could not find $sad_nnet_dir/post_${output_name}.vec. See the last stage of local/segmentation/run_train_sad.sh" + exit 1 +fi + +create_topo=true +if $create_topo; then + if [ ! -f $lang/classes_info.txt ]; then + echo "$0: Could not find $lang/topo or $lang/classes_info.txt" + exit 1 + else + steps/segmentation/internal/prepare_simple_hmm_lang.py \ + $lang/classes_info.txt $lang + fi +fi + +if [ $stage -le 3 ]; then + simple-hmm-init $lang/topo $lang/init.mdl + + $train_cmd $sad_nnet_dir/log/get_final_${output_name}_model.log \ + nnet3-am-init $lang/init.mdl \ + "nnet3-copy --edits='rename-node old-name=$output_name new-name=output' $sad_nnet_dir/$iter.raw - |" - \| \ + nnet3-am-adjust-priors - $sad_nnet_dir/post_${output_name}.vec \ + $sad_nnet_dir/${iter}_${output_name}.mdl +fi +iter=${iter}_${output_name} + +if [ $stage -le 4 ]; then + steps/nnet3/compute_output.sh --nj $nj --cmd "$train_cmd" \ + --iter $iter --use-raw-nnet false \ + --extra-left-context $extra_left_context \ + --extra-right-context $extra_right_context \ + --frames-per-chunk 150 \ + --stage $sad_stage \ + --frame-subsampling-factor $frame_subsampling_factor \ + ${test_data_dir} $sad_nnet_dir $sad_dir +fi + +graph_dir=${sad_nnet_dir}/graph_${output_name} + +if [ $stage -le 5 ]; then + cp -r $lang $graph_dir + + if [ ! -f $lang/final.mdl ]; then + echo "$0: Could not find $lang/final.mdl!" + echo "$0: Using $lang/init.mdl instead" + cp $lang/init.mdl $graph_dir/final.mdl + else + cp $lang/final.mdl $graph_dir + fi + + $train_cmd $lang/log/make_graph.log \ + make-simple-hmm-graph --transition-scale=$transition_scale \ + --self-loop-scale=$loopscale \ + $graph_dir/final.mdl \| \ + fstdeterminizestar --use-log=true \| \ + fstrmepslocal \| \ + fstminimizeencoded '>' $graph_dir/HCLG.fst +fi + +if [ $stage -le 6 ]; then + steps/segmentation/decode_sad.sh --acwt 1.0 --cmd "$decode_cmd" \ + --iter ${iter} \ + --get-pdfs true $graph_dir $sad_dir $seg_dir +fi + +if [ $stage -le 7 ]; then + steps/segmentation/post_process_sad_to_subsegments.sh \ + --cmd "$train_cmd" --segmentation-config $segmentation_config \ + --frame-subsampling-factor $frame_subsampling_factor \ + ${test_data_dir} $lang/phone2sad_map ${seg_dir} \ + ${seg_dir} ${data_dir}_seg + + cp $src_data_dir/wav.scp ${data_dir}_seg +fi diff --git a/egs/wsj/s5/steps/segmentation/evaluate_segmentation.pl b/egs/wsj/s5/steps/segmentation/evaluate_segmentation.pl new file mode 100755 index 00000000000..06a762d7762 --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/evaluate_segmentation.pl @@ -0,0 +1,198 @@ +#!/usr/bin/env perl + +# Copyright 2014 Johns Hopkins University (Author: Sanjeev Khudanpur), Vimal Manohar +# Apache 2.0 + +################################################################################ +# +# This script was written to check the goodness of automatic segmentation tools +# It assumes input in the form of two Kaldi segments files, i.e. a file each of +# whose lines contain four space-separated values: +# +# UtteranceID FileID StartTime EndTime +# +# It computes # missed frames, # false positives and # overlapping frames. +# +################################################################################ + +if ($#ARGV == 1) { + $ReferenceSegmentation = $ARGV[0]; + $HypothesizedSegmentation = $ARGV[1]; + printf STDERR ("Comparing reference segmentation\n\t%s\nwith proposed segmentation\n\t%s\n", + $ReferenceSegmentation, + $HypothesizedSegmentation); +} else { + printf STDERR "This program compares the reference segmenation with the proposted segmentation\n"; + printf STDERR "Usage: $0 reference_segments_filename proposed_segments_filename\n"; + printf STDERR "e.g. $0 data/dev10h/segments data/dev10h.seg/segments\n"; + exit (0); +} + +################################################################################ +# First read the reference segmentation, and +# store the start- and end-times of all segments in each file. +################################################################################ + +open (SEGMENTS, "cat $ReferenceSegmentation | sort -k2,2 -k3n,3 -k4n,4 |") + || die "Unable to open $ReferenceSegmentation"; +$numLines = 0; +while ($line=) { + chomp $line; + @field = split("[ \t]+", $line); + unless ($#field == 3) { + exit (1); + printf STDERR "Skipping unparseable line in file $ReferenceSegmentation\n\t$line\n"; + next; + } + $fileID = $field[1]; + unless (exists $firstSeg{$fileID}) { + $firstSeg{$fileID} = $numLines; + $actualSpeech{$fileID} = 0.0; + $hypothesizedSpeech{$fileID} = 0.0; + $foundSpeech{$fileID} = 0.0; + $falseAlarm{$fileID} = 0.0; + $minStartTime{$fileID} = 0.0; + $maxEndTime{$fileID} = 0.0; + } + $refSegName[$numLines] = $field[0]; + $refSegStart[$numLines] = $field[2]; + $refSegEnd[$numLines] = $field[3]; + $actualSpeech{$fileID} += ($field[3]-$field[2]); + $minStartTime{$fileID} = $field[2] if ($minStartTime{$fileID}>$field[2]); + $maxEndTime{$fileID} = $field[3] if ($maxEndTime{$fileID}<$field[3]); + $lastSeg{$fileID} = $numLines; + ++$numLines; +} +close(SEGMENTS); +print STDERR "Read $numLines segments from $ReferenceSegmentation\n"; + +################################################################################ +# Process hypothesized segments sequentially, and gather speech/nonspeech stats +################################################################################ + +open (SEGMENTS, "cat $HypothesizedSegmentation | sort -k2,2 -k1,1 |") + # Kaldi segments files are sorted by UtteranceID, but we re-sort them here + # so that all segments of a file are read together, sorted by start-time. + || die "Unable to open $HypothesizedSegmentation"; +$numLines = 0; +$totalHypSpeech = 0.0; +$totalFoundSpeech = 0.0; +$totalFalseAlarm = 0.0; +$numShortSegs = 0; +$numLongSegs = 0; +while ($line=) { + chomp $line; + @field = split("[ \t]+", $line); + unless ($#field == 3) { + exit (1); + printf STDERR "Skipping unparseable line in file $HypothesizedSegmentation\n\t$line\n"; + next; + } + $fileID = $field[1]; + $segStart = $field[2]; + $segEnd = $field[3]; + if (exists $firstSeg{$fileID}) { + # This FileID exists in the reference segmentation + # So gather statistics for this UtteranceID + $hypothesizedSpeech{$fileID} += ($segEnd-$segStart); + $totalHypSpeech += ($segEnd-$segStart); + if (($segStart>=$maxEndTime{$fileID}) || ($segEnd<=$minStartTime{$fileID})) { + # This entire segment is a false alarm + $falseAlarm{$fileID} += ($segEnd-$segStart); + $totalFalseAlarm += ($segEnd-$segStart); + } else { + # This segment may overlap one or more reference segments + $p = $firstSeg{$fileID}; + while ($refSegEnd[$p]<=$segStart) { + ++$p; + } + # The overlap, if any, begins at the reference segment p + $q = $lastSeg{$fileID}; + while ($refSegStart[$q]>=$segEnd) { + --$q; + } + # The overlap, if any, ends at the reference segment q + if ($q<$p) { + # This segment sits entirely in the nonspeech region + # between the two reference speech segments q and p + $falseAlarm{$fileID} += ($segEnd-$segStart); + $totalFalseAlarm += ($segEnd-$segStart); + } else { + if (($segEnd-$segStart)<0.20) { + # For diagnosing Pascal's VAD segmentation + print STDOUT "Found short speech region $line\n"; + ++$numShortSegs; + } elsif (($segEnd-$segStart)>60.0) { + ++$numLongSegs; + # For diagnosing Pascal's VAD segmentation + print STDOUT "Found long speech region $line\n"; + } + # There is some overlap with segments p through q + for ($s=$p; $s<=$q; ++$s) { + if ($segStart<$refSegStart[$s]) { + # There is a leading false alarm portion before s + $falseAlarm{$fileID} += ($refSegStart[$s]-$segStart); + $totalFalseAlarm += ($refSegStart[$s]-$segStart); + $segStart=$refSegStart[$s]; + } + $speechPortion = ($refSegEnd[$s]<$segEnd) ? + ($refSegEnd[$s]-$segStart) : ($segEnd-$segStart); + $foundSpeech{$fileID} += $speechPortion; + $totalFoundSpeech += $speechPortion; + $segStart=$refSegEnd[$s]; + } + if ($segEnd>$segStart) { + # There is a trailing false alarm portion after q + $falseAlarm{$fileID} += ($segEnd-$segStart); + $totalFalseAlarm += ($segEnd-$segStart); + } + } + } + } else { + # This FileID does not exist in the reference segmentation + # So all this speech counts as a false alarm + exit (1); + printf STDERR ("Unexpected fileID in hypothesized segments: %s", $fileID); + $totalFalseAlarm += ($segEnd-$segStart); + } + ++$numLines; +} +close(SEGMENTS); +print STDERR "Read $numLines segments from $HypothesizedSegmentation\n"; + +################################################################################ +# Now that all hypothesized segments have been processed, compute needed stats +################################################################################ + +$totalActualSpeech = 0.0; +$totalNonSpeechEst = 0.0; # This is just a crude estimate of total nonspeech. +foreach $fileID (sort keys %actualSpeech) { + $totalActualSpeech += $actualSpeech{$fileID}; + $totalNonSpeechEst += $maxEndTime{$fileID} - $actualSpeech{$fileID}; + ####################################################################### + # Print file-wise statistics to STDOUT; can pipe to /dev/null is needed + ####################################################################### + printf STDOUT ("%s: %.2f min actual speech, %.2f min hypothesized: %.2f min overlap (%d\%), %.2f min false alarm (~%d\%)\n", + $fileID, + ($actualSpeech{$fileID}/60.0), + ($hypothesizedSpeech{$fileID}/60.0), + ($foundSpeech{$fileID}/60.0), + ($foundSpeech{$fileID}*100/($actualSpeech{$fileID}+0.01)), + ($falseAlarm{$fileID}/60.0), + ($falseAlarm{$fileID}*100/($maxEndTime{$fileID}-$actualSpeech{$fileID}+0.01))); +} + +################################################################################ +# Finally, we have everything needed to report the segmentation statistics. +################################################################################ + +printf STDERR ("------------------------------------------------------------------------\n"); +printf STDERR ("TOTAL: %.2f hrs actual speech, %.2f hrs hypothesized: %.2f hrs overlap (%d\%), %.2f hrs false alarm (~%d\%)\n", + ($totalActualSpeech/3600.0), + ($totalHypSpeech/3600.0), + ($totalFoundSpeech/3600.0), + ($totalFoundSpeech*100/($totalActualSpeech+0.000001)), + ($totalFalseAlarm/3600.0), + ($totalFalseAlarm*100/($totalNonSpeechEst+0.000001))); +printf STDERR ("\t$numShortSegs segments < 0.2 sec and $numLongSegs segments > 60.0 sec\n"); +printf STDERR ("------------------------------------------------------------------------\n"); diff --git a/egs/wsj/s5/steps/segmentation/get_frame_shift_info_from_config.pl b/egs/wsj/s5/steps/segmentation/get_frame_shift_info_from_config.pl new file mode 100755 index 00000000000..79a42aa9852 --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/get_frame_shift_info_from_config.pl @@ -0,0 +1,21 @@ +#! /usr/bin/perl +use strict; +use warnings; + +# This script parses a features config file such as conf/mfcc.conf +# and returns the pair of values frame_shift and frame_overlap in seconds. + +my $frame_shift = 0.01; +my $frame_overlap = 0.015; + +while (<>) { + if (m/--frame-length=(\d+)/) { + $frame_shift = $1 / 1000; + } + + if (m/--window-length=(\d+)/) { + $frame_overlap = $1 / 1000 - $frame_shift; + } +} + +print "$frame_shift $frame_overlap\n"; diff --git a/egs/wsj/s5/steps/segmentation/get_reverb_scp.pl b/egs/wsj/s5/steps/segmentation/get_reverb_scp.pl new file mode 100755 index 00000000000..57f63b517f2 --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/get_reverb_scp.pl @@ -0,0 +1,58 @@ +#! /usr/bin/perl +use strict; +use warnings; + +my $field_begin = -1; +my $field_end = -1; + +if ($ARGV[0] eq "-f") { + shift @ARGV; + my $field_spec = shift @ARGV; + if ($field_spec =~ m/^\d+$/) { + $field_begin = $field_spec - 1; $field_end = $field_spec - 1; + } + if ($field_spec =~ m/^(\d*)[-:](\d*)/) { # accept e.g. 1:10 as a courtesty (properly, 1-10) + if ($1 ne "") { + $field_begin = $1 - 1; # Change to zero-based indexing. + } + if ($2 ne "") { + $field_end = $2 - 1; # Change to zero-based indexing. + } + } + if (!defined $field_begin && !defined $field_end) { + die "Bad argument to -f option: $field_spec"; + } +} + +if (scalar @ARGV != 1 && scalar @ARGV != 2 ) { + print "Usage: get_reverb_scp.pl [-f -] [] < input_scp > output_scp\n"; + exit(1); +} + +my $num_reps = $ARGV[0]; +my $prefix = "rev"; + +if (scalar @ARGV == 2) { + $prefix = $ARGV[1]; +} + +while () { + chomp; + my @A = split; + + for (my $i = 1; $i <= $num_reps; $i++) { + for (my $pos = 0; $pos <= $#A; $pos++) { + my $a = $A[$pos]; + if ( ($field_begin < 0 || $pos >= $field_begin) + && ($field_end < 0 || $pos <= $field_end) ) { + if ($a =~ m/^(sp[0-9.]+-)(.+)$/) { + $a = $1 . "$prefix" . $i . "_" . $2; + } else { + $a = "$prefix" . $i . "_" . $a; + } + } + print $a . " "; + } + print "\n"; + } +} diff --git a/egs/wsj/s5/steps/segmentation/get_sad_map.py b/egs/wsj/s5/steps/segmentation/get_sad_map.py new file mode 100755 index 00000000000..222e6c1a512 --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/get_sad_map.py @@ -0,0 +1,132 @@ +#! /usr/bin/env python + +"""This script prints a mapping from phones to speech +activity labels +0 for silence, 1 for speech, 2 for noise and 3 for OOV. +Other labels can be optionally defined. +e.g. If 1, 2 and 3 are silence phones, 4, 5 and 6 are speech phones, +the SAD map would be +1 0 +2 0 +3 0 +4 1 +5 1 +6 1. +The silence and speech are read from the phones/silence.txt and +phones/nonsilence.txt from the lang directory. +An initial SAD map can be provided using --init-sad-map to override +the above default mapping of phones. This is useful to say map + or noise phones to separate SAD labels. +""" + +import argparse +import sys + +sys.path.insert(0, 'steps') +import libs.common as common_lib + + +def get_args(): + parser = argparse.ArgumentParser( + description="""This script prints a mapping from phones to speech + activity labels + 0 for silence, 1 for speech, 2 for noise and 3 for OOV. + Other labels can be optionally defined. + e.g. If 1, 2 and 3 are silence phones, 4, 5 and 6 are speech phones, + the SAD map would be + 1 0 + 2 0 + 3 0 + 4 1 + 5 1 + 6 1. + The silence and speech are read from the phones/silence.txt and + phones/nonsilence.txt from the lang directory. + An initial SAD map can be provided using --init-sad-map to override + the above default mapping of phones. This is useful to say map + or noise phones to separate SAD labels. + """) + + parser.add_argument("--init-sad-map", type=str, action=common_lib.NullstrToNoneAction, + help="""Initial SAD map that will be used to override + the default mapping using phones/silence.txt and + phones/nonsilence.txt. Does not need to specify labels + for all the phones. + e.g. + 3 + 2""") + + noise_group = parser.add_mutually_exclusive_group() + noise_group.add_argument("--noise-phones-file", type=str, + action=common_lib.NullstrToNoneAction, + help="Map noise phones from file to label 2") + noise_group.add_argument("--noise-phones-list", type=str, + action=common_lib.NullstrToNoneAction, + help="A colon-separated list of noise phones to " + "map to label 2") + parser.add_argument("--unk", type=str, action=common_lib.NullstrToNoneAction, + help="""UNK phone, if provided will be mapped to + label 3""") + + parser.add_argument("--map-noise-to-sil", type=str, + action=common_lib.StrToBoolAction, + choices=["true", "false"], default=False, + help="""Map noise phones to silence before writing the + map. i.e. anything with label 2 is mapped to + label 0.""") + parser.add_argument("--map-unk-to-speech", type=str, + action=common_lib.StrToBoolAction, + choices=["true", "false"], default=False, + help="""Map UNK phone to speech before writing the map + i.e. anything with label 3 is mapped to label 1.""") + + parser.add_argument("lang_dir") + + args = parser.parse_args() + + return args + + +def main(): + args = get_args() + + sad_map = {} + + for line in open('{0}/phones/nonsilence.txt'.format(args.lang_dir)): + parts = line.strip().split() + sad_map[parts[0]] = 1 + + for line in open('{0}/phones/silence.txt'.format(args.lang_dir)): + parts = line.strip().split() + sad_map[parts[0]] = 0 + + if args.init_sad_map is not None: + for line in open(args.init_sad_map): + parts = line.strip().split() + try: + sad_map[parts[0]] = int(parts[1]) + except Exception: + raise Exception("Invalid line " + line) + + if args.unk is not None: + sad_map[args.unk] = 3 + + noise_phones = {} + if args.noise_phones_file is not None: + for line in open(args.noise_phones_file): + parts = line.strip().split() + noise_phones[parts[0]] = 1 + + if args.noise_phones_list is not None: + for x in args.noise_phones_list.split(":"): + noise_phones[x] = 1 + + for x, l in sad_map.iteritems(): + if l == 2 and args.map_noise_to_sil: + l = 0 + if l == 3 and args.map_unk_to_speech: + l = 1 + print ("{0} {1}".format(x, l)) + +if __name__ == "__main__": + main() diff --git a/egs/wsj/s5/steps/segmentation/internal/convert_ali_to_vad.sh b/egs/wsj/s5/steps/segmentation/internal/convert_ali_to_vad.sh new file mode 100755 index 00000000000..0d8939a9b80 --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/internal/convert_ali_to_vad.sh @@ -0,0 +1,59 @@ +#! /bin/bash + +set -o pipefail +set -e +set -u + +. path.sh + +cmd=run.pl + +. parse_options.sh + +if [ $# -ne 3 ]; then + echo "This script converts the alignment in the alignment directory " + echo "to speech activity segments based on the provided phone-map." + echo "The output is stored in sad_seg.*.ark along with an scp-file " + echo "sad_seg.scp in Segmentation format.\n" + echo "If alignment directory has frame_subsampling_factor, the segments " + echo "are applied that frame-subsampling-factor.\n" + echo "The phone-map file must have two columns: " + echo " \n" + echo "\n" + echo "Usage: $0 " + echo "e.g. : $0 exp/tri3_ali data/lang/phones/sad.map exp/tri3_ali_vad" + exit 1 +fi + +ali_dir=$1 +phone_map=$2 +dir=$3 + +for f in $phone_map $ali_dir/ali.1.gz; do + [ ! -f $f ] && echo "$0: Could not find $f" && exit 1 +done + +mkdir -p $dir + +nj=`cat $ali_dir/num_jobs` || exit 1 +echo $nj > $dir/num_jobs + +frame_subsampling_factor=1 +if [ -f $ali_dir/frame_subsampling_factor ]; then + frame_subsampling_factor=`cat $ali_dir/frame_subsampling_factor` +fi + +dir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' $dir ${PWD}` + +$cmd JOB=1:$nj $dir/log/get_sad.JOB.log \ + segmentation-init-from-ali \ + "ark:gunzip -c ${ali_dir}/ali.JOB.gz | ali-to-phones --per-frame ${ali_dir}/final.mdl ark:- ark:- |" \ + ark:- \| \ + segmentation-copy --label-map=$phone_map \ + --frame-subsampling-factor=$frame_subsampling_factor ark:- ark:- \| \ + segmentation-post-process --merge-adjacent-segments ark:- \ + ark,scp:$dir/sad_seg.JOB.ark,$dir/sad_seg.JOB.scp + +for n in `seq $nj`; do + cat $dir/sad_seg.$n.scp +done | sort -k1,1 > $dir/sad_seg.scp diff --git a/egs/wsj/s5/steps/segmentation/internal/make_G_fst.py b/egs/wsj/s5/steps/segmentation/internal/make_G_fst.py new file mode 100755 index 00000000000..5ad7e867d10 --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/internal/make_G_fst.py @@ -0,0 +1,52 @@ +#! /usr/bin/env python + +from __future__ import print_function +import argparse, math + +def ParseArgs(): + parser = argparse.ArgumentParser("""Make a simple unigram FST for +decoding for segmentation purpose.""") + + parser.add_argument("--word2prior-map", type=str, required=True, + help = "A file with priors for different words") + parser.add_argument("--end-probability", type=float, default=0.01, + help = "Ending probability") + + args = parser.parse_args() + + return args + +def ReadMap(map_file): + out_map = {} + sum_prob = 0 + for line in open(map_file): + parts = line.strip().split() + if len(parts) == 0: + continue + if len(parts) != 2: + raise Exception("Invalid line {0} in {1}".format(line.strip(), map_file)) + + if parts[0] in out_map: + raise Exception("Duplicate entry of {0} in {1}".format(parts[0], map_file)) + + prob = float(parts[1]) + out_map[parts[0]] = prob + + sum_prob += prob + + return (out_map, sum_prob) + +def Main(): + args = ParseArgs() + + word2prior, sum_prob = ReadMap(args.word2prior_map) + sum_prob += args.end_probability + + for w,p in word2prior.iteritems(): + print ("0 0 {word} {word} {log_p}".format(word = w, + log_p = -math.log(p / sum_prob))) + print ("0 {log_p}".format(word = w, + log_p = -math.log(args.end_probability / sum_prob))) + +if __name__ == '__main__': + Main() diff --git a/egs/wsj/s5/steps/segmentation/internal/make_bigram_G_fst.py b/egs/wsj/s5/steps/segmentation/internal/make_bigram_G_fst.py new file mode 100755 index 00000000000..2431d293c4c --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/internal/make_bigram_G_fst.py @@ -0,0 +1,174 @@ +#! /usr/bin/env python + +from __future__ import print_function +import argparse +import logging +import math + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) +handler = logging.StreamHandler() +handler.setLevel(logging.INFO) +formatter = logging.Formatter("%(asctime)s [%(filename)s:%(lineno)s - " + "%(funcName)s - %(levelname)s ] %(message)s") +handler.setFormatter(formatter) +logger.addHandler(handler) + + +def get_args(): + parser = argparse.ArgumentParser( + description="""This script generates a bigram G.fst lang for decoding. + It needs as an input classes_info file with the format: + , + where each pair is :. + destination-class -1 is used to represent final probabilitiy.""") + + parser.add_argument("classes_info", type=argparse.FileType('r'), + help="File with classes_info") + parser.add_argument("out_file", type=argparse.FileType('w'), + help="Output G.fst. Use '-' for stdout") + args = parser.parse_args() + return args + + +class ClassInfo(object): + def __init__(self, class_id): + self.class_id = class_id + self.start_state = -1 + self.initial_prob = 0 + self.transitions = {} + + def __str__(self): + return ("class-id={0},start-state={1}," + "initial-prob={2:.2f},transitions={3}".format( + self.class_id, self.start_state, + self.initial_prob, ' '.join( + ['{0}:{1}'.format(x, y) + for x, y in self.transitions.iteritems()]))) + + +def read_classes_info(file_handle): + classes_info = {} + + num_states = 1 + num_classes = 0 + + for line in file_handle.readlines(): + try: + parts = line.split() + class_id = int(parts[0]) + assert class_id > 0, class_id + if class_id in classes_info: + raise RuntimeError( + "Duplicate class-id {0} in file {1}".format( + class_id, file_handle.name)) + + classes_info[class_id] = ClassInfo(class_id) + class_info = classes_info[class_id] + class_info.initial_prob = float(parts[1]) + class_info.start_state = num_states + num_states += 1 + num_classes += 1 + + total_prob = 0.0 + if len(parts) > 2: + for part in parts[2:]: + dest_class, transition_prob = part.split(':') + dest_class = int(dest_class) + total_prob += float(transition_prob) + + if total_prob > 1.0: + raise ValueError("total-probability out of class {0} " + "is {1} > 1.0".format(class_id, + total_prob)) + + if dest_class in class_info.transitions: + logger.error( + "Duplicate transition to class-id {0}" + "in transitions".format(dest_class)) + raise RuntimeError + class_info.transitions[dest_class] = float(transition_prob) + + if -1 in class_info.transitions: + if abs(total_prob - 1.0) > 0.001: + raise ValueError("total-probability out of class {0} " + "is {1} != 1.0".format(class_id, + total_prob)) + else: + class_info.transitions[-1] = 1.0 - total_prob + else: + raise RuntimeError( + "No transitions out of class {0}".format(class_id)) + except Exception: + logger.error("Error processing line %s in file %s", + line, file_handle.name) + raise + + # Final state + classes_info[-1] = ClassInfo(-1) + class_info = classes_info[-1] + class_info.start_state = num_states + + for class_id, class_info in classes_info.iteritems(): + logger.info("For class %d, got class-info %s", class_id, class_info) + + return classes_info, num_classes + + +def print_states_for_class(class_id, classes_info, out_file): + class_info = classes_info[class_id] + + state = class_info.start_state + + # Transition from the FST initial state + print ("0 {end} {logprob}".format( + end=state, logprob=-math.log(class_info.initial_prob)), + file=out_file) + + for dest_class, prob in class_info.transitions.iteritems(): + try: + if dest_class == class_id: # self loop + next_state = state + else: # other transition + next_state = classes_info[dest_class].start_state + + print ("{start} {end} {class_id} {class_id} {logprob}".format( + start=state, end=next_state, class_id=class_id, + logprob=-math.log(prob)), + file=out_file) + + except Exception: + logger.error("Failed to add transition (%d->%d).\n" + "classes_info = %s", class_id, dest_class, + class_info) + + print ("{start} {final} {class_id} {class_id}".format( + start=state, final=classes_info[-1].start_state, + class_id=class_id), + file=out_file) + print ("{0}".format(classes_info[-1].start_state), file=out_file) + + +def run(args): + classes_info, num_classes = read_classes_info(args.classes_info) + + for class_id in range(1, num_classes + 1): + print_states_for_class(class_id, classes_info, args.out_file) + + +def main(): + try: + args = get_args() + run(args) + except Exception: + logger.error("Failed to make G.fst") + raise + finally: + for f in [args.classes_info, args.out_file]: + if f is not None: + f.close() + + +if __name__ == '__main__': + main() diff --git a/egs/wsj/s5/steps/segmentation/internal/make_sad_graph.sh b/egs/wsj/s5/steps/segmentation/internal/make_sad_graph.sh new file mode 100755 index 00000000000..5edb3eb2bb6 --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/internal/make_sad_graph.sh @@ -0,0 +1,83 @@ +#!/bin/bash + +# Copyright 2016 Vimal Manohar + +# Begin configuration section. +stage=0 +cmd=run.pl +iter=final # use $iter.mdl from $model_dir +tree=tree +tscale=1.0 # transition scale. +loopscale=0.1 # scale for self-loops. +# End configuration section. + +echo "$0 $@" # Print the command line for logging + +[ -f ./path.sh ] && . ./path.sh; # source the path. +. parse_options.sh || exit 1; + +if [ $# -ne 3 ]; then + echo "Usage: $0 [options] " + echo " e.g.: $0 exp/vad_dev/lang exp/vad_dev exp/vad_dev/graph" + echo "Makes the graph in \$dir, corresponding to the model in \$model_dir" + exit 1; +fi + +lang=$1 +model=$2/$iter.mdl +tree=$2/$tree +dir=$3 + +for f in $lang/G.fst $model $tree; do + if [ ! -f $f ]; then + echo "$0: expected $f to exist" + exit 1; + fi +done + +mkdir -p $dir $lang/tmp + +clg=$lang/tmp/CLG.fst + +if [[ ! -s $clg || $clg -ot $lang/G.fst ]]; then + echo "$0: creating CLG." + + fstcomposecontext --context-size=1 --central-position=0 \ + $lang/tmp/ilabels < $lang/G.fst | \ + fstarcsort --sort_type=ilabel > $clg + fstisstochastic $clg || echo "[info]: CLG not stochastic." +fi + +if [[ ! -s $dir/Ha.fst || $dir/Ha.fst -ot $model || $dir/Ha.fst -ot $lang/tmp/ilabels ]]; then + make-h-transducer --disambig-syms-out=$dir/disambig_tid.int \ + --transition-scale=$tscale $lang/tmp/ilabels $tree $model \ + > $dir/Ha.fst || exit 1; +fi + +if [[ ! -s $dir/HCLGa.fst || $dir/HCLGa.fst -ot $dir/Ha.fst || $dir/HCLGa.fst -ot $clg ]]; then + fsttablecompose $dir/Ha.fst $clg | fstdeterminizestar --use-log=true \ + | fstrmsymbols $dir/disambig_tid.int | fstrmepslocal | \ + fstminimizeencoded > $dir/HCLGa.fst || exit 1; + fstisstochastic $dir/HCLGa.fst || echo "HCLGa is not stochastic" +fi + +if [[ ! -s $dir/HCLG.fst || $dir/HCLG.fst -ot $dir/HCLGa.fst ]]; then + add-self-loops --self-loop-scale=$loopscale --reorder=true \ + $model < $dir/HCLGa.fst > $dir/HCLG.fst || exit 1; + + if [ $tscale == 1.0 -a $loopscale == 1.0 ]; then + # No point doing this test if transition-scale not 1, as it is bound to fail. + fstisstochastic $dir/HCLG.fst || echo "[info]: final HCLG is not stochastic." + fi +fi + +# keep a copy of the lexicon and a list of silence phones with HCLG... +# this means we can decode without reference to the $lang directory. + +cp $lang/words.txt $dir/ || exit 1; +cp $lang/phones.txt $dir/ 2> /dev/null # ignore the error if it's not there. + +# to make const fst: +# fstconvert --fst_type=const $dir/HCLG.fst $dir/HCLG_c.fst +am-info --print-args=false $model | grep pdfs | awk '{print $NF}' > $dir/num_pdfs + diff --git a/egs/wsj/s5/steps/segmentation/internal/post_process_segments.sh b/egs/wsj/s5/steps/segmentation/internal/post_process_segments.sh new file mode 100755 index 00000000000..31f0d09f351 --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/internal/post_process_segments.sh @@ -0,0 +1,104 @@ +#! /bin/bash + +# Copyright 2015-16 Vimal Manohar +# Apache 2.0. + +set -e +set -o pipefail +set -u + +. path.sh + +cmd=run.pl +stage=-10 + +# General segmentation options +pad_length=50 # Pad speech segments by this many frames on either side +max_blend_length=10 # Maximum duration of speech that will be removed as part + # of smoothing process. This is only if there are no other + # speech segments nearby. +max_intersegment_length=50 # Merge nearby speech segments if the silence + # between them is less than this many frames. +post_pad_length=50 # Pad speech segments by this many frames on either side + # after the merging process using max_intersegment_length +max_segment_length=1000 # Segments that are longer than this are split into + # overlapping frames. +overlap_length=100 # Overlapping frames when segments are split. + # See the above option. +min_silence_length=30 # Min silence length at which to split very long segments +min_segment_length=20 + +frame_shift=0.01 +frame_overlap=0.016 + +. utils/parse_options.sh + +if [ $# -ne 3 ]; then + echo "This script post-processes a speech activity segmentation to create " + echo "a kaldi-style data directory." + echo "See the comments for the kind of post-processing options." + echo "Usage: $0 " + echo " e.g.: $0 data/dev_aspire_whole exp/vad_dev_aspire data/dev_aspire_seg" + exit 1 +fi + +data_dir=$1 +dir=$2 +segmented_data_dir=$3 + +for f in $dir/orig_segmentation.1.gz; do + if [ ! -f $f ]; then + echo "$0: Could not find $f" + exit 1 + fi +done + +nj=`cat $dir/num_jobs` || exit 1 + +[ $pad_length -eq -1 ] && pad_length= +[ $post_pad_length -eq -1 ] && post_pad_length= +[ $max_blend_length -eq -1 ] && max_blend_length= + +if [ $stage -le 2 ]; then + # Post-process the orignal SAD segmentation using the following steps: + # 1) blend short speech segments of less than $max_blend_length frames + # into silence + # 2) Remove all silence frames and widen speech segments by padding + # $pad_length frames + # 3) Merge adjacent segments that have an intersegment length of less than + # $max_intersegment_length frames + # 4) Widen speech segments again after merging + # 5) Split segments into segments of $max_segment_length at the point where + # the original segmentation had silence + # 6) Split segments into overlapping segments of max length + # $max_segment_length and overlap $overlap_length + # 7) Convert segmentation to kaldi segments and utt2spk + $cmd JOB=1:$nj $dir/log/post_process_segmentation.JOB.log \ + gunzip -c $dir/orig_segmentation.JOB.gz \| \ + segmentation-post-process --merge-adjacent-segments --max-intersegment-length=0 ark:- ark:- \| \ + segmentation-post-process ${max_blend_length:+--max-blend-length=$max_blend_length --blend-short-segments-class=1} ark:- ark:- \| \ + segmentation-post-process --remove-labels=0 ${pad_length:+--pad-label=1 --pad-length=$pad_length} ark:- ark:- \| \ + segmentation-post-process --merge-adjacent-segments --max-intersegment-length=$max_intersegment_length ark:- ark:- \| \ + segmentation-post-process ${post_pad_length:+--pad-label=1 --pad-length=$post_pad_length} ark:- ark:- \| \ + segmentation-split-segments --alignments="ark,s,cs:gunzip -c $dir/orig_segmentation.JOB.gz | segmentation-to-ali ark:- ark:- |" \ + --max-segment-length=$max_segment_length --min-alignment-chunk-length=$min_silence_length --ali-label=0 ark:- ark:- \| \ + segmentation-post-process --remove-labels=1 --max-remove-length=$min_segment_length ark:- ark:- \| \ + segmentation-split-segments \ + --max-segment-length=$max_segment_length --overlap-length=$overlap_length ark:- ark:- \| \ + segmentation-to-segments --frame-shift=$frame_shift \ + --frame-overlap=$frame_overlap ark:- \ + ark,t:$dir/utt2spk.JOB $dir/segments.JOB || exit 1 +fi + +for n in `seq $nj`; do + cat $dir/utt2spk.$n +done > $segmented_data_dir/utt2spk + +for n in `seq $nj`; do + cat $dir/segments.$n +done > $segmented_data_dir/segments + +if [ ! -s $segmented_data_dir/utt2spk ] || [ ! -s $segmented_data_dir/segments ]; then + echo "$0: Segmentation failed to generate segments or utt2spk!" + exit 1 +fi diff --git a/egs/wsj/s5/steps/segmentation/internal/prepare_sad_lang.py b/egs/wsj/s5/steps/segmentation/internal/prepare_sad_lang.py new file mode 100755 index 00000000000..b539286a85b --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/internal/prepare_sad_lang.py @@ -0,0 +1,121 @@ +#! /usr/bin/env python + +from __future__ import print_function +import argparse +import sys +import shlex + +sys.path.insert(0, 'steps') +import libs.common as common_lib + +def GetArgs(): + parser = argparse.ArgumentParser(description="""This script generates a lang +directory for purpose of segmentation. It takes as arguments the list of phones, +the corresponding min durations and end transition probability.""") + + parser.add_argument("--phone-transition-parameters", dest='phone_transition_para_array', + type=str, action='append', required=True, + help="Options to build topology. \n" + "--phone-list= # Colon-separated list of phones\n" + "--min-duration= # Min duration for the phones\n" + "--end-transition-probability= # Probability of the end transition after the minimum duration\n") + parser.add_argument("dir", type=str, + help="Output lang directory") + args = parser.parse_args() + return args + + +def ParsePhoneTransitionParameters(para_array): + parser = argparse.ArgumentParser() + + parser.add_argument("--phone-list", type=str, required=True, + help="Colon-separated list of phones") + parser.add_argument("--min-duration", type=int, default=3, + help="Minimum number of states for the phone") + parser.add_argument("--end-transition-probability", type=float, default=0.1, + help="Probability of the end transition after the minimum duration") + + phone_transition_parameters = [ parser.parse_args(shlex.split(x)) for x in para_array ] + + for t in phone_transition_parameters: + if (t.end_transition_probability > 1.0 or + t.end_transition_probability < 0.0): + raise ValueError("Expected --end-transition-probability to be " + "between 0 and 1, got {0} for phones {1}".format( + t.end_transition_probability, t.phone_list)) + if t.min_duration > 100 or t.min_duration < 1: + raise ValueError("Expected --min-duration to be " + "between 1 and 100, got {0} for phones {1}".format( + t.min_duration, t.phone_list)) + + t.phone_list = t.phone_list.split(":") + + return phone_transition_parameters + + +def get_phone_map(phone_transition_parameters): + phone2int = {} + n = 1 + for t in phone_transition_parameters: + for p in t.phone_list: + if p in phone2int: + raise Exception("Phone {0} found in multiple topologies".format(p)) + phone2int[p] = n + n += 1 + + return phone2int + + +def print_duration_constraint_states(min_duration, topo): + for state in range(0, min_duration - 1): + print(" {state} 0" + " {dest_state} 1.0 ".format( + state=state, dest_state=state + 1), + file=topo) + + +def print_topology(phone_transition_parameters, phone2int, args, topo): + for t in phone_transition_parameters: + print ("", file=topo) + print ("", file=topo) + print ("{0}".format(" ".join([str(phone2int[p]) + for p in t.phone_list])), file=topo) + print ("", file=topo) + + print_duration_constraint_states(t.min_duration, topo) + + print(" {state} 0 " + " {state} {self_prob} " + " {next_state} {next_prob} ".format( + state=t.min_duration - 1, next_state=t.min_duration, + self_prob=1 - t.end_transition_probability, + next_prob=t.end_transition_probability), file=topo) + + print(" {state} ".format(state=t.min_duration), + file=topo) # Final state + print ("", file=topo) + + +def main(): + args = GetArgs() + phone_transition_parameters = ParsePhoneTransitionParameters(args.phone_transition_para_array) + + phone2int = get_phone_map(phone_transition_parameters) + + topo = open("{0}/topo".format(args.dir), 'w') + + print ("", file=topo) + + print_topology(phone_transition_parameters, phone2int, args, topo) + + print ("", file=topo) + + phones_file = open("{0}/phones.txt".format(args.dir), 'w') + + print (" 0", file=phones_file) + + for p,n in sorted(list(phone2int.items()), key=lambda x:x[1]): + print ("{0} {1}".format(p, n), file=phones_file) + +if __name__ == '__main__': + main() diff --git a/egs/wsj/s5/steps/segmentation/internal/prepare_simple_hmm_lang.py b/egs/wsj/s5/steps/segmentation/internal/prepare_simple_hmm_lang.py new file mode 100755 index 00000000000..eae0f142668 --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/internal/prepare_simple_hmm_lang.py @@ -0,0 +1,202 @@ +#! /usr/bin/env python + +from __future__ import print_function +import argparse +import logging +import os +import sys + +sys.path.insert(0, 'steps') +import libs.common as common_lib + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) +handler = logging.StreamHandler() +handler.setLevel(logging.INFO) +formatter = logging.Formatter("%(asctime)s [%(filename)s:%(lineno)s - " + "%(funcName)s - %(levelname)s ] %(message)s") +handler.setFormatter(formatter) +logger.addHandler(handler) + + +def get_args(): + parser = argparse.ArgumentParser( + description="""This script generates a lang directory for decoding with + simple HMM model. + It needs as an input classes_info file with the + format: + , + where each pair is :. + destination-class -1 is used to represent final probabilitiy.""") + + parser.add_argument("classes_info", type=argparse.FileType('r'), + help="File with classes_info") + parser.add_argument("dir", type=str, + help="Output lang directory") + args = parser.parse_args() + return args + + +class ClassInfo(object): + def __init__(self, class_id): + self.class_id = class_id + self.start_state = -1 + self.num_states = 0 + self.initial_prob = 0 + self.self_loop_prob= 0 + self.transitions = {} + + def __str__(self): + return ("class-id={0},start-state={1},num-states={2}," + "initial-prob={3:.2f},transitions={4}".format( + self.class_id, self.start_state, self.num_states, + self.initial_prob, ' '.join( + ['{0}:{1}'.format(x,y) + for x,y in self.transitions.iteritems()]))) + + +def read_classes_info(file_handle): + classes_info = {} + + num_states = 1 + num_classes = 0 + + for line in file_handle.readlines(): + try: + parts = line.split() + class_id = int(parts[0]) + assert class_id > 0, class_id + if class_id in classes_info: + raise RuntimeError( + "Duplicate class-id {0} in file {1}".format( + class_id, file_handle.name)) + classes_info[class_id] = ClassInfo(class_id) + class_info = classes_info[class_id] + class_info.initial_prob = float(parts[1]) + class_info.self_loop_prob = float(parts[2]) + class_info.num_states = int(parts[3]) + class_info.start_state = num_states + num_states += class_info.num_states + num_classes += 1 + + if len(parts) > 4: + for part in parts[4:]: + dest_class, transition_prob = part.split(':') + dest_class = int(dest_class) + if dest_class in class_info.transitions: + logger.error( + "Duplicate transition to class-id {0}" + "in transitions".format(dest_class)) + raise RuntimeError + class_info.transitions[dest_class] = float(transition_prob) + else: + raise RuntimeError( + "No transitions out of class {0}".format(class_id)) + except Exception: + logger.error("Error processing line %s in file %s", + line, file_handle.name) + raise + + # Final state + classes_info[-1] = ClassInfo(-1) + class_info = classes_info[-1] + class_info.num_states = 1 + class_info.start_state = num_states + + for class_id, class_info in classes_info.iteritems(): + logger.info("For class %d, dot class-info %s", class_id, class_info) + + return classes_info, num_classes + + +def print_states_for_class(class_id, classes_info, topo): + class_info = classes_info[class_id] + + assert class_info.num_states > 1, class_info + + for state in range(class_info.start_state, + class_info.start_state + class_info.num_states - 1): + print(" {state} {pdf}" + " {dest_state} 1.0 ".format( + state=state, dest_state=state + 1, + pdf=class_info.class_id - 1), + file=topo) + + state = class_info.start_state + class_info.num_states - 1 + + transitions = [] + + transitions.append(" {next_state} {next_prob}".format( + next_state=state, next_prob=class_info.self_loop_prob)) + + for dest_class, prob in class_info.transitions.iteritems(): + try: + next_state = classes_info[dest_class].start_state + + transitions.append(" {next_state} {next_prob}".format( + next_state=next_state, next_prob=prob)) + except Exception: + logger.error("Failed to add transition (%d->%d).\n" + "classes_info = %s", class_id, dest_class, + class_info) + + print(" {state} {pdf} " + "{transitions} ".format( + state=state, pdf=class_id - 1, + transitions=' '.join(transitions)), file=topo) + + +def main(): + try: + args = get_args() + run(args) + except Exception: + logger.error("Failed preparing lang directory") + raise + + +def run(args): + if not os.path.exists(args.dir): + os.makedirs(args.dir) + + classes_info, num_classes = read_classes_info(args.classes_info) + + topo = open("{0}/topo".format(args.dir), 'w') + + print ("", file=topo) + print ("", file=topo) + print ("", file=topo) + print ("1", file=topo) + print ("", file=topo) + + # Print transitions from initial state (initial probs) + transitions = [] + for class_id in range(1, num_classes + 1): + class_info = classes_info[class_id] + transitions.append(" {next_state} {next_prob}".format( + next_state=class_info.start_state, + next_prob=class_info.initial_prob)) + print(" 0 {transitions} ".format( + transitions=' '.join(transitions)), file=topo) + + for class_id in range(1, num_classes + 1): + print_states_for_class(class_id, classes_info, topo) + + print(" {state} ".format( + state=classes_info[-1].start_state), file=topo) + + print ("", file=topo) + print ("", file=topo) + topo.close() + + with open('{0}/phones.txt'.format(args.dir), 'w') as phones_f: + for class_id in range(1, num_classes + 1): + print ("{0} {1}".format(class_id - 1, class_id), file=phones_f) + + common_lib.force_symlink('{0}/phones.txt'.format(args.dir), + '{0}/words.txt'.format(args.dir)) + + +if __name__ == '__main__': + main() diff --git a/egs/wsj/s5/steps/segmentation/invert_vector.pl b/egs/wsj/s5/steps/segmentation/invert_vector.pl new file mode 100755 index 00000000000..c16243a0b93 --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/invert_vector.pl @@ -0,0 +1,20 @@ +#! /usr/bin/perl +use strict; +use warnings; + +while () { + chomp; + my @F = split; + my $utt = shift @F; + shift @F; + + print "$utt [ "; + for (my $i = 0; $i < $#F; $i++) { + if ($F[$i] == 0) { + print "1 "; + } else { + print 1.0/$F[$i] . " "; + } + } + print "]\n"; +} diff --git a/egs/wsj/s5/steps/segmentation/make_snr_targets.sh b/egs/wsj/s5/steps/segmentation/make_snr_targets.sh new file mode 100755 index 00000000000..71f603a690e --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/make_snr_targets.sh @@ -0,0 +1,104 @@ +#!/bin/bash + +# Copyright 2015-16 Vimal Manohar +# Apache 2.0 +set -e +set -o pipefail + +nj=4 +cmd=run.pl +stage=0 + +data_id= + +compress=true +target_type=Irm +apply_exp=false + +ali_rspecifier= +silence_phones_str=0 + +ignore_noise_dir=false + +ceiling=inf +floor=-inf + +length_tolerance=2 +transform_matrix= + +echo "$0 $@" # Print the command line for logging + +if [ -f path.sh ]; then . ./path.sh; fi +. parse_options.sh || exit 1; + +if [ $# != 5 ]; then + echo "Usage: $0 [options] --target-type (Irm|Snr) "; + echo " or : $0 [options] --target-type FbankMask "; + echo "e.g.: $0 data/train_clean_fbank data/train_noise_fbank data/train_corrupted_hires exp/make_snr_targets/train snr_targets" + echo "options: " + echo " --nj # number of parallel jobs" + echo " --cmd (utils/run.pl|utils/queue.pl ) # how to run jobs." + exit 1; +fi + +clean_data=$1 +noise_or_noisy_data=$2 +data=$3 +tmpdir=$4 +targets_dir=$5 + +mkdir -p $targets_dir + +[ -z "$data_id" ] && data_id=`basename $data` + +utils/split_data.sh $clean_data $nj + +for n in `seq $nj`; do + utils/subset_data_dir.sh --utt-list $clean_data/split$nj/$n/utt2spk $noise_or_noisy_data $noise_or_noisy_data/subset${nj}/$n +done + +$ignore_noise_dir && utils/split_data.sh $data $nj + +targets_dir=`perl -e '($data,$pwd)= @ARGV; if($data!~m:^/:) { $data = "$pwd/$data"; } print $data; ' $targets_dir ${PWD}` + +for n in `seq $nj`; do + utils/create_data_link.pl $targets_dir/${data_id}.$n.ark +done + +apply_exp_opts= +if $apply_exp; then + apply_exp_opts=" copy-matrix --apply-exp=true ark:- ark:- |" +fi + +copy_feats_opts="copy-feats" +if [ ! -z "$transform_matrix" ]; then + copy_feats_opts="transform-feats $transform_matrix" +fi + +if [ $stage -le 1 ]; then + if ! $ignore_noise_dir; then + $cmd JOB=1:$nj $tmpdir/make_`basename $targets_dir`_${data_id}.JOB.log \ + compute-snr-targets --length-tolerance=$length_tolerance --target-type=$target_type \ + ${ali_rspecifier:+--ali-rspecifier="$ali_rspecifier" --silence-phones=$silence_phones_str} \ + --floor=$floor --ceiling=$ceiling \ + "ark:$copy_feats_opts scp:$clean_data/split$nj/JOB/feats.scp ark:- |" \ + "ark,s,cs:$copy_feats_opts scp:$noise_or_noisy_data/subset$nj/JOB/feats.scp ark:- |" \ + ark:- \|$apply_exp_opts \ + copy-feats --compress=$compress ark:- \ + ark,scp:$targets_dir/${data_id}.JOB.ark,$targets_dir/${data_id}.JOB.scp || exit 1 + else + feat_dim=$(feat-to-dim scp:$data/feats.scp -) || exit 1 + $cmd JOB=1:$nj $tmpdir/make_`basename $targets_dir`_${data_id}.JOB.log \ + compute-snr-targets --length-tolerance=$length_tolerance --target-type=$target_type \ + ${ali_rspecifier:+--ali-rspecifier="$ali_rspecifier" --silence-phones=$silence_phones_str} \ + --floor=$floor --ceiling=$ceiling --binary-targets --target-dim=$feat_dim \ + scp:$data/split$nj/JOB/feats.scp \ + ark:- \|$apply_exp_opts \ + copy-feats --compress=$compress ark:- \ + ark,scp:$targets_dir/${data_id}.JOB.ark,$targets_dir/${data_id}.JOB.scp || exit 1 + fi +fi + +for n in `seq $nj`; do + cat $targets_dir/${data_id}.$n.scp +done > $data/`basename $targets_dir`.scp diff --git a/egs/wsj/s5/steps/segmentation/post_process_sad_to_segments.sh b/egs/wsj/s5/steps/segmentation/post_process_sad_to_segments.sh new file mode 100755 index 00000000000..c1006d09678 --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/post_process_sad_to_segments.sh @@ -0,0 +1,130 @@ +#! /bin/bash + +# Copyright 2015 Vimal Manohar +# Apache 2.0. + +set -e -o pipefail -u +. path.sh + +cmd=run.pl +stage=-10 + +segmentation_config=conf/segmentation.conf +nj=18 + +frame_shift=0.01 +weight_threshold=0.5 +ali_suffix=_acwt0.1 + +frame_subsampling_factor=1 + +phone2sad_map= + +. utils/parse_options.sh + +if [ $# -ne 5 ] && [ $# -ne 4 ]; then + echo "This script converts an alignment directory containing per-frame SAD " + echo "labels or per-frame speech probabilities into kaldi-style " + echo "segmented data directory. " + echo "This script first converts the per-frame labels or weights into " + echo "segmentation and then calls " + echo "steps/segmentation/internal/post_process_sad_to_segments.sh, " + echo "which does the actual post-processing step." + echo "Usage: $0 ( |) " + echo " e.g.: $0 data/dev_aspire_whole exp/vad_dev_aspire data/dev_aspire_seg" + exit 1 +fi + +data_dir=$1 +vad_dir= + +if [ $# -eq 5 ]; then + lang=$2 + vad_dir=$3 + shift; shift; shift +else + weights_scp=$2 + shift; shift +fi + +dir=$1 +segmented_data_dir=$2 + +utils/data/get_reco2utt.sh $data_dir + +mkdir -p $dir + +if [ ! -z "$vad_dir" ]; then + nj=`cat $vad_dir/num_jobs` || exit 1 + + utils/split_data.sh $data_dir $nj + + for n in `seq $nj`; do + cat $data_dir/split$nj/$n/segments | awk '{print $1" "$2}' | \ + utils/utt2spk_to_spk2utt.pl > $data_dir/split$nj/$n/reco2utt + done + + if [ -z "$phone2sad_map" ]; then + phone2sad_map=$dir/phone2sad_map + + { + cat $lang/phones/silence.int | awk '{print $1" 0"}'; + cat $lang/phones/nonsilence.int | awk '{print $1" 1"}'; + } | sort -k1,1 -n > $dir/phone2sad_map + fi + + frame_shift_subsampled=`perl -e "print ($frame_subsampling_factor * $frame_shift)"` + + if [ $stage -le 0 ]; then + # Convert the original SAD into segmentation + $cmd JOB=1:$nj $dir/log/segmentation.JOB.log \ + segmentation-init-from-ali \ + "ark:gunzip -c $vad_dir/ali${ali_suffix}.JOB.gz |" ark:- \| \ + segmentation-combine-segments ark:- \ + "ark:segmentation-init-from-segments --shift-to-zero=false --frame-shift=$frame_shift_subsampled $data_dir/split$nj/JOB/segments ark:- |" \ + "ark,t:$data_dir/split$nj/JOB/reco2utt" ark:- \| \ + segmentation-copy --label-map=$phone2sad_map \ + --frame-subsampling-factor=$frame_subsampling_factor ark:- \ + "ark:| gzip -c > $dir/orig_segmentation.JOB.gz" + fi +else + utils/split_data.sh $data_dir $nj + + for n in `seq $nj`; do + utils/data/get_reco2utt.sh $data_dir/split$nj/$n + utils/filter_scp.pl $data_dir/split$nj/$n/reco2utt $weights_scp > \ + $dir/weights.$n.scp + done + + $cmd JOB=1:$nj $dir/log/weights_to_segments.JOB.log \ + copy-vector scp:$dir/weights.JOB.scp ark,t:- \| \ + awk -v t=$weight_threshold '{printf $1; for (i=3; i < NF; i++) { if ($i >= t) printf (" 1"); else printf (" 0"); }; print "";}' \| \ + segmentation-init-from-ali \ + ark,t:- ark:- \| segmentation-combine-segments ark:- \ + "ark:segmentation-init-from-segments --shift-to-zero=false --frame-shift=$frame_shift_subsampled $data_dir/split$nj/JOB/segments ark:- |" \ + "ark,t:$data_dir/split$nj/JOB/reco2utt" ark:- \| \ + segmentation-copy --frame-subsampling-factor=$frame_subsampling_factor \ + ark:- "ark:| gzip -c > $dir/orig_segmentation.JOB.gz" +fi + +echo $nj > $dir/num_jobs + +if [ $stage -le 1 ]; then + rm -r $segmented_data_dir || true + utils/data/convert_data_dir_to_whole.sh $data_dir $segmented_data_dir || exit 1 + rm $segmented_data_dir/text || true +fi + +steps/segmentation/internal/post_process_segments.sh \ + --stage $stage --cmd "$cmd" \ + --config $segmentation_config --frame-shift $frame_shift \ + $data_dir $dir $segmented_data_dir + +utils/utt2spk_to_spk2utt.pl $segmented_data_dir/utt2spk > $segmented_data_dir/spk2utt || exit 1 +utils/fix_data_dir.sh $segmented_data_dir + +if [ ! -s $segmented_data_dir/utt2spk ] || [ ! -s $segmented_data_dir/segments ]; then + echo "$0: Segmentation failed to generate segments or utt2spk!" + exit 1 +fi + diff --git a/egs/wsj/s5/steps/segmentation/post_process_sad_to_subsegments.sh b/egs/wsj/s5/steps/segmentation/post_process_sad_to_subsegments.sh new file mode 100755 index 00000000000..d5ad48a492f --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/post_process_sad_to_subsegments.sh @@ -0,0 +1,92 @@ +#! /bin/bash + +# Copyright 2015 Vimal Manohar +# Apache 2.0. + +set -e -o pipefail -u +. path.sh + +cmd=run.pl +stage=-10 + +segmentation_config=conf/segmentation.conf +nj=18 + +frame_subsampling_factor=1 +frame_shift=0.01 +frame_overlap=0.015 + +. utils/parse_options.sh + +if [ $# -ne 5 ]; then + echo "Usage: $0 " + echo " e.g.: $0 data/dev_aspire_whole exp/vad_dev_aspire data/dev_aspire_seg" + exit 1 +fi + +data_dir=$1 +phone2sad_map=$2 +vad_dir=$3 +dir=$4 +segmented_data_dir=$5 + +mkdir -p $dir + +nj=`cat $vad_dir/num_jobs` || exit 1 + +utils/split_data.sh $data_dir $nj + +if [ $stage -le 0 ]; then + # Convert the original SAD into segmentation + $cmd JOB=1:$nj $dir/log/segmentation.JOB.log \ + segmentation-init-from-ali \ + "ark:gunzip -c $vad_dir/ali.JOB.gz |" ark:- \| \ + segmentation-copy --label-map=$phone2sad_map \ + --frame-subsampling-factor=$frame_subsampling_factor ark:- \ + "ark:| gzip -c > $dir/orig_segmentation.JOB.gz" +fi + +echo $nj > $dir/num_jobs + +# Create a temporary directory into which we can create the new segments +# file. +if [ $stage -le 1 ]; then + rm -r $segmented_data_dir || true + utils/data/convert_data_dir_to_whole.sh $data_dir $segmented_data_dir || exit 1 + rm $segmented_data_dir/text || true +fi + +if [ $stage -le 2 ]; then + # --frame-overlap is set to 0 to not do any additional padding when writing + # segments. This padding will be done later by the option + # --segment-end-padding to utils/data/subsegment_data_dir.sh. + steps/segmentation/internal/post_process_segments.sh \ + --stage $stage --cmd "$cmd" \ + --config $segmentation_config --frame-shift $frame_shift \ + --frame-overlap 0 \ + $data_dir $dir $segmented_data_dir +fi + +mv $segmented_data_dir/segments $segmented_data_dir/sub_segments +utils/data/subsegment_data_dir.sh --segment-end-padding `perl -e "print $frame_overlap"` \ + $data_dir $segmented_data_dir/sub_segments $segmented_data_dir +utils/fix_data_dir.sh $segmented_data_dir + +utils/data/get_reco2num_frames.sh --nj $nj --cmd "$cmd" ${data_dir} +mv $segmented_data_dir/feats.scp $segmented_data_dir/feats.scp.tmp +cat $segmented_data_dir/segments | awk '{print $1" "$2}' | \ + utils/apply_map.pl -f 2 $data_dir/reco2num_frames > \ + $segmented_data_dir/utt2max_frames +cat $segmented_data_dir/feats.scp.tmp | \ + utils/data/fix_subsegmented_feats.pl $segmented_data_dir/utt2max_frames > \ + $segmented_data_dir/feats.scp + +utils/utt2spk_to_spk2utt.pl $segmented_data_dir/utt2spk > \ + $segmented_data_dir/spk2utt || exit 1 +utils/fix_data_dir.sh $segmented_data_dir + +if [ ! -s $segmented_data_dir/utt2spk ] || [ ! -s $segmented_data_dir/segments ]; then + echo "$0: Segmentation failed to generate segments or utt2spk!" + exit 1 +fi + diff --git a/egs/wsj/s5/steps/segmentation/quantize_vector.pl b/egs/wsj/s5/steps/segmentation/quantize_vector.pl new file mode 100755 index 00000000000..0bccebade4c --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/quantize_vector.pl @@ -0,0 +1,28 @@ +#!/usr/bin/perl + +# This script convert per-frame speech probabilities into +# 0-1 labels. + +@ARGV <= 1 or die "Usage: quantize_vector.pl [threshold]"; + +my $t = 0.5; + +if (scalar @ARGV == 1) { + $t = $ARGV[0]; +} + +while () { + chomp; + my @F = split; + + my $str = "$F[0]"; + for (my $i = 2; $i < $#F; $i++) { + if ($F[$i] >= $t) { + $str = "$str 1"; + } else { + $str = "$str 0"; + } + } + + print ("$str\n"); +} diff --git a/egs/wsj/s5/steps/segmentation/split_data_on_reco.sh b/egs/wsj/s5/steps/segmentation/split_data_on_reco.sh new file mode 100755 index 00000000000..4c167d99a1e --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/split_data_on_reco.sh @@ -0,0 +1,29 @@ +#! /bin/bash + +set -e + +if [ $# -ne 3 ]; then + echo "Usage: split_data_on_reco.sh " + exit 1 +fi + +ref_data=$1 +data=$2 +nj=$3 + +utils/data/get_reco2utt.sh $ref_data +utils/data/get_reco2utt.sh $data + +utils/split_data.sh --per-reco $ref_data $nj + +for n in `seq $nj`; do + srn=$ref_data/split${nj}reco/$n + dsn=$data/split${nj}reco/$n + + mkdir -p $dsn + + utils/data/get_reco2utt.sh $srn + utils/filter_scp.pl $srn/reco2utt $data/reco2utt > $dsn/reco2utt + utils/spk2utt_to_utt2spk.pl $dsn/reco2utt > $dsn/utt2reco + utils/subset_data_dir.sh --utt-list $dsn/utt2reco $data $dsn +done diff --git a/egs/wsj/s5/steps/segmentation/train_simple_hmm.py b/egs/wsj/s5/steps/segmentation/train_simple_hmm.py new file mode 100755 index 00000000000..9f581b0a520 --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/train_simple_hmm.py @@ -0,0 +1,194 @@ +#! /usr/bin/env python + +# Copyright 2016 Vimal Manohar +# Apache 2.0. + +import argparse +import logging +import os +import sys + +sys.path.insert(0, 'steps') +import libs.common as common_lib + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) +handler = logging.StreamHandler() +handler.setLevel(logging.INFO) +formatter = logging.Formatter("%(asctime)s [%(pathname)s:%(lineno)s - " + "%(funcName)s - %(levelname)s ] %(message)s") +handler.setFormatter(formatter) +logger.addHandler(handler) + + +def get_args(): + """Parse command-line arguments""" + + parser = argparse.ArgumentParser( + """Train a simple HMM model starting from HMM topology.""") + + # Alignment options + parser.add_argument("--align.transition-scale", dest='transition_scale', + type=float, default=10.0, + help="""Transition-probability scale [relative to + acoustics]""") + parser.add_argument("--align.self-loop-scale", dest='self_loop_scale', + type=float, default=1.0, + help="""Scale on self-loop versus non-self-loop log + probs [relative to acoustics]""") + parser.add_argument("--align.beam", dest='beam', + type=float, default=6, + help="""Decoding beam used in alignment""") + + # Training options + parser.add_argument("--training.num-iters", dest='num_iters', + type=int, default=30, + help="""Number of iterations of training""") + parser.add_argument("--training.use-soft-counts", dest='use_soft_counts', + type=str, action=common_lib.StrToBoolAction, + choices=["true", "false"], default=False, + help="""Use soft counts (posteriors) instead of + alignments""") + + # General options + parser.add_argument("--scp2ark-cmd", type=str, + default="copy-int-vector scp:- ark:- |", + help="The command used to convert scp from stdin to " + "write archive to stdout") + parser.add_argument("--cmd", dest='command', type=str, + default="run.pl", + help="Command used to run jobs") + parser.add_argument("--stage", type=int, default=-10, + help="""Stage to run training from""") + + parser.add_argument("--data", type=str, required=True, + help="Data directory; primarily used for splitting") + + labels_group = parser.add_mutually_exclusive_group(required=True) + labels_group.add_argument("--labels-scp", type=str, + help="Input labels that must be convert to alignment " + "of class-ids using --scp2ark-cmd") + labels_group.add_argument("--labels-rspecifier", type=str, + help="Input labels rspecifier") + + parser.add_argument("--lang", type=str, required=True, + help="The language directory containing the " + "HMM Topology file topo") + parser.add_argument("--loglikes-dir", type=str, required=True, + help="Directory containing the log-likelihoods") + parser.add_argument("--dir", type=str, required=True, + help="Directory where the intermediate and final " + "models will be written") + + args = parser.parse_args() + + if args.use_soft_counts: + raise NotImplementedError("--use-soft-counts not supported yet!") + + return args + + +def check_files(args): + """Check files required for this script""" + + files = ("{lang}/topo {data}/utt2spk " + "{loglikes_dir}/log_likes.1.gz {loglikes_dir}/num_jobs " + "".format(lang=args.lang, data=args.data, + loglikes_dir=args.loglikes_dir).split()) + + if args.labels_scp is not None: + files.append(args.labels_scp) + + for f in files: + if not os.path.exists(f): + logger.error("Could not find file %s", f) + raise RuntimeError + + +def run(args): + """The function that does it all""" + + check_files(args) + + if args.stage <= -2: + logger.info("Initializing simple HMM model") + common_lib.run_kaldi_command( + """{cmd} {dir}/log/init.log simple-hmm-init {lang}/topo """ + """ {dir}/0.mdl""".format(cmd=args.command, dir=args.dir, + lang=args.lang)) + + num_jobs = common_lib.get_number_of_jobs(args.loglikes_dir) + split_data = common_lib.split_data(args.data, num_jobs) + + if args.labels_rspecifier is not None: + labels_rspecifier = args.labels_rspecifier + else: + labels_rspecifier = ("ark:utils/filter_scp.pl {sdata}/JOB/utt2spk " + "{labels_scp} | {scp2ark_cmd}".format( + sdata=split_data, labels_scp=args.labels_scp, + scp2ark_cmd=args.scp2ark_cmd)) + + if args.stage <= -1: + logger.info("Compiling training graphs") + common_lib.run_kaldi_command( + """{cmd} JOB=1:{nj} {dir}/log/compile_graphs.JOB.log """ + """ compile-train-simple-hmm-graphs {dir}/0.mdl """ + """ "{labels_rspecifier}" """ + """ "ark:| gzip -c > {dir}/fsts.JOB.gz" """.format( + cmd=args.command, nj=num_jobs, + dir=args.dir, lang=args.lang, + labels_rspecifier=labels_rspecifier)) + + scale_opts = ("--transition-scale={tscale} --self-loop-scale={loop_scale}" + "".format(tscale=args.transition_scale, + loop_scale=args.self_loop_scale)) + + for iter_ in range(0, args.num_iters): + if args.stage > iter_: + continue + + logger.info("Training iteration %d", iter_) + + common_lib.run_kaldi_command( + """{cmd} JOB=1:{nj} {dir}/log/align.{iter}.JOB.log """ + """ simple-hmm-align-compiled {scale_opts} """ + """ --beam={beam} --retry-beam={retry_beam} {dir}/{iter}.mdl """ + """ "ark:gunzip -c {dir}/fsts.JOB.gz |" """ + """ "ark:gunzip -c {loglikes_dir}/log_likes.JOB.gz |" """ + """ ark:- \| """ + """ simple-hmm-acc-stats-ali {dir}/{iter}.mdl ark:- """ + """ {dir}/{iter}.JOB.acc""".format( + cmd=args.command, nj=num_jobs, dir=args.dir, iter=iter_, + scale_opts=scale_opts, beam=args.beam, + retry_beam=args.beam * 4, loglikes_dir=args.loglikes_dir)) + + common_lib.run_kaldi_command( + """{cmd} {dir}/log/update.{iter}.log """ + """ simple-hmm-est {dir}/{iter}.mdl """ + """ "vector-sum {dir}/{iter}.*.acc - |" """ + """ {dir}/{new_iter}.mdl""".format( + cmd=args.command, dir=args.dir, iter=iter_, + new_iter=iter_ + 1)) + + common_lib.run_kaldi_command( + "rm {dir}/{iter}.*.acc".format(dir=args.dir, iter=iter_)) + # end train loop + + common_lib.force_symlink("{0}.mdl".format(args.num_iters), + "{0}/final.mdl".format(args.dir)) + + logger.info("Done training simple HMM in %s/final.mdl", args.dir) + + +def main(): + try: + args = get_args() + run(args) + except Exception: + logger.error("Failed training models") + raise + + +if __name__ == '__main__': + main() diff --git a/egs/wsj/s5/steps/segmentation/vector_get_max.pl b/egs/wsj/s5/steps/segmentation/vector_get_max.pl new file mode 100644 index 00000000000..abb8ea977a2 --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/vector_get_max.pl @@ -0,0 +1,26 @@ +#! /usr/bin/perl + +use warnings; +use strict; + +while (<>) { + chomp; + if (m/^\S+\s+\[.+\]\s*$/) { + my @F = split; + my $utt = shift @F; + shift; + + my $max_id = 0; + my $max = $F[0]; + for (my $i = 1; $i < $#F; $i++) { + if ($F[$i] > $max) { + $max_id = $i; + $max = $F[$i]; + } + } + + print "$utt $max_id\n"; + } else { + die "Invalid line $_\n"; + } +} diff --git a/egs/wsj/s5/utils/copy_data_dir.sh b/egs/wsj/s5/utils/copy_data_dir.sh index 008233daf62..222bc708527 100755 --- a/egs/wsj/s5/utils/copy_data_dir.sh +++ b/egs/wsj/s5/utils/copy_data_dir.sh @@ -83,15 +83,16 @@ fi if [ -f $srcdir/segments ]; then utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/segments >$destdir/segments cp $srcdir/wav.scp $destdir - if [ -f $srcdir/reco2file_and_channel ]; then - cp $srcdir/reco2file_and_channel $destdir/ - fi else # no segments->wav indexed by utt. if [ -f $srcdir/wav.scp ]; then utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/wav.scp >$destdir/wav.scp fi fi +if [ -f $srcdir/reco2file_and_channel ]; then + cp $srcdir/reco2file_and_channel $destdir/ +fi + if [ -f $srcdir/text ]; then utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/text >$destdir/text fi diff --git a/egs/wsj/s5/utils/data/convert_data_dir_to_whole.sh b/egs/wsj/s5/utils/data/convert_data_dir_to_whole.sh new file mode 100755 index 00000000000..f55f60c4774 --- /dev/null +++ b/egs/wsj/s5/utils/data/convert_data_dir_to_whole.sh @@ -0,0 +1,108 @@ +#! /bin/bash + +# This scripts converts a data directory into a "whole" data directory +# by removing the segments and using the recordings themselves as +# utterances + +set -o pipefail + +. path.sh + +cmd=run.pl +stage=-1 + +. parse_options.sh + +if [ $# -ne 2 ]; then + echo "Usage: convert_data_dir_to_whole.sh " + echo " e.g.: convert_data_dir_to_whole.sh data/dev data/dev_whole" + exit 1 +fi + +data=$1 +dir=$2 + +if [ ! -f $data/segments ]; then + # Data directory already does not contain segments. So just copy it. + utils/copy_data_dir.sh $data $dir + exit 0 +fi + +mkdir -p $dir +cp $data/wav.scp $dir +cp $data/reco2file_and_channel $dir +rm -f $dir/{utt2spk,text} || true + +[ -f $data/stm ] && cp $data/stm $dir +[ -f $data/glm ] && cp $data/glm $dir + +text_files= +[ -f $data/text ] && text_files="$data/text $dir/text" + +# Combine utt2spk and text from the segments into utt2spk and text for the whole +# recording. +cat $data/segments | perl -e ' +if (scalar @ARGV == 4) { + ($utt2spk_in, $utt2spk_out, $text_in, $text_out) = @ARGV; +} elsif (scalar @ARGV == 2) { + ($utt2spk_in, $utt2spk_out) = @ARGV; +} else { + die "Unexpected number of arguments"; +} + +if (defined $text_in) { + open(TI, "<$text_in") || die "Error: fail to open $text_in\n"; + open(TO, ">$text_out") || die "Error: fail to open $text_out\n"; +} +open(UI, "<$utt2spk_in") || die "Error: fail to open $utt2spk_in\n"; +open(UO, ">$utt2spk_out") || die "Error: fail to open $utt2spk_out\n"; + +my %file2utt = (); +while () { + chomp; + my @col = split; + @col >= 4 or die "bad line $_\n"; + + if (! defined $file2utt{$col[1]}) { + $file2utt{$col[1]} = []; + } + push @{$file2utt{$col[1]}}, $col[0]; +} + +my %text = (); +my %utt2spk = (); + +while () { + chomp; + my @col = split; + $utt2spk{$col[0]} = $col[1]; +} + +if (defined $text_in) { + while () { + chomp; + my @col = split; + @col >= 1 or die "bad line $_\n"; + + my $utt = shift @col; + $text{$utt} = join(" ", @col); + } +} + +foreach $file (keys %file2utt) { + my @utts = @{$file2utt{$file}}; + #print STDERR $file . " " . join(" ", @utts) . "\n"; + print UO "$file $file\n"; + + if (defined $text_in) { + $text_line = ""; + print TO "$file $text_line\n"; + } +} +' $data/utt2spk $dir/utt2spk $text_files + +sort -u $dir/utt2spk > $dir/utt2spk.tmp +mv $dir/utt2spk.tmp $dir/utt2spk +utils/utt2spk_to_spk2utt.pl $dir/utt2spk > $dir/spk2utt + +utils/fix_data_dir.sh $dir diff --git a/egs/wsj/s5/utils/data/data_lib.py b/egs/wsj/s5/utils/data/data_lib.py new file mode 100644 index 00000000000..5e58fcac3d5 --- /dev/null +++ b/egs/wsj/s5/utils/data/data_lib.py @@ -0,0 +1,57 @@ +import os + +import libs.common as common_lib + +def get_frame_shift(data_dir): + frame_shift = common_lib.run_kaldi_command("utils/data/get_frame_shift.sh {0}".format(data_dir))[0] + return float(frame_shift.strip()) + +def generate_utt2dur(data_dir): + common_lib.run_kaldi_command("utils/data/get_utt2dur.sh {0}".format(data_dir)) + +def get_utt2dur(data_dir): + GenerateUtt2Dur(data_dir) + utt2dur = {} + for line in open('{0}/utt2dur'.format(data_dir), 'r').readlines(): + parts = line.split() + utt2dur[parts[0]] = float(parts[1]) + return utt2dur + +def get_utt2uniq(data_dir): + utt2uniq_file = '{0}/utt2uniq'.format(data_dir) + if not os.path.exists(utt2uniq_file): + return None, None + utt2uniq = {} + uniq2utt = {} + for line in open(utt2uniq_file, 'r').readlines(): + parts = line.split() + utt2uniq[parts[0]] = parts[1] + if uniq2utt.has_key(parts[1]): + uniq2utt[parts[1]].append(parts[0]) + else: + uniq2utt[parts[1]] = [parts[0]] + return utt2uniq, uniq2utt + +def get_num_frames(data_dir, utts = None): + GenerateUtt2Dur(data_dir) + frame_shift = GetFrameShift(data_dir) + total_duration = 0 + utt2dur = GetUtt2Dur(data_dir) + if utts is None: + utts = utt2dur.keys() + for utt in utts: + total_duration = total_duration + utt2dur[utt] + return int(float(total_duration)/frame_shift) + +def create_data_links(file_names): + # if file_names already exist create_data_link.pl returns with code 1 + # so we just delete them before calling create_data_link.pl + for file_name in file_names: + TryToDelete(file_name) + common_lib.run_kaldi_command(" utils/create_data_link.pl {0}".format(" ".join(file_names))) + +def try_to_delete(file_name): + try: + os.remove(file_name) + except OSError: + pass diff --git a/egs/wsj/s5/utils/data/downsample_data_dir.sh b/egs/wsj/s5/utils/data/downsample_data_dir.sh new file mode 100755 index 00000000000..022af67d265 --- /dev/null +++ b/egs/wsj/s5/utils/data/downsample_data_dir.sh @@ -0,0 +1,34 @@ +#! /bin/bash + +# Copyright 2016 Vimal Manohar +# Apache 2.0. + +if [ $# -ne 2 ]; then + echo "Usage: $0 " + exit 1 +fi + +freq=$1 +dir=$2 + +sox=`which sox` || { echo "Could not find sox in PATH"; exit 1; } + +if [ -f $dir/feats.scp ]; then + mkdir -p $dir/.backup + mv $dir/feats.scp $dir/.backup/ + if [ -f $dir/cmvn.scp ]; then + mv $dir/cmvn.scp $dir/.backup/ + fi + echo "$0: feats.scp already exists. Moving it to $dir/.backup" +fi + +mv $dir/wav.scp $dir/wav.scp.tmp +cat $dir/wav.scp.tmp | python -c "import sys +for line in sys.stdin.readlines(): + splits = line.strip().split() + if splits[-1] == '|': + out_line = line.strip() + ' $sox -t wav - -r $freq -c 1 -b 16 -t wav - downsample |' + else: + out_line = 'cat {0} {1} | $sox -t wav - -r $freq -c 1 -b 16 -t wav - downsample |'.format(splits[0], ' '.join(splits[1:])) + print (out_line)" > ${dir}/wav.scp +rm $dir/wav.scp.tmp diff --git a/egs/wsj/s5/utils/data/fix_subsegmented_feats.pl b/egs/wsj/s5/utils/data/fix_subsegmented_feats.pl new file mode 100755 index 00000000000..b0cece46ca8 --- /dev/null +++ b/egs/wsj/s5/utils/data/fix_subsegmented_feats.pl @@ -0,0 +1,79 @@ +#!/usr/bin/env perl + +# Copyright 2016 Vimal Manohar +# Apache 2.0. + +use warnings; + +# This script modifies the feats ranges and ensures that they don't +# exceed the max number of frames supplied in utt2max_frames. +# utt2max_frames can be computed by using +# steps/segmentation/get_reco2num_frames.sh +# cut -d ' ' -f 1,2 /segments | utils/apply_map.pl -f 2 /reco2num_frames > /utt2max_frames + +(scalar @ARGV == 1) or die "Usage: fix_subsegmented_feats.pl "; + +my $utt2max_frames_file = $ARGV[0]; + +open MAX_FRAMES, $utt2max_frames_file or die "fix_subsegmented_feats.pl: Could not open file $utt2max_frames_file"; + +my %utt2max_frames; + +while () { + chomp; + my @F = split; + + (scalar @F == 2) or die "fix_subsegmented_feats.pl: Invalid line $_ in $utt2max_frames_file"; + + $utt2max_frames{$F[0]} = $F[1]; +} + +while () { + my $line = $_; + + if (m/\[([^][]*)\]\[([^][]*)\]\s*$/) { + print ("fix_subsegmented_feats.pl: this script only supports single indices"); + exit(1); + } + + my $before_range = ""; + my $range = ""; + + if (m/^(.*)\[([^][]*)\]\s*$/) { + $before_range = $1; + $range = $2; + } else { + print; + next; + } + + my @F = split(/ /, $before_range); + my $utt = shift @F; + defined $utt2max_frames{$utt} or die "fix_subsegmented_feats.pl: Could not find key $utt in $utt2max_frames_file.\nError with line $line"; + + if ($range !~ m/^(\d*):(\d*)([,]?.*)$/) { + print STDERR "fix_subsegmented_feats.pl: could not make sense of input line $_"; + exit(1); + } + + my $row_start = $1; + my $row_end = $2; + my $col_range = $3; + + if ($row_end >= $utt2max_frames{$utt}) { + print STDERR "Fixed row_end for $utt from $row_end to $utt2max_frames{$utt}-1\n"; + $row_end = $utt2max_frames{$utt} - 1; + } + + if ($row_start ne "") { + $range = "$row_start:$row_end"; + } else { + $range = ""; + } + + if ($col_range ne "") { + $range .= ",$col_range"; + } + + print ("$utt " . join(" ", @F) . "[" . $range . "]\n"); +} diff --git a/egs/wsj/s5/utils/data/get_dct_matrix.py b/egs/wsj/s5/utils/data/get_dct_matrix.py new file mode 100755 index 00000000000..88b28b5dd5c --- /dev/null +++ b/egs/wsj/s5/utils/data/get_dct_matrix.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python + +# we're using python 3.x style print but want it to work in python 2.x, +from __future__ import print_function +import os, argparse, sys, math, warnings + +import numpy as np + +def ComputeLifterCoeffs(Q, dim): + coeffs = np.zeros((dim)) + for i in range(0, dim): + coeffs[i] = 1.0 + 0.5 * Q * math.sin(math.pi * i / Q); + + return coeffs + +def ComputeIDctMatrix(K, N, cepstral_lifter=0): + matrix = np.zeros((K, N)) + # normalizer for X_0 + normalizer = math.sqrt(1.0 / N); + for j in range(0, N): + matrix[0, j] = normalizer; + # normalizer for other elements + normalizer = math.sqrt(2.0 / N); + for k in range(1, K): + for n in range(0, N): + matrix[k, n] = normalizer * math.cos(math.pi/N * (n + 0.5) * k); + + if cepstral_lifter != 0: + lifter_coeffs = ComputeLifterCoeffs(cepstral_lifter, K) + for k in range(0, K): + matrix[k, :] = matrix[k, :] / lifter_coeffs[k]; + + return matrix.T + +def ComputeDctMatrix(K, N, cepstral_lifter=0): + matrix = np.zeros((K, N)) + # normalizer for X_0 + normalizer = math.sqrt(1.0 / N); + for j in range(0, N): + matrix[0, j] = normalizer; + # normalizer for other elements + normalizer = math.sqrt(2.0 / N); + for k in range(1, K): + for n in range(0, N): + matrix[k, n] = normalizer * math.cos(math.pi/N * (n + 0.5) * k); + + if cepstral_lifter != 0: + lifter_coeffs = ComputeLifterCoeffs(cepstral_lifter, K) + for k in range(0, K): + matrix[k, :] = matrix[k, :] * lifter_coeffs[k]; + + return matrix + +def GetArgs(): + parser = argparse.ArgumentParser(description="Write DCT/IDCT matrix") + parser.add_argument("--cepstral-lifter", type=float, + help="Here we need the scaling factor on cepstra in the production of MFCC" + "to cancel out the effect of lifter, e.g. 22.0", default=22.0) + parser.add_argument("--num-ceps", type=int, + default=13, + help="Number of cepstral dimensions") + parser.add_argument("--num-filters", type=int, + default=23, + help="Number of mel filters") + parser.add_argument("--get-idct-matrix", type=str, default="false", + choices=["true","false"], + help="Get IDCT matrix instead of DCT matrix") + parser.add_argument("--add-zero-column", type=str, default="true", + choices=["true","false"], + help="Add a column to convert the matrix from a linear transform to affine transform") + parser.add_argument("out_file", type=str, + help="Output file") + + args = parser.parse_args() + + return args + +def CheckArgs(args): + if args.num_ceps > args.num_filters: + raise Exception("num-ceps must not be larger than num-filters") + + args.out_file_handle = open(args.out_file, 'w') + + return args + +def Main(): + args = GetArgs() + args = CheckArgs(args) + + if args.get_idct_matrix == "false": + matrix = ComputeDctMatrix(args.num_ceps, args.num_filters, + args.cepstral_lifter) + if args.add_zero_column == "true": + matrix = np.append(matrix, np.zeros((args.num_ceps,1)), 1) + else: + matrix = ComputeIDctMatrix(args.num_ceps, args.num_filters, + args.cepstral_lifter) + + if args.add_zero_column == "true": + matrix = np.append(matrix, np.zeros((args.num_filters,1)), 1) + + print('[ ', file=args.out_file_handle) + np.savetxt(args.out_file_handle, matrix, fmt='%.6e') + print(' ]', file=args.out_file_handle) + +if __name__ == "__main__": + Main() + diff --git a/egs/wsj/s5/utils/data/get_frame_shift.sh b/egs/wsj/s5/utils/data/get_frame_shift.sh index d032c9c17fa..f5a3bac9009 100755 --- a/egs/wsj/s5/utils/data/get_frame_shift.sh +++ b/egs/wsj/s5/utils/data/get_frame_shift.sh @@ -38,23 +38,27 @@ if [ ! -s $dir/utt2dur ]; then utils/data/get_utt2dur.sh $dir 1>&2 fi -if [ ! -f $dir/feats.scp ]; then - echo "$0: $dir/feats.scp does not exist" 1>&2 - exit 1 -fi +if [ ! -f $dir/frame_shift ]; then + if [ ! -f $dir/feats.scp ]; then + echo "$0: $dir/feats.scp does not exist" 1>&2 + exit 1 + fi -temp=$(mktemp /tmp/tmp.XXXX) + temp=$(mktemp /tmp/tmp.XXXX) -feat-to-len "scp:head -n 10 $dir/feats.scp|" ark,t:- > $temp + feat-to-len "scp:head -n 10 $dir/feats.scp|" ark,t:- > $temp -if [ -z $temp ]; then - echo "$0: error running feat-to-len" 1>&2 - exit 1 -fi + if [ -z $temp ]; then + echo "$0: error running feat-to-len" 1>&2 + exit 1 + fi -head -n 10 $dir/utt2dur | paste - $temp | \ - awk '{ dur += $2; frames += $4; } END { shift = dur / frames; if (shift > 0.01 && shift < 0.0102) shift = 0.01; print shift; }' || exit 1; + frame_shift=$(head -n 10 $dir/utt2dur | paste - $temp | awk '{ dur += $2; frames += $4; } END { shift = dur / frames; if (shift > 0.01 && shift < 0.0102) shift = 0.01; print shift; }') || exit 1; + + echo $frame_shift > $dir/frame_shift + rm $temp +fi -rm $temp +cat $dir/frame_shift exit 0 diff --git a/egs/wsj/s5/utils/data/get_reco2dur.sh b/egs/wsj/s5/utils/data/get_reco2dur.sh new file mode 100755 index 00000000000..5e925fc3e75 --- /dev/null +++ b/egs/wsj/s5/utils/data/get_reco2dur.sh @@ -0,0 +1,95 @@ +#!/bin/bash + +# Copyright 2016 Johns Hopkins University (author: Daniel Povey) +# Apache 2.0 + +# This script operates on a data directory, such as in data/train/, and adds the +# reco2dur file if it does not already exist. The file 'reco2dur' maps from +# utterance to the duration of the utterance in seconds. This script works it +# out from the 'segments' file, or, if not present, from the wav.scp file (it +# first tries interrogating the headers, and if this fails, it reads the wave +# files in entirely.) + +frame_shift=0.01 +cmd=run.pl +nj=4 + +. utils/parse_options.sh +. ./path.sh + +if [ $# != 1 ]; then + echo "Usage: $0 [options] " + echo "e.g.:" + echo " $0 data/train" + echo " Options:" + echo " --frame-shift # frame shift in seconds. Only relevant when we are" + echo " # getting duration from feats.scp (default: 0.01). " + exit 1 +fi + +export LC_ALL=C + +data=$1 + +if [ -s $data/reco2dur ] && \ + [ $(cat $data/wav.scp | wc -l) -eq $(cat $data/reco2dur | wc -l) ]; then + echo "$0: $data/reco2dur already exists with the expected length. We won't recompute it." + exit 0; +fi + +# if the wav.scp contains only lines of the form +# utt1 /foo/bar/sph2pipe -f wav /baz/foo.sph | +if cat $data/wav.scp | perl -e ' + while (<>) { s/\|\s*$/ |/; # make sure final | is preceded by space. + @A = split; if (!($#A == 5 && $A[1] =~ m/sph2pipe$/ && + $A[2] eq "-f" && $A[3] eq "wav" && $A[5] eq "|")) { exit(1); } + $utt = $A[0]; $sphere_file = $A[4]; + + if (!open(F, "<$sphere_file")) { die "Error opening sphere file $sphere_file"; } + $sample_rate = -1; $sample_count = -1; + for ($n = 0; $n <= 30; $n++) { + $line = ; + if ($line =~ m/sample_rate -i (\d+)/) { $sample_rate = $1; } + if ($line =~ m/sample_count -i (\d+)/) { $sample_count = $1; } + if ($line =~ m/end_head/) { break; } + } + close(F); + if ($sample_rate == -1 || $sample_count == -1) { + die "could not parse sphere header from $sphere_file"; + } + $duration = $sample_count * 1.0 / $sample_rate; + print "$utt $duration\n"; + } ' > $data/reco2dur; then + echo "$0: successfully obtained utterance lengths from sphere-file headers" +else + echo "$0: could not get utterance lengths from sphere-file headers, using wav-to-duration" + if ! command -v wav-to-duration >/dev/null; then + echo "$0: wav-to-duration is not on your path" + exit 1; + fi + + read_entire_file=false + if cat $data/wav.scp | grep -q 'sox.*speed'; then + read_entire_file=true + echo "$0: reading from the entire wav file to fix the problem caused by sox commands with speed perturbation. It is going to be slow." + echo "... It is much faster if you call get_reco2dur.sh *before* doing the speed perturbation via e.g. perturb_data_dir_speed.sh or " + echo "... perturb_data_dir_speed_3way.sh." + fi + + utils/split_data.sh $data $nj + if ! $cmd JOB=1:$nj $data/log/get_wav_duration.JOB.log wav-to-duration --read-entire-file=$read_entire_file scp:$data/split$nj/JOB/wav.scp ark,t:$data/split$nj/JOB/reco2dur 2>&1; then + echo "$0: there was a problem getting the durations; moving $data/reco2dur to $data/.backup/" + mkdir -p $data/.backup/ + mv $data/reco2dur $data/.backup/ + exit 1 + fi + + for n in `seq $nj`; do + cat $data/split$nj/$n/reco2dur + done > $data/reco2dur +fi + +echo "$0: computed $data/reco2dur" + +exit 0 + diff --git a/egs/wsj/s5/utils/data/get_reco2num_frames.sh b/egs/wsj/s5/utils/data/get_reco2num_frames.sh new file mode 100755 index 00000000000..edb16609703 --- /dev/null +++ b/egs/wsj/s5/utils/data/get_reco2num_frames.sh @@ -0,0 +1,28 @@ +#! /bin/bash + +cmd=run.pl +nj=4 + +frame_shift=0.01 +frame_overlap=0.015 + +. utils/parse_options.sh + +if [ $# -ne 1 ]; then + echo "Usage: $0 " + exit 1 +fi + +data=$1 + +if [ -s $data/reco2num_frames ]; then + echo "$0: $data/reco2num_frames already present!" + exit 0; +fi + +utils/data/get_reco2dur.sh --cmd "$cmd" --nj $nj $data +awk -v fs=$frame_shift -v fovlp=$frame_overlap \ + '{print $1" "int( ($2 - fovlp) / fs)}' $data/reco2dur > $data/reco2num_frames + +echo "$0: Computed and wrote $data/reco2num_frames" + diff --git a/egs/wsj/s5/utils/data/get_reco2utt.sh b/egs/wsj/s5/utils/data/get_reco2utt.sh new file mode 100755 index 00000000000..6c30f812cfe --- /dev/null +++ b/egs/wsj/s5/utils/data/get_reco2utt.sh @@ -0,0 +1,21 @@ +#! /bin/bash + +# Copyright 2016 Vimal Manohar +# Apache 2.0 + +if [ $# -ne 1 ]; then + echo "This script creates a reco2utt file in the data directory, " + echo "which is analogous to spk2utt file but with the first column " + echo "as recording instead of speaker." + echo "Usage: get_reco2utt.sh " + echo " e.g.: get_reco2utt.sh data/train" + exit 1 +fi + +data=$1 + +if [ ! -s $data/segments ]; then + utils/data/get_segments_for_data.sh $data > $data/segments +fi + +cut -d ' ' -f 1,2 $data/segments | utils/utt2spk_to_spk2utt.pl > $data/reco2utt diff --git a/egs/wsj/s5/utils/data/get_segments_for_data.sh b/egs/wsj/s5/utils/data/get_segments_for_data.sh index 694acc6a256..7adc4c465d3 100755 --- a/egs/wsj/s5/utils/data/get_segments_for_data.sh +++ b/egs/wsj/s5/utils/data/get_segments_for_data.sh @@ -19,7 +19,7 @@ fi data=$1 -if [ ! -f $data/utt2dur ]; then +if [ ! -s $data/utt2dur ]; then utils/data/get_utt2dur.sh $data 1>&2 || exit 1; fi diff --git a/egs/wsj/s5/utils/data/get_subsegmented_feats.sh b/egs/wsj/s5/utils/data/get_subsegmented_feats.sh new file mode 100755 index 00000000000..6baba68eedd --- /dev/null +++ b/egs/wsj/s5/utils/data/get_subsegmented_feats.sh @@ -0,0 +1,46 @@ +#! /bin/bash + +# Copyright 2016 Johns Hopkins University (Author: Dan Povey) +# 2016 Vimal Manohar +# Apache 2.0. + +if [ $# -ne 4 ]; then + echo "This scripts gets subsegmented_feats (by adding ranges to data/feats.scp) " + echo "for the subsegments file. This is does one part of the " + echo "functionality in subsegment_data_dir.sh, which additionally " + echo "creates a new subsegmented data directory." + echo "Usage: $0 " + echo " e.g.: $0 data/train/feats.scp 0.01 0.015 subsegments" + exit 1 +fi + +feats=$1 +frame_shift=$2 +frame_overlap=$3 +subsegments=$4 + +# The subsegments format is . +# e.g. 'utt_foo-1 utt_foo 7.21 8.93' +# The first awk command replaces this with the format: +# +# e.g. 'utt_foo-1 utt_foo 721 893' +# and the apply_map.pl command replaces 'utt_foo' (the 2nd field) with its corresponding entry +# from the original wav.scp, so we get a line like: +# e.g. 'utt_foo-1 foo-bar.ark:514231 721 892' +# Note: the reason we subtract one from the last time is that it's going to +# represent the 'last' frame, not the 'end' frame [i.e. not one past the last], +# in the matlab-like, but zero-indexed [first:last] notion. For instance, a segment with 1 frame +# would have start-time 0.00 and end-time 0.01, which would become the frame range +# [0:0] +# The second awk command turns this into something like +# utt_foo-1 foo-bar.ark:514231[721:892] +# It has to be a bit careful because the format actually allows for more general things +# like pipes that might contain spaces, so it has to be able to produce output like the +# following: +# utt_foo-1 some command|[721:892] +# Lastly, utils/data/normalize_data_range.pl will only do something nontrivial if +# the original data-dir already had data-ranges in square brackets. +awk -v s=$frame_shift -v fovlp=$frame_overlap '{print $1, $2, int(($3/s)+0.5), int(($4-fovlp)/s+0.5);}' <$subsegments| \ + utils/apply_map.pl -f 2 $feats | \ + awk '{p=NF-1; for (n=1;n $data/utt2dur elif [ -f $data/wav.scp ]; then diff --git a/egs/wsj/s5/utils/data/get_utt2num_frames.sh b/egs/wsj/s5/utils/data/get_utt2num_frames.sh new file mode 100755 index 00000000000..ec80e771c83 --- /dev/null +++ b/egs/wsj/s5/utils/data/get_utt2num_frames.sh @@ -0,0 +1,42 @@ +#! /bin/bash + +cmd=run.pl +nj=4 + +frame_shift=0.01 +frame_overlap=0.015 + +. utils/parse_options.sh + +if [ $# -ne 1 ]; then + echo "This script writes a file utt2num_frames with the " + echo "number of frames in each utterance as measured based on the " + echo "duration of the utterances (in utt2dur) and the specified " + echo "frame_shift and frame_overlap." + echo "Usage: $0 " + exit 1 +fi + +data=$1 + +if [ -s $data/utt2num_frames ]; then + echo "$0: $data/utt2num_frames already present!" + exit 0; +fi + +if [ ! -f $data/feats.scp ]; then + utils/data/get_utt2dur.sh $data + awk -v fs=$frame_shift -v fovlp=$frame_overlap \ + '{print $1" "int( ($2 - fovlp) / fs)}' $data/utt2dur > $data/utt2num_frames + exit 0 +fi + +utils/split_data.sh --per-utt $data $nj || exit 1 +$cmd JOB=1:$nj $data/log/get_utt2num_frames.JOB.log \ + feat-to-len scp:$data/split${nj}utt/JOB/feats.scp ark,t:$data/split${nj}utt/JOB/utt2num_frames || exit 1 + +for n in `seq $nj`; do + cat $data/split${nj}utt/$n/utt2num_frames +done > $data/utt2num_frames + +echo "$0: Computed and wrote $data/utt2num_frames" diff --git a/egs/wsj/s5/utils/data/modify_speaker_info.sh b/egs/wsj/s5/utils/data/modify_speaker_info.sh index f75e9be5f67..e42f0df551d 100755 --- a/egs/wsj/s5/utils/data/modify_speaker_info.sh +++ b/egs/wsj/s5/utils/data/modify_speaker_info.sh @@ -37,6 +37,7 @@ utts_per_spk_max=-1 seconds_per_spk_max=-1 respect_speaker_info=true +respect_recording_info=true # end configuration section . utils/parse_options.sh @@ -93,10 +94,26 @@ else utt2dur_opt= fi -utils/data/internal/modify_speaker_info.py \ - $utt2dur_opt --respect-speaker-info=$respect_speaker_info \ - --utts-per-spk-max=$utts_per_spk_max --seconds-per-spk-max=$seconds_per_spk_max \ - <$srcdir/utt2spk >$destdir/utt2spk +if ! $respect_speaker_info && $respect_recording_info; then + if [ -f $srcdir/segments ]; then + cat $srcdir/segments | awk '{print $1" "$2}' | \ + utils/data/internal/modify_speaker_info.py \ + $utt2dur_opt --respect-speaker-info=true \ + --utts-per-spk-max=$utts_per_spk_max --seconds-per-spk-max=$seconds_per_spk_max \ + >$destdir/utt2spk + else + cat $srcdir/wav.scp | awk '{print $1" "$2}' | \ + utils/data/internal/modify_speaker_info.py \ + $utt2dur_opt --respect-speaker-info=true \ + --utts-per-spk-max=$utts_per_spk_max --seconds-per-spk-max=$seconds_per_spk_max \ + >$destdir/utt2spk + fi +else + utils/data/internal/modify_speaker_info.py \ + $utt2dur_opt --respect-speaker-info=$respect_speaker_info \ + --utts-per-spk-max=$utts_per_spk_max --seconds-per-spk-max=$seconds_per_spk_max \ + <$srcdir/utt2spk >$destdir/utt2spk +fi utils/utt2spk_to_spk2utt.pl <$destdir/utt2spk >$destdir/spk2utt diff --git a/egs/wsj/s5/utils/data/normalize_data_range.pl b/egs/wsj/s5/utils/data/normalize_data_range.pl index f7936d98a31..61ccfd593f7 100755 --- a/egs/wsj/s5/utils/data/normalize_data_range.pl +++ b/egs/wsj/s5/utils/data/normalize_data_range.pl @@ -45,14 +45,13 @@ sub combine_ranges { # though they are supported at the C++ level. if ($start1 eq "" || $start2 eq "" || $end1 eq "" || $end2 == "") { chop $line; - print("normalize_data_range.pl: could not make sense of line $line\n"); + print STDERR ("normalize_data_range.pl: could not make sense of line $line\n"); exit(1) } if ($start1 + $end2 > $end1) { chop $line; - print("normalize_data_range.pl: could not make sense of line $line " . - "[second $row_or_column range too large vs first range, $start1 + $end2 > $end1]\n"); - exit(1); + print STDERR ("normalize_data_range.pl: could not make sense of line $line " . + "[second $row_or_column range too large vs first range, $start1 + $end2 > $end1]; adjusting end.\n"); } return ($start2+$start1, $end2+$start1); } @@ -72,11 +71,11 @@ sub combine_ranges { # sometimes in scp files, we use the command concat-feats to splice together # two feature matrices. Handling this correctly is complicated and we don't # anticipate needing it, so we just refuse to process this type of data. - print "normalize_data_range.pl: this script cannot [yet] normalize the data ranges " . + print STDERR "normalize_data_range.pl: this script cannot [yet] normalize the data ranges " . "if concat-feats was in the input data\n"; exit(1); } - print STDERR "matched: $before_range $first_range $second_range\n"; + # print STDERR "matched: $before_range $first_range $second_range\n"; if ($first_range !~ m/^((\d*):(\d*)|)(,(\d*):(\d*)|)$/) { print STDERR "normalize_data_range.pl: could not make sense of input line $_"; exit(1); diff --git a/egs/wsj/s5/utils/data/perturb_data_dir_speed_3way.sh b/egs/wsj/s5/utils/data/perturb_data_dir_speed_3way.sh index c575166534e..4b12a94eee9 100755 --- a/egs/wsj/s5/utils/data/perturb_data_dir_speed_3way.sh +++ b/egs/wsj/s5/utils/data/perturb_data_dir_speed_3way.sh @@ -43,5 +43,9 @@ utils/data/combine_data.sh $destdir ${srcdir} ${destdir}_speed0.9 ${destdir}_spe rm -r ${destdir}_speed0.9 ${destdir}_speed1.1 echo "$0: generated 3-way speed-perturbed version of data in $srcdir, in $destdir" -utils/validate_data_dir.sh --no-feats $destdir +if [ -f $srcdir/text ]; then + utils/validate_data_dir.sh --no-feats $destdir +else + utils/validate_data_dir.sh --no-feats --no-text $destdir +fi diff --git a/egs/wsj/s5/utils/data/perturb_data_dir_speed_random.sh b/egs/wsj/s5/utils/data/perturb_data_dir_speed_random.sh new file mode 100755 index 00000000000..1eb7ebb874c --- /dev/null +++ b/egs/wsj/s5/utils/data/perturb_data_dir_speed_random.sh @@ -0,0 +1,63 @@ +#!/bin/bash + +# Copyright 2017 Vimal Manohar + +# Apache 2.0 + +speeds="0.9 1.0 1.1" + +. utils/parse_options.sh + +if [ $# != 2 ]; then + echo "Usage: perturb_data_dir_speed_random.sh " + echo "Applies 3-way speed perturbation using factors of 0.9, 1.0 and 1.1 on random subsets." + echo "e.g.:" + echo " $0 data/train data/train_spr" + echo "Note: if /feats.scp already exists, this will refuse to run." + exit 1 +fi + +srcdir=$1 +destdir=$2 + +if [ ! -f $srcdir/wav.scp ]; then + echo "$0: expected $srcdir/wav.scp to exist" + exit 1 +fi + +if [ -f $destdir/feats.scp ]; then + echo "$0: $destdir/feats.scp already exists: refusing to run this (please delete $destdir/feats.scp if you want this to run)" + exit 1 +fi + +echo "$0: making sure the utt2dur file is present in ${srcdir}, because " +echo "... obtaining it after speed-perturbing would be very slow, and" +echo "... you might need it." +utils/data/get_utt2dur.sh ${srcdir} + +num_speeds=`echo $speeds | awk '{print NF}'` +utils/split_data.sh --per-reco $srcdir $num_speeds + +speed_dirs= +i=1 +for speed in $speeds; do + if [ $speed != 1.0 ]; then + utils/data/perturb_data_dir_speed.sh $speed ${srcdir}/split${num_speeds}reco/$i ${destdir}_speed$speed || exit 1 + speed_dirs="${speed_dirs} ${destdir}_speed$speed" + else + speed_dirs="$speed_dirs ${srcdir}/split${num_speeds}reco/$i" + fi +done + +utils/data/combine_data.sh $destdir ${speed_dirs} || exit 1 + +rm -r $speed_dirs ${srcdir}/split${num_speeds}reco + +echo "$0: generated $num_speeds-way speed-perturbed version of random subsets of data in $srcdir, in $destdir" +if [ -f $srcdir/text ]; then + utils/validate_data_dir.sh --no-feats $destdir +else + utils/validate_data_dir.sh --no-feats --no-text $destdir +fi + + diff --git a/egs/wsj/s5/utils/data/perturb_data_dir_volume.sh b/egs/wsj/s5/utils/data/perturb_data_dir_volume.sh index bc76939643c..ee3c281bdbb 100755 --- a/egs/wsj/s5/utils/data/perturb_data_dir_volume.sh +++ b/egs/wsj/s5/utils/data/perturb_data_dir_volume.sh @@ -7,6 +7,11 @@ # the wav.scp to perturb the volume (typically useful for training data when # using systems that don't have cepstral mean normalization). +reco2vol= +force=false +scale_low=0.125 +scale_high=2 + . utils/parse_options.sh if [ $# != 1 ]; then @@ -25,29 +30,86 @@ if [ ! -f $data/wav.scp ]; then exit 1 fi -if grep -q "sox --vol" $data/wav.scp; then +volume_perturb_done=`head -n100 $data/wav.scp | python -c " +import sys, re +for line in sys.stdin.readlines(): + if len(line.strip()) == 0: + continue + # Handle three cases of rxfilenames appropriately; 'input piped command', 'file offset' and 'filename' + parts = line.strip().split() + if line.strip()[-1] == '|': + if re.search('sox --vol', ' '.join(parts[-11:])): + print 'true' + sys.exit(0) + elif re.search(':[0-9]+$', line.strip()) is not None: + continue + else: + if ' '.join(parts[1:3]) == 'sox --vol': + print 'true' + sys.exit(0) +print 'false' +"` || exit 1 + +if $volume_perturb_done; then echo "$0: It looks like the data was already volume perturbed. Not doing anything." exit 0 fi -cat $data/wav.scp | python -c " +if [ -z "$reco2vol" ]; then + cat $data/wav.scp | python -c " import sys, os, subprocess, re, random random.seed(0) -scale_low = 1.0/8 -scale_high = 2.0 +scale_low = $scale_low +scale_high = $scale_high +volume_writer = open('$data/reco2vol', 'w') for line in sys.stdin.readlines(): if len(line.strip()) == 0: continue # Handle three cases of rxfilenames appropriately; 'input piped command', 'file offset' and 'filename' + vol = random.uniform(scale_low, scale_high) + + parts = line.strip().split() if line.strip()[-1] == '|': - print '{0} sox --vol {1} -t wav - -t wav - |'.format(line.strip(), random.uniform(scale_low, scale_high)) + print '{0} sox --vol {1} -t wav - -t wav - |'.format(line.strip(), vol) elif re.search(':[0-9]+$', line.strip()) is not None: - parts = line.split() - print '{id} wav-copy {wav} - | sox --vol {vol} -t wav - -t wav - |'.format(id = parts[0], wav=' '.join(parts[1:]), vol = random.uniform(scale_low, scale_high)) + print '{id} wav-copy {wav} - | sox --vol {vol} -t wav - -t wav - |'.format(id = parts[0], wav=' '.join(parts[1:]), vol = vol) else: - parts = line.split() - print '{id} sox --vol {vol} -t wav {wav} -t wav - |'.format(id = parts[0], wav=' '.join(parts[1:]), vol = random.uniform(scale_low, scale_high)) + print '{id} sox --vol {vol} -t wav {wav} -t wav - |'.format(id = parts[0], wav=' '.join(parts[1:]), vol = vol) + volume_writer.write('{id} {vol}\n'.format(id = parts[0], vol = vol)) " > $data/wav.scp_scaled || exit 1; +else + cat $data/wav.scp | python -c " +import sys, os, subprocess, re +volumes = {} +for line in open('$reco2vol'): + if len(line.strip()) == 0: + continue + parts = line.strip().split() + volumes[parts[0]] = float(parts[1]) + +for line in sys.stdin.readlines(): + if len(line.strip()) == 0: + continue + # Handle three cases of rxfilenames appropriately; 'input piped command', 'file offset' and 'filename' + + parts = line.strip().split() + id = parts[0] + + if id not in volumes: + raise Exception('Could not find volume for id {id}'.format(id = id)) + + vol = volumes[id] + + if line.strip()[-1] == '|': + print '{0} sox --vol {1} -t wav - -t wav - |'.format(line.strip(), vol) + elif re.search(':[0-9]+$', line.strip()) is not None: + print '{id} wav-copy {wav} - | sox --vol {vol} -t wav - -t wav - |'.format(id = parts[0], wav=' '.join(parts[1:]), vol = vol) + else: + print '{id} sox --vol {vol} -t wav {wav} -t wav - |'.format(id = parts[0], wav=' '.join(parts[1:]), vol = vol) +" > $data/wav.scp_scaled || exit 1; + + cp $reco2vol $data/reco2vol +fi len1=$(cat $data/wav.scp | wc -l) len2=$(cat $data/wav.scp_scaled | wc -l) diff --git a/egs/wsj/s5/utils/data/resample_data_dir.sh b/egs/wsj/s5/utils/data/resample_data_dir.sh new file mode 100755 index 00000000000..8781ee4c503 --- /dev/null +++ b/egs/wsj/s5/utils/data/resample_data_dir.sh @@ -0,0 +1,35 @@ +#! /bin/bash + +# Copyright 2016 Vimal Manohar +# Apache 2.0. + +if [ $# -ne 2 ]; then + echo "Usage: $0 " + exit 1 +fi + +freq=$1 +dir=$2 + +sox=`which sox` || { echo "Could not find sox in PATH"; exit 1; } + +if [ -f $dir/feats.scp ]; then + mkdir -p $dir/.backup + mv $dir/feats.scp $dir/.backup/ + if [ -f $dir/cmvn.scp ]; then + mv $dir/cmvn.scp $dir/.backup/ + fi + echo "$0: feats.scp already exists. Moving it to $dir/.backup" +fi + +mv $dir/wav.scp $dir/wav.scp.tmp +cat $dir/wav.scp.tmp | python -c "import sys +for line in sys.stdin.readlines(): + splits = line.strip().split() + if splits[-1] == '|': + out_line = line.strip() + ' $sox -t wav - -c 1 -b 16 -t wav - rate $freq |' + else: + out_line = 'cat {0} {1} | $sox -t wav - -c 1 -b 16 -t wav - rate $freq |'.format(splits[0], ' '.join(splits[1:])) + print (out_line)" > ${dir}/wav.scp +rm $dir/wav.scp.tmp + diff --git a/egs/wsj/s5/utils/data/subsegment_data_dir.sh b/egs/wsj/s5/utils/data/subsegment_data_dir.sh index 18a00c3df7d..4c664f16441 100755 --- a/egs/wsj/s5/utils/data/subsegment_data_dir.sh +++ b/egs/wsj/s5/utils/data/subsegment_data_dir.sh @@ -24,14 +24,15 @@ segment_end_padding=0.0 . utils/parse_options.sh -if [ $# != 4 ]; then +if [ $# != 4 ] && [ $# != 3 ]; then echo "Usage: " - echo " $0 [options] " + echo " $0 [options] [] " echo "This script sub-segments a data directory. is to" echo "have lines of the form " echo "and is of the form ... ." echo "This script appropriately combines the with the original" echo "segments file, if necessary, and if not, creates a segments file." + echo " is an optional argument." echo "e.g.:" echo " $0 data/train [options] exp/tri3b_resegment/segments exp/tri3b_resegment/text data/train_resegmented" echo " Options:" @@ -50,11 +51,23 @@ export LC_ALL=C srcdir=$1 subsegments=$2 -new_text=$3 -dir=$4 +add_subsegment_text=false +if [ $# -eq 4 ]; then + new_text=$3 + dir=$4 + add_subsegment_text=true -for f in "$subsegments" "$new_text" "$srcdir/utt2spk"; do + if [ ! -f "$new_text" ]; then + echo "$0: no such file $new_text" + exit 1 + fi + +else + dir=$3 +fi + +for f in "$subsegments" "$srcdir/utt2spk"; do if [ ! -f "$f" ]; then echo "$0: no such file $f" exit 1; @@ -65,9 +78,11 @@ if ! mkdir -p $dir; then echo "$0: failed to create directory $dir" fi -if ! cmp <(awk '{print $1}' <$subsegments) <(awk '{print $1}' <$new_text); then - echo "$0: expected the first fields of the files $subsegments and $new_text to be identical" - exit 1 +if $add_subsegment_text; then + if ! cmp <(awk '{print $1}' <$subsegments) <(awk '{print $1}' <$new_text); then + echo "$0: expected the first fields of the files $subsegments and $new_text to be identical" + exit 1 + fi fi # create the utt2spk in $dir @@ -86,8 +101,11 @@ awk '{print $1, $2}' < $subsegments > $dir/new2old_utt utils/apply_map.pl -f 2 $srcdir/utt2spk < $dir/new2old_utt >$dir/utt2spk # .. and the new spk2utt file. utils/utt2spk_to_spk2utt.pl <$dir/utt2spk >$dir/spk2utt -# the new text file is just what the user provides. -cp $new_text $dir/text + +if $add_subsegment_text; then + # the new text file is just what the user provides. + cp $new_text $dir/text +fi # copy the source wav.scp cp $srcdir/wav.scp $dir @@ -125,6 +143,10 @@ if [ -f $srcdir/feats.scp ]; then frame_shift=$(utils/data/get_frame_shift.sh $srcdir) echo "$0: note: frame shift is $frame_shift [affects feats.scp]" + utils/data/get_utt2num_frames.sh --cmd "run.pl" --nj 1 $srcdir + awk '{print $1" "$2}' $subsegments | \ + utils/apply_map.pl -f 2 $srcdir/utt2num_frames > \ + $dir/utt2max_frames # The subsegments format is . # e.g. 'utt_foo-1 utt_foo 7.21 8.93' @@ -147,10 +169,22 @@ if [ -f $srcdir/feats.scp ]; then # utt_foo-1 some command|[721:892] # Lastly, utils/data/normalize_data_range.pl will only do something nontrivial if # the original data-dir already had data-ranges in square brackets. - awk -v s=$frame_shift '{print $1, $2, int(($3/s)+0.5), int(($4/s)-0.5);}' <$subsegments| \ + cat $subsegments | awk -v s=$frame_shift '{print $1, $2, int(($3/s)+0.5), int(($4/s)-0.5);}' | \ utils/apply_map.pl -f 2 $srcdir/feats.scp | \ awk '{p=NF-1; for (n=1;n$dir/feats.scp + utils/data/normalize_data_range.pl | \ + utils/data/fix_subsegmented_feats.pl $dir/utt2max_frames >$dir/feats.scp + + cat $dir/feats.scp | perl -ne 'm/^(\S+) .+\[(\d+):(\d+)\]$/; print "$1 " . ($3-$2+1) . "\n"' > \ + $dir/utt2num_frames + + if [ -f $srcdir/vad.scp ]; then + cat $subsegments | awk -v s=$frame_shift '{print $1, $2, int(($3/s)+0.5), int(($4/s)-0.5);}' | \ + utils/apply_map.pl -f 2 $srcdir/vad.scp | \ + awk '{p=NF-1; for (n=1;n$dir/vad.scp + fi fi @@ -184,6 +218,7 @@ utils/data/fix_data_dir.sh $dir validate_opts= [ ! -f $srcdir/feats.scp ] && validate_opts="$validate_opts --no-feats" [ ! -f $srcdir/wav.scp ] && validate_opts="$validate_opts --no-wav" +! $add_subsegment_text && validate_opts="$validate_opts --no-text" utils/data/validate_data_dir.sh $validate_opts $dir diff --git a/egs/wsj/s5/utils/fix_data_dir.sh b/egs/wsj/s5/utils/fix_data_dir.sh index cbbcbe8f8c4..8ebfc8d49fe 100755 --- a/egs/wsj/s5/utils/fix_data_dir.sh +++ b/egs/wsj/s5/utils/fix_data_dir.sh @@ -6,6 +6,11 @@ # It puts the original contents of data-dir into # data-dir/.backup +utt_extra_files= +spk_extra_files= + +. utils/parse_options.sh + if [ $# != 1 ]; then echo "Usage: utils/data/fix_data_dir.sh " echo "e.g.: utils/data/fix_data_dir.sh data/train" @@ -111,7 +116,7 @@ function filter_speakers { filter_file $tmpdir/speakers $data/spk2utt utils/spk2utt_to_utt2spk.pl $data/spk2utt > $data/utt2spk - for s in cmvn.scp spk2gender; do + for s in cmvn.scp spk2gender $spk_extra_files; do f=$data/$s if [ -f $f ]; then filter_file $tmpdir/speakers $f @@ -159,7 +164,7 @@ function filter_utts { fi fi - for x in utt2spk utt2uniq feats.scp vad.scp text segments utt2lang utt2dur utt2num_frames $maybe_wav; do + for x in utt2spk utt2uniq feats.scp vad.scp text segments utt2lang utt2dur utt2num_frames $maybe_wav $utt_extra_files; do if [ -f $data/$x ]; then cp $data/$x $data/.backup/$x if ! cmp -s $data/$x <( utils/filter_scp.pl $tmpdir/utts $data/$x ) ; then diff --git a/egs/wsj/s5/utils/perturb_data_dir_speed.sh b/egs/wsj/s5/utils/perturb_data_dir_speed.sh index 20ff86755eb..e3d56d58b9c 100755 --- a/egs/wsj/s5/utils/perturb_data_dir_speed.sh +++ b/egs/wsj/s5/utils/perturb_data_dir_speed.sh @@ -112,4 +112,9 @@ cat $srcdir/utt2dur | utils/apply_map.pl -f 1 $destdir/utt_map | \ rm $destdir/spk_map $destdir/utt_map 2>/dev/null echo "$0: generated speed-perturbed version of data in $srcdir, in $destdir" -utils/validate_data_dir.sh --no-feats $destdir + +if [ -f $srcdir/text ]; then + utils/validate_data_dir.sh --no-feats $destdir +else + utils/validate_data_dir.sh --no-feats --no-text $destdir +fi diff --git a/egs/wsj/s5/utils/split_data.sh b/egs/wsj/s5/utils/split_data.sh index ab0dbbf35c7..94ba4f555ce 100755 --- a/egs/wsj/s5/utils/split_data.sh +++ b/egs/wsj/s5/utils/split_data.sh @@ -16,20 +16,28 @@ # limitations under the License. split_per_spk=true +split_per_reco=false if [ "$1" == "--per-utt" ]; then split_per_spk=false shift +elif [ "$1" == "--per-reco" ]; then + split_per_spk=false + split_per_reco=true + shift fi if [ $# != 2 ]; then - echo "Usage: $0 [--per-utt] " + echo "Usage: $0 [--per-utt|--per-reco] " echo "E.g.: $0 data/train 50" echo "It creates its output in e.g. data/train/split50/{1,2,3,...50}, or if the " echo "--per-utt option was given, in e.g. data/train/split50utt/{1,2,3,...50}." + echo "If the --per-reco option was given, in e.g. data/train/split50reco/{1,2,3,...50}." echo "" echo "This script will not split the data-dir if it detects that the output is newer than the input." echo "By default it splits per speaker (so each speaker is in only one split dir)," echo "but with the --per-utt option it will ignore the speaker information while splitting." + echo "But if --per-reco option is given, it splits per recording " + echo "(so each recording is in only one split dir)" exit 1 fi @@ -67,10 +75,14 @@ if [ -f $data/text ] && [ $nu -ne $nt ]; then echo "** use utils/fix_data_dir.sh to fix this." fi - if $split_per_spk; then utt2spk_opt="--utt2spk=$data/utt2spk" utt="" +elif $split_per_reco; then + utils/data/get_reco2utt.sh $data + utils/spk2utt_to_utt2spk.pl $data/reco2utt > $data/utt2reco + utt2spk_opt="--utt2spk=$data/utt2reco" + utt="reco" else utt2spk_opt= utt="utt" @@ -94,6 +106,7 @@ if ! $need_to_split; then fi utt2spks=$(for n in `seq $numsplit`; do echo $data/split${numsplit}${utt}/$n/utt2spk; done) +utt2recos=$(for n in `seq $numsplit`; do echo $data/split${numsplit}${utt}/$n/utt2reco; done) directories=$(for n in `seq $numsplit`; do echo $data/split${numsplit}${utt}/$n; done) @@ -108,11 +121,20 @@ fi which lockfile >&/dev/null && lockfile -l 60 $data/.split_lock trap 'rm -f $data/.split_lock' EXIT HUP INT PIPE TERM -utils/split_scp.pl $utt2spk_opt $data/utt2spk $utt2spks || exit 1 +if $split_per_reco; then + utils/split_scp.pl $utt2spk_opt $data/utt2reco $utt2recos || exit 1 +else + utils/split_scp.pl $utt2spk_opt $data/utt2spk $utt2spks || exit 1 +fi for n in `seq $numsplit`; do dsn=$data/split${numsplit}${utt}/$n - utils/utt2spk_to_spk2utt.pl $dsn/utt2spk > $dsn/spk2utt || exit 1; + + if $split_per_reco; then + utils/filter_scp.pl $dsn/utt2reco $data/utt2spk > $dsn/utt2spk + fi + + utils/utt2spk_to_spk2utt.pl $dsn/utt2spk > $dsn/spk2utt || exit 1 done maybe_wav_scp= @@ -122,7 +144,7 @@ if [ ! -f $data/segments ]; then fi # split some things that are indexed by utterance. -for f in feats.scp text vad.scp utt2lang $maybe_wav_scp; do +for f in feats.scp text vad.scp utt2lang $maybe_wav_scp utt2dur utt2num_frames; do if [ -f $data/$f ]; then utils/filter_scps.pl JOB=1:$numsplit \ $data/split${numsplit}${utt}/JOB/utt2spk $data/$f $data/split${numsplit}${utt}/JOB/$f || exit 1; @@ -154,6 +176,12 @@ if [ -f $data/segments ]; then $data/split${numsplit}${utt}/JOB/tmp.reco $data/wav.scp \ $data/split${numsplit}${utt}/JOB/wav.scp || exit 1 fi + if [ -f $data/reco2utt ]; then + utils/filter_scps.pl JOB=1:$numsplit \ + $data/split${numsplit}${utt}/JOB/tmp.reco $data/reco2utt \ + $data/split${numsplit}${utt}/JOB/reco2utt || exit 1 + fi + for f in $data/split${numsplit}${utt}/*/tmp.reco; do rm $f; done fi diff --git a/egs/wsj/s5/utils/subset_data_dir.sh b/egs/wsj/s5/utils/subset_data_dir.sh index 5fe3217ddad..9533d0216c9 100755 --- a/egs/wsj/s5/utils/subset_data_dir.sh +++ b/egs/wsj/s5/utils/subset_data_dir.sh @@ -108,6 +108,7 @@ function do_filtering { [ -f $srcdir/vad.scp ] && utils/filter_scp.pl $destdir/utt2spk <$srcdir/vad.scp >$destdir/vad.scp [ -f $srcdir/utt2lang ] && utils/filter_scp.pl $destdir/utt2spk <$srcdir/utt2lang >$destdir/utt2lang [ -f $srcdir/utt2dur ] && utils/filter_scp.pl $destdir/utt2spk <$srcdir/utt2dur >$destdir/utt2dur + [ -f $srcdir/utt2uniq ] && utils/filter_scp.pl $destdir/utt2spk <$srcdir/utt2uniq >$destdir/utt2uniq [ -f $srcdir/wav.scp ] && utils/filter_scp.pl $destdir/utt2spk <$srcdir/wav.scp >$destdir/wav.scp [ -f $srcdir/spk2warp ] && utils/filter_scp.pl $destdir/spk2utt <$srcdir/spk2warp >$destdir/spk2warp [ -f $srcdir/utt2warp ] && utils/filter_scp.pl $destdir/utt2spk <$srcdir/utt2warp >$destdir/utt2warp @@ -126,6 +127,10 @@ function do_filtering { [ -f $srcdir/stm ] && utils/filter_scp.pl $destdir/reco < $srcdir/stm > $destdir/stm rm $destdir/reco + else + awk '{print $1;}' $destdir/wav.scp | sort | uniq > $destdir/reco + [ -f $srcdir/reco2file_and_channel ] && \ + utils/filter_scp.pl $destdir/reco <$srcdir/reco2file_and_channel >$destdir/reco2file_and_channel fi srcutts=`cat $srcdir/utt2spk | wc -l` destutts=`cat $destdir/utt2spk | wc -l` diff --git a/src/Makefile b/src/Makefile index 52b23261b76..b7ac6f60bd4 100644 --- a/src/Makefile +++ b/src/Makefile @@ -6,16 +6,16 @@ SHELL := /bin/bash SUBDIRS = base matrix util feat tree thread gmm transform \ - fstext hmm lm decoder lat kws cudamatrix nnet \ + fstext hmm simplehmm lm decoder lat kws cudamatrix nnet segmenter \ bin fstbin gmmbin fgmmbin featbin \ nnetbin latbin sgmm2 sgmm2bin nnet2 nnet3 chain nnet3bin nnet2bin kwsbin \ - ivector ivectorbin online2 online2bin lmbin chainbin + ivector ivectorbin online2 online2bin lmbin chainbin segmenterbin simplehmmbin MEMTESTDIRS = base matrix util feat tree thread gmm transform \ - fstext hmm lm decoder lat nnet kws chain \ + fstext hmm simplehmm lm decoder lat nnet kws chain segmenter \ bin fstbin gmmbin fgmmbin featbin \ nnetbin latbin sgmm2 nnet2 nnet3 nnet2bin nnet3bin sgmm2bin kwsbin \ - ivector ivectorbin online2 online2bin lmbin + ivector ivectorbin online2 online2bin lmbin segmenterbin simplehmmbin CUDAMEMTESTDIR = cudamatrix @@ -150,9 +150,9 @@ $(EXT_SUBDIRS) : mklibdir ext_depend # this is necessary for correct parallel compilation #1)The tools depend on all the libraries -bin fstbin gmmbin fgmmbin sgmm2bin featbin nnetbin nnet2bin nnet3bin chainbin latbin ivectorbin lmbin kwsbin online2bin: \ - base matrix util feat tree thread gmm transform sgmm2 fstext hmm \ - lm decoder lat cudamatrix nnet nnet2 nnet3 ivector chain kws online2 +bin fstbin gmmbin fgmmbin sgmm2bin featbin nnetbin nnet2bin nnet3bin chainbin latbin ivectorbin lmbin kwsbin online2bin segmenterbin simplehmmbin: \ + base matrix util feat tree thread gmm transform sgmm2 fstext hmm simplehmm \ + lm decoder lat cudamatrix nnet nnet2 nnet3 ivector chain kws online2 segmenter #2)The libraries have inter-dependencies base: base/.depend.mk @@ -175,6 +175,8 @@ nnet2: base util matrix thread lat gmm hmm tree transform cudamatrix nnet3: base util matrix thread lat gmm hmm tree transform cudamatrix chain fstext chain: lat hmm tree fstext matrix cudamatrix util thread base ivector: base util matrix thread transform tree gmm +segmenter: base matrix util gmm thread tree +simplehmm: base tree matrix util thread hmm #3)Dependencies for optional parts of Kaldi onlinebin: base matrix util feat tree gmm transform sgmm2 fstext hmm lm decoder lat cudamatrix nnet nnet2 online thread # python-kaldi-decoding: base matrix util feat tree thread gmm transform sgmm2 fstext hmm decoder lat online diff --git a/src/bin/Makefile b/src/bin/Makefile index 687040889b3..1948ba2d681 100644 --- a/src/bin/Makefile +++ b/src/bin/Makefile @@ -24,7 +24,8 @@ BINFILES = align-equal align-equal-compiled acc-tree-stats \ matrix-logprob matrix-sum \ build-pfile-from-ali get-post-on-ali tree-info am-info \ vector-sum matrix-sum-rows est-pca sum-lda-accs sum-mllt-accs \ - transform-vec align-text matrix-dim + transform-vec align-text matrix-dim weight-pdf-post weight-matrix \ + matrix-add-offset matrix-dot-product compute-fscore OBJFILES = diff --git a/src/bin/compute-fscore.cc b/src/bin/compute-fscore.cc new file mode 100644 index 00000000000..eb231fe361e --- /dev/null +++ b/src/bin/compute-fscore.cc @@ -0,0 +1,153 @@ +// bin/compute-fscore.cc + +// Copyright 2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" + +int main(int argc, char *argv[]) { + using namespace kaldi; + typedef kaldi::int32 int32; + + try { + const char *usage = + "Compute F1-score, precision, recall etc.\n" + "Takes two alignment files and computes statistics\n" + "\n" + "Usage: compute-fscore [options] \n" + " e.g.: compute-fscore ark:data/train/text ark:hyp_text\n"; + + ParseOptions po(usage); + + std::string mode = "strict"; + std::string mask_rspecifier; + + po.Register("mode", &mode, + "Scoring mode: \"present\"|\"all\"|\"strict\":\n" + " \"present\" means score those we have transcriptions for\n" + " \"all\" means treat absent transcriptions as empty\n" + " \"strict\" means die if all in ref not also in hyp"); + po.Register("mask", &mask_rspecifier, + "Only score on frames where mask is 1"); + + po.Read(argc, argv); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + std::string ref_rspecifier = po.GetArg(1); + std::string hyp_rspecifier = po.GetArg(2); + + if (mode != "strict" && mode != "present" && mode != "all") { + KALDI_ERR << "--mode option invalid: expected \"present\"|\"all\"|\"strict\", got " + << mode; + } + + int64 num_tp = 0, num_fp = 0, num_tn = 0, num_fn = 0, num_frames = 0; + int32 num_absent_sents = 0; + + // Both text and integers are loaded as vector of strings, + SequentialInt32VectorReader ref_reader(ref_rspecifier); + RandomAccessInt32VectorReader hyp_reader(hyp_rspecifier); + RandomAccessInt32VectorReader mask_reader(mask_rspecifier); + + // Main loop, accumulate WER stats, + for (; !ref_reader.Done(); ref_reader.Next()) { + const std::string &key = ref_reader.Key(); + const std::vector &ref_ali = ref_reader.Value(); + std::vector hyp_ali; + if (!hyp_reader.HasKey(key)) { + if (mode == "strict") + KALDI_ERR << "No hypothesis for key " << key << " and strict " + "mode specifier."; + num_absent_sents++; + if (mode == "present") // do not score this one. + continue; + } else { + hyp_ali = hyp_reader.Value(key); + } + + std::vector mask_ali; + if (!mask_rspecifier.empty()) { + if (!mask_reader.HasKey(key)) { + if (mode == "strict") + KALDI_ERR << "No hypothesis for key " << key << " and strict " + "mode specifier."; + num_absent_sents++; + if (mode == "present") // do not score this one. + continue; + } else { + mask_ali = mask_reader.Value(key); + } + } + + for (int32 i = 0; i < ref_ali.size(); i++) { + if ( (i < hyp_ali.size() && hyp_ali[i] != 0 && hyp_ali[i] != 1) || + (i < ref_ali.size() && ref_ali[i] != 0 && ref_ali[i] != 1) || + (i < mask_ali.size() && mask_ali[i] != 0 && mask_ali[i] != 1) ) { + KALDI_ERR << "Expecting alignment to be 0s or 1s"; + } + + if (!mask_rspecifier.empty() && (std::abs(static_cast(ref_ali.size()) - static_cast(mask_ali.size())) > 2) ) + KALDI_ERR << "Length mismatch: mask vs ref"; + + if (!mask_rspecifier.empty() && (i > mask_ali.size() || mask_ali[i] == 0)) continue; + num_frames++; + + if (ref_ali[i] == 1 && i > hyp_ali.size()) { num_fn++; continue; } + if (ref_ali[i] == 0 && i > hyp_ali.size()) { num_tn++; continue; } + + if (ref_ali[i] == 1 && hyp_ali[i] == 1) num_tp++; + else if (ref_ali[i] == 0 && hyp_ali[i] == 1) num_fp++; + else if (ref_ali[i] == 1 && hyp_ali[i] == 0) num_fn++; + else if (ref_ali[i] == 0 && hyp_ali[i] == 0) num_tn++; + else + KALDI_ERR << "Unknown condition"; + } + } + + // Print the ouptut, + std::cout.precision(2); + std::cerr.precision(2); + + BaseFloat precision = static_cast(num_tp) / (num_tp + num_fp); + BaseFloat recall = static_cast(num_tp) / (num_tp + num_fn); + + std::cout << "F1 " << 2 * precision * recall / (precision + recall) << "\n"; + std::cout << "Precision " << precision << "\n"; + std::cout << "Recall " << recall << "\n"; + std::cout << "Specificity " + << static_cast(num_tn) / (num_tn + num_fp) << "\n"; + std::cout << "Accuracy " + << static_cast(num_tp + num_tn) / num_frames << "\n"; + + std::cerr << "TP " << num_tp << "\n"; + std::cerr << "FP " << num_fp << "\n"; + std::cerr << "TN " << num_tn << "\n"; + std::cerr << "FN " << num_fn << "\n"; + std::cerr << "Length " << num_frames << "\n"; + + return 0; + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/bin/copy-matrix.cc b/src/bin/copy-matrix.cc index d7b8181c64c..56f2e51d90f 100644 --- a/src/bin/copy-matrix.cc +++ b/src/bin/copy-matrix.cc @@ -36,16 +36,30 @@ int main(int argc, char *argv[]) { " e.g.: copy-matrix --binary=false 1.mat -\n" " copy-matrix ark:2.trans ark,t:-\n" "See also: copy-feats\n"; - + bool binary = true; + bool apply_log = false; + bool apply_exp = false; + bool apply_softmax_per_row = false; + BaseFloat apply_power = 1.0; BaseFloat scale = 1.0; + ParseOptions po(usage); po.Register("binary", &binary, "Write in binary mode (only relevant if output is a wxfilename)"); po.Register("scale", &scale, "This option can be used to scale the matrices being copied."); - + po.Register("apply-log", &apply_log, + "This option can be used to apply log on the matrices. " + "Must be avoided if matrix has negative quantities."); + po.Register("apply-exp", &apply_exp, + "This option can be used to apply exp on the matrices"); + po.Register("apply-power", &apply_power, + "This option can be used to apply a power on the matrices"); + po.Register("apply-softmax-per-row", &apply_softmax_per_row, + "This option can be used to apply softmax per row of the matrices"); + po.Read(argc, argv); if (po.NumArgs() != 2) { @@ -53,6 +67,10 @@ int main(int argc, char *argv[]) { exit(1); } + if ( (apply_log && apply_exp) || (apply_softmax_per_row && apply_exp) || + (apply_softmax_per_row && apply_log) ) + KALDI_ERR << "Only one of apply-log, apply-exp and " + << "apply-softmax-per-row can be given"; std::string matrix_in_fn = po.GetArg(1), matrix_out_fn = po.GetArg(2); @@ -68,11 +86,15 @@ int main(int argc, char *argv[]) { if (in_is_rspecifier != out_is_wspecifier) KALDI_ERR << "Cannot mix archives with regular files (copying matrices)"; - + if (!in_is_rspecifier) { Matrix mat; ReadKaldiObject(matrix_in_fn, &mat); if (scale != 1.0) mat.Scale(scale); + if (apply_log) mat.ApplyLog(); + if (apply_exp) mat.ApplyExp(); + if (apply_softmax_per_row) mat.ApplySoftMaxPerRow(); + if (apply_power != 1.0) mat.ApplyPow(apply_power); Output ko(matrix_out_fn, binary); mat.Write(ko.Stream(), binary); KALDI_LOG << "Copied matrix to " << matrix_out_fn; @@ -82,9 +104,14 @@ int main(int argc, char *argv[]) { BaseFloatMatrixWriter writer(matrix_out_fn); SequentialBaseFloatMatrixReader reader(matrix_in_fn); for (; !reader.Done(); reader.Next(), num_done++) { - if (scale != 1.0) { + if (scale != 1.0 || apply_log || apply_exp || + apply_power != 1.0 || apply_softmax_per_row) { Matrix mat(reader.Value()); - mat.Scale(scale); + if (scale != 1.0) mat.Scale(scale); + if (apply_log) mat.ApplyLog(); + if (apply_exp) mat.ApplyExp(); + if (apply_softmax_per_row) mat.ApplySoftMaxPerRow(); + if (apply_power != 1.0) mat.ApplyPow(apply_power); writer.Write(reader.Key(), mat); } else { writer.Write(reader.Key(), reader.Value()); diff --git a/src/bin/matrix-add-offset.cc b/src/bin/matrix-add-offset.cc new file mode 100644 index 00000000000..90f72ba3254 --- /dev/null +++ b/src/bin/matrix-add-offset.cc @@ -0,0 +1,84 @@ +// bin/matrix-add-offset.cc + +// Copyright 2015 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "matrix/kaldi-matrix.h" + + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + + const char *usage = + "Add an offset vector to the rows of matrices in a table.\n" + "\n" + "Usage: matrix-add-offset [options] " + " \n" + "e.g.: matrix-add-offset log_post.mat neg_priors.vec log_like.mat\n" + "See also: matrix-sum-rows, matrix-sum, vector-sum\n"; + + + ParseOptions po(usage); + + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + std::string rspecifier = po.GetArg(1); + std::string vector_rxfilename = po.GetArg(2); + std::string wspecifier = po.GetArg(3); + + SequentialBaseFloatMatrixReader mat_reader(rspecifier); + BaseFloatMatrixWriter mat_writer(wspecifier); + + int32 num_done = 0; + + Vector vec; + { + bool binary_in; + Input ki(vector_rxfilename, &binary_in); + vec.Read(ki.Stream(), binary_in); + } + + for (; !mat_reader.Done(); mat_reader.Next()) { + std::string key = mat_reader.Key(); + Matrix mat(mat_reader.Value()); + if (vec.Dim() != mat.NumCols()) { + KALDI_ERR << "Mismatch in vector dimension and " + << "number of columns in matrix; " + << vec.Dim() << " vs " << mat.NumCols(); + } + mat.AddVecToRows(1.0, vec); + mat_writer.Write(key, mat); + num_done++; + } + + KALDI_LOG << "Added offset to " << num_done << " matrices."; + + return (num_done != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + + diff --git a/src/bin/matrix-dot-product.cc b/src/bin/matrix-dot-product.cc new file mode 100644 index 00000000000..a292cab9a40 --- /dev/null +++ b/src/bin/matrix-dot-product.cc @@ -0,0 +1,183 @@ +// bin/matrix-dot-product.cc + +// Copyright 2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "matrix/kaldi-matrix.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + + const char *usage = + "Get element-wise dot product of matrices. Always returns a matrix " + "that is the same size as the first matrix.\n" + "If there is a mismatch in number of rows, the utterance is skipped, " + "unless the mismatch is within a tolerance. If the second matrix has " + "number of rows that is larger than the first matrix by less than the " + "specified tolerance, then a submatrix of the second matrix is " + "multiplied element-wise with the first matrix.\n" + "\n" + "Usage: matrix-dot-product [options] " + "[ ...] " + "\n" + " e.g.: matrix-dot-product ark:1.weights ark:2.weights " + "ark:combine.weights\n" + "or \n" + "Usage: matrix-dot-product [options] " + "[ ...] " + "\n" + " e.g.: matrix-sum --binary=false 1.mat 2.mat product.mat\n" + "See also: matrix-sum, matrix-sum-rows\n"; + + bool binary = true; + int32 length_tolerance = 0; + + ParseOptions po(usage); + + po.Register("binary", &binary, "If true, write output as binary (only " + "relevant for usage types two or three"); + po.Register("length-tolerance", &length_tolerance, + "Tolerance length mismatch of this many frames"); + + po.Read(argc, argv); + + if (po.NumArgs() < 2) { + po.PrintUsage(); + exit(1); + } + + int32 N = po.NumArgs(); + std::string matrix_in_fn1 = po.GetArg(1), + matrix_out_fn = po.GetArg(N); + + if (ClassifyWspecifier(matrix_out_fn, NULL, NULL, NULL) != kNoWspecifier) { + // output to table. + + // Output matrix + BaseFloatMatrixWriter matrix_writer(matrix_out_fn); + + // Input matrices + SequentialBaseFloatMatrixReader matrix_reader1(matrix_in_fn1); + std::vector + matrix_readers(N-2, + static_cast(NULL)); + std::vector matrix_in_fns(N-2); + for (int32 i = 2; i < N; ++i) { + matrix_readers[i-2] = new RandomAccessBaseFloatMatrixReader( + po.GetArg(i)); + matrix_in_fns[i-2] = po.GetArg(i); + } + int32 n_utts = 0, n_total_matrices = 0, + n_success = 0, n_missing = 0, n_other_errors = 0; + + for (; !matrix_reader1.Done(); matrix_reader1.Next()) { + std::string key = matrix_reader1.Key(); + Matrix matrix1 = matrix_reader1.Value(); + matrix_reader1.FreeCurrent(); + n_utts++; + n_total_matrices++; + + Matrix matrix_out(matrix1); + + int32 i = 0; + for (i = 0; i < N-2; ++i) { + bool failed = false; // Indicates failure for this key. + if (matrix_readers[i]->HasKey(key)) { + const Matrix &matrix2 = matrix_readers[i]->Value(key); + n_total_matrices++; + if (SameDim(matrix2, matrix_out)) { + matrix_out.MulElements(matrix2); + } else { + KALDI_WARN << "Dimension mismatch for utterance " << key + << " : " << matrix2.NumRows() << " by " + << matrix2.NumCols() << " for " + << "system " << (i + 2) << ", rspecifier: " + << matrix_in_fns[i] << " vs " << matrix_out.NumRows() + << " by " << matrix_out.NumCols() + << " primary matrix, rspecifier:" << matrix_in_fn1; + if (matrix2.NumRows() - matrix_out.NumRows() <= + length_tolerance) { + KALDI_WARN << "Tolerated length mismatch for key " << key; + matrix_out.MulElements(matrix2.Range(0, matrix_out.NumRows(), + 0, matrix2.NumCols())); + } else { + KALDI_WARN << "Skipping key " << key; + failed = true; + n_other_errors++; + } + } + } else { + KALDI_WARN << "No matrix found for utterance " << key << " for " + << "system " << (i + 2) << ", rspecifier: " + << matrix_in_fns[i]; + failed = true; + n_missing++; + } + + if (failed) break; + } + + if (i != N-2) // Skipping utterance + continue; + + matrix_writer.Write(key, matrix_out); + n_success++; + } + + KALDI_LOG << "Processed " << n_utts << " utterances: with a total of " + << n_total_matrices << " matrices across " << (N-1) + << " different systems."; + KALDI_LOG << "Produced output for " << n_success << " utterances; " + << n_missing << " total missing matrices and skipped " + << n_other_errors << "matrices."; + + DeletePointers(&matrix_readers); + + return (n_success != 0 && n_missing < (n_success - n_missing)) ? 0 : 1; + } else { + for (int32 i = 1; i < N; i++) { + if (ClassifyRspecifier(po.GetArg(i), NULL, NULL) != kNoRspecifier) { + KALDI_ERR << "Wrong usage: if last argument is not " + << "table, the other arguments must not be tables."; + } + } + + Matrix mat1; + ReadKaldiObject(po.GetArg(1), &mat1); + + for (int32 i = 2; i < N; i++) { + Matrix mat; + ReadKaldiObject(po.GetArg(i), &mat); + + mat1.MulElements(mat); + } + + WriteKaldiObject(mat1, po.GetArg(N), binary); + KALDI_LOG << "Multiplied " << (po.NumArgs() - 1) << " matrices; " + << "wrote product to " << PrintableWxfilename(po.GetArg(N)); + + return 0; + } + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/bin/matrix-sum-rows.cc b/src/bin/matrix-sum-rows.cc index 7e60483eef2..ee6504ba2b1 100644 --- a/src/bin/matrix-sum-rows.cc +++ b/src/bin/matrix-sum-rows.cc @@ -34,9 +34,13 @@ int main(int argc, char *argv[]) { "e.g.: matrix-sum-rows ark:- ark:- | vector-sum ark:- sum.vec\n" "See also: matrix-sum, vector-sum\n"; + bool do_average = false; ParseOptions po(usage); + po.Register("do-average", &do_average, + "Do average instead of sum"); + po.Read(argc, argv); if (po.NumArgs() != 2) { @@ -45,28 +49,28 @@ int main(int argc, char *argv[]) { } std::string rspecifier = po.GetArg(1); std::string wspecifier = po.GetArg(2); - + SequentialBaseFloatMatrixReader mat_reader(rspecifier); BaseFloatVectorWriter vec_writer(wspecifier); - + int32 num_done = 0; int64 num_rows_done = 0; - + for (; !mat_reader.Done(); mat_reader.Next()) { std::string key = mat_reader.Key(); Matrix mat(mat_reader.Value()); Vector vec(mat.NumCols()); - vec.AddRowSumMat(1.0, mat, 0.0); + vec.AddRowSumMat(!do_average ? 1.0 : 1.0 / mat.NumRows(), mat, 0.0); // Do the summation in double, to minimize roundoff. Vector float_vec(vec); vec_writer.Write(key, float_vec); num_done++; num_rows_done += mat.NumRows(); } - + KALDI_LOG << "Summed rows " << num_done << " matrices, " << num_rows_done << " rows in total."; - + return (num_done != 0 ? 0 : 1); } catch(const std::exception &e) { std::cerr << e.what(); diff --git a/src/bin/vector-scale.cc b/src/bin/vector-scale.cc index 60d4d3121d2..ea68ae31ad0 100644 --- a/src/bin/vector-scale.cc +++ b/src/bin/vector-scale.cc @@ -30,11 +30,14 @@ int main(int argc, char *argv[]) { const char *usage = "Scale a set of vectors in a Table (useful for speaker vectors and " "per-frame weights)\n" - "Usage: vector-scale [options] \n"; + "Usage: vector-scale [options] \n"; ParseOptions po(usage); BaseFloat scale = 1.0; + bool binary = false; + po.Register("binary", &binary, "If true, write output as binary " + "not relevant for archives"); po.Register("scale", &scale, "Scaling factor for vectors"); po.Read(argc, argv); @@ -43,17 +46,33 @@ int main(int argc, char *argv[]) { exit(1); } - std::string rspecifier = po.GetArg(1); - std::string wspecifier = po.GetArg(2); + std::string vector_in_fn = po.GetArg(1); + std::string vector_out_fn = po.GetArg(2); - BaseFloatVectorWriter vec_writer(wspecifier); - - SequentialBaseFloatVectorReader vec_reader(rspecifier); - for (; !vec_reader.Done(); vec_reader.Next()) { - Vector vec(vec_reader.Value()); + if (ClassifyWspecifier(vector_in_fn, NULL, NULL, NULL) != kNoWspecifier) { + if (ClassifyRspecifier(vector_in_fn, NULL, NULL) == kNoRspecifier) { + KALDI_ERR << "Cannot mix archives and regular files"; + } + BaseFloatVectorWriter vec_writer(vector_out_fn); + SequentialBaseFloatVectorReader vec_reader(vector_in_fn); + for (; !vec_reader.Done(); vec_reader.Next()) { + Vector vec(vec_reader.Value()); + vec.Scale(scale); + vec_writer.Write(vec_reader.Key(), vec); + } + } else { + if (ClassifyRspecifier(vector_in_fn, NULL, NULL) != kNoRspecifier) { + KALDI_ERR << "Cannot mix archives and regular files"; + } + bool binary_in; + Input ki(vector_in_fn, &binary_in); + Vector vec; + vec.Read(ki.Stream(), binary_in); vec.Scale(scale); - vec_writer.Write(vec_reader.Key(), vec); + Output ko(vector_out_fn, binary); + vec.Write(ko.Stream(), binary); } + return 0; } catch(const std::exception &e) { std::cerr << e.what(); diff --git a/src/bin/weight-matrix.cc b/src/bin/weight-matrix.cc new file mode 100644 index 00000000000..c6823b8da29 --- /dev/null +++ b/src/bin/weight-matrix.cc @@ -0,0 +1,84 @@ +// bin/weight-matrix.cc + +// Copyright 2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#include "base/kaldi-common.h" +#include "util/common-utils.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + typedef kaldi::int32 int32; + + const char *usage = + "Takes archives (typically per-utterance) of features and " + "per-frame weights,\n" + "and weights the features by the per-frame weights\n" + "\n" + "Usage: weight-matrix " + "\n"; + + ParseOptions po(usage); + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + std::string matrix_rspecifier = po.GetArg(1), + weights_rspecifier = po.GetArg(2), + matrix_wspecifier = po.GetArg(3); + + SequentialBaseFloatMatrixReader matrix_reader(matrix_rspecifier); + RandomAccessBaseFloatVectorReader weights_reader(weights_rspecifier); + BaseFloatMatrixWriter matrix_writer(matrix_wspecifier); + + int32 num_done = 0, num_err = 0; + + for (; !matrix_reader.Done(); matrix_reader.Next()) { + std::string key = matrix_reader.Key(); + Matrix mat = matrix_reader.Value(); + if (!weights_reader.HasKey(key)) { + KALDI_WARN << "No weight vectors for utterance " << key; + num_err++; + continue; + } + const Vector &weights = weights_reader.Value(key); + if (weights.Dim() != mat.NumRows()) { + KALDI_WARN << "Weights for utterance " << key + << " have wrong size, " << weights.Dim() + << " vs. " << mat.NumRows(); + num_err++; + continue; + } + mat.MulRowsVec(weights); + matrix_writer.Write(key, mat); + num_done++; + } + KALDI_LOG << "Applied per-frame weights for " << num_done + << " matrices; errors on " << num_err; + return (num_done != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + + diff --git a/src/bin/weight-pdf-post.cc b/src/bin/weight-pdf-post.cc new file mode 100644 index 00000000000..c7477a046c8 --- /dev/null +++ b/src/bin/weight-pdf-post.cc @@ -0,0 +1,154 @@ +// bin/weight-pdf-post.cc + +// Copyright 2015 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "gmm/am-diag-gmm.h" +#include "hmm/transition-model.h" +#include "hmm/hmm-utils.h" +#include "hmm/posterior.h" + +namespace kaldi { + +void WeightPdfPost(const ConstIntegerSet &pdf_set, + BaseFloat pdf_scale, + Posterior *post) { + for (size_t i = 0; i < post->size(); i++) { + std::vector > this_post; + this_post.reserve((*post)[i].size()); + for (size_t j = 0; j < (*post)[i].size(); j++) { + int32 pdf_id = (*post)[i][j].first; + BaseFloat weight = (*post)[i][j].second; + if (pdf_set.count(pdf_id) != 0) { // is a silence. + if (pdf_scale != 0.0) + this_post.push_back(std::make_pair(pdf_id, weight*pdf_scale)); + } else { + this_post.push_back(std::make_pair(pdf_id, weight)); + } + } + (*post)[i].swap(this_post); + } +} + +void WeightPdfPostDistributed(const ConstIntegerSet &pdf_set, + BaseFloat pdf_scale, + Posterior *post) { + for (size_t i = 0; i < post->size(); i++) { + std::vector > this_post; + this_post.reserve((*post)[i].size()); + BaseFloat sil_weight = 0.0, nonsil_weight = 0.0; + for (size_t j = 0; j < (*post)[i].size(); j++) { + int32 pdf_id = (*post)[i][j].first; + BaseFloat weight = (*post)[i][j].second; + if (pdf_set.count(pdf_id) != 0) + sil_weight += weight; + else + nonsil_weight += weight; + } + // This "distributed" weighting approach doesn't make sense if we have + // negative weights. + KALDI_ASSERT(sil_weight >= 0.0 && nonsil_weight >= 0.0); + if (sil_weight + nonsil_weight == 0.0) continue; + BaseFloat frame_scale = (sil_weight * pdf_scale + nonsil_weight) / + (sil_weight + nonsil_weight); + if (frame_scale != 0.0) { + for (size_t j = 0; j < (*post)[i].size(); j++) { + int32 pdf_id = (*post)[i][j].first; + BaseFloat weight = (*post)[i][j].second; + this_post.push_back(std::make_pair(pdf_id, weight * frame_scale)); + } + } + (*post)[i].swap(this_post); + } +} + +} // namespace kaldi + +int main(int argc, char *argv[]) { + using namespace kaldi; + typedef kaldi::int32 int32; + try { + const char *usage = + "Apply weight to specific pdfs or tids in posts\n" + "Usage: weight-pdf-post [options] " + " \n" + "e.g.:\n" + " weight-pdf-post 0.00001 0:2 ark:1.post ark:nosil.post\n"; + + ParseOptions po(usage); + + bool distribute = false; + + po.Register("distribute", &distribute, "If true, rather than weighting the " + "individual posteriors, apply the weighting to the " + "whole frame: " + "i.e. on time t, scale all posterior entries by " + "p(sil)*silence-weight + p(non-sil)*1.0"); + + po.Read(argc, argv); + + if (po.NumArgs() != 4) { + po.PrintUsage(); + exit(1); + } + + std::string pdf_weight_str = po.GetArg(1), + pdfs_str = po.GetArg(2), + posteriors_rspecifier = po.GetArg(3), + posteriors_wspecifier = po.GetArg(4); + + BaseFloat pdf_weight = 0.0; + if (!ConvertStringToReal(pdf_weight_str, &pdf_weight)) + KALDI_ERR << "Invalid pdf-weight parameter: expected float, got \"" + << pdf_weight << '"'; + std::vector pdfs; + if (!SplitStringToIntegers(pdfs_str, ":", false, &pdfs)) + KALDI_ERR << "Invalid pdf string string " << pdfs_str; + if (pdfs.empty()) + KALDI_WARN <<"No pdf specified, this will have no effect"; + ConstIntegerSet pdf_set(pdfs); // faster lookup. + + int32 num_posteriors = 0; + SequentialPosteriorReader posterior_reader(posteriors_rspecifier); + PosteriorWriter posterior_writer(posteriors_wspecifier); + + for (; !posterior_reader.Done(); posterior_reader.Next()) { + num_posteriors++; + // Posterior is vector > > + Posterior post = posterior_reader.Value(); + // Posterior is vector > > + if (distribute) + WeightPdfPostDistributed(pdf_set, + pdf_weight, &post); + else + WeightPdfPost(pdf_set, + pdf_weight, &post); + + posterior_writer.Write(posterior_reader.Key(), post); + } + KALDI_LOG << "Done " << num_posteriors << " posteriors."; + return (num_posteriors != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + + diff --git a/src/bin/weight-post.cc b/src/bin/weight-post.cc index d536896eaaa..bbaad465195 100644 --- a/src/bin/weight-post.cc +++ b/src/bin/weight-post.cc @@ -26,32 +26,38 @@ int main(int argc, char *argv[]) { try { using namespace kaldi; - typedef kaldi::int32 int32; + typedef kaldi::int32 int32; + + int32 length_tolerance = 2; const char *usage = "Takes archives (typically per-utterance) of posteriors and per-frame weights,\n" "and weights the posteriors by the per-frame weights\n" "\n" "Usage: weight-post \n"; - + ParseOptions po(usage); + + po.Register("length-tolerance", &length_tolerance, + "Tolerate this many frames of length mismatch"); + po.Read(argc, argv); if (po.NumArgs() != 3) { po.PrintUsage(); exit(1); } - + std::string post_rspecifier = po.GetArg(1), weights_rspecifier = po.GetArg(2), post_wspecifier = po.GetArg(3); SequentialPosteriorReader posterior_reader(post_rspecifier); RandomAccessBaseFloatVectorReader weights_reader(weights_rspecifier); - PosteriorWriter post_writer(post_wspecifier); - + PosteriorWriter post_writer(post_wspecifier); + int32 num_done = 0, num_err = 0; - + for (; !posterior_reader.Done(); posterior_reader.Next()) { std::string key = posterior_reader.Key(); Posterior post = posterior_reader.Value(); @@ -61,7 +67,8 @@ int main(int argc, char *argv[]) { continue; } const Vector &weights = weights_reader.Value(key); - if (weights.Dim() != static_cast(post.size())) { + if (std::abs(weights.Dim() - static_cast(post.size())) > + length_tolerance) { KALDI_WARN << "Weights for utterance " << key << " have wrong size, " << weights.Dim() << " vs. " << post.size(); @@ -71,7 +78,7 @@ int main(int argc, char *argv[]) { for (size_t i = 0; i < post.size(); i++) { if (weights(i) == 0.0) post[i].clear(); for (size_t j = 0; j < post[i].size(); j++) - post[i][j].second *= weights(i); + post[i][j].second *= i < weights.Dim() ? weights(i) : 0.0; } post_writer.Write(key, post); num_done++; diff --git a/src/feat/feature-fbank.cc b/src/feat/feature-fbank.cc index c54069696b5..3c53ef1ec08 100644 --- a/src/feat/feature-fbank.cc +++ b/src/feat/feature-fbank.cc @@ -28,9 +28,9 @@ FbankComputer::FbankComputer(const FbankOptions &opts): if (opts.energy_floor > 0.0) log_energy_floor_ = Log(opts.energy_floor); - int32 padded_window_size = opts.frame_opts.PaddedWindowSize(); - if ((padded_window_size & (padded_window_size-1)) == 0) // Is a power of two... - srfft_ = new SplitRadixRealFft(padded_window_size); + int32 num_fft_bins = opts.frame_opts.NumFftBins(); + if ((num_fft_bins & (num_fft_bins-1)) == 0) // Is a power of two... + srfft_ = new SplitRadixRealFft(num_fft_bins); // We'll definitely need the filterbanks info for VTLN warping factor 1.0. // [note: this call caches it.] @@ -76,7 +76,7 @@ void FbankComputer::Compute(BaseFloat signal_log_energy, const MelBanks &mel_banks = *(GetMelBanks(vtln_warp)); - KALDI_ASSERT(signal_frame->Dim() == opts_.frame_opts.PaddedWindowSize() && + KALDI_ASSERT(signal_frame->Dim() == opts_.frame_opts.NumFftBins() && feature->Dim() == this->Dim()); diff --git a/src/feat/feature-mfcc.cc b/src/feat/feature-mfcc.cc index c1962a5c1d1..47912cc8693 100644 --- a/src/feat/feature-mfcc.cc +++ b/src/feat/feature-mfcc.cc @@ -29,7 +29,7 @@ void MfccComputer::Compute(BaseFloat signal_log_energy, BaseFloat vtln_warp, VectorBase *signal_frame, VectorBase *feature) { - KALDI_ASSERT(signal_frame->Dim() == opts_.frame_opts.PaddedWindowSize() && + KALDI_ASSERT(signal_frame->Dim() == opts_.frame_opts.NumFftBins() && feature->Dim() == this->Dim()); const MelBanks &mel_banks = *(GetMelBanks(vtln_warp)); @@ -98,9 +98,9 @@ MfccComputer::MfccComputer(const MfccOptions &opts): if (opts.energy_floor > 0.0) log_energy_floor_ = Log(opts.energy_floor); - int32 padded_window_size = opts.frame_opts.PaddedWindowSize(); - if ((padded_window_size & (padded_window_size-1)) == 0) // Is a power of two... - srfft_ = new SplitRadixRealFft(padded_window_size); + int32 num_fft_bins = opts.frame_opts.NumFftBins(); + if ((num_fft_bins & (num_fft_bins-1)) == 0) // Is a power of two... + srfft_ = new SplitRadixRealFft(num_fft_bins); // We'll definitely need the filterbanks info for VTLN warping factor 1.0. // [note: this call caches it.] diff --git a/src/feat/feature-spectrogram.cc b/src/feat/feature-spectrogram.cc index 953f38fc54f..f5f1c420462 100644 --- a/src/feat/feature-spectrogram.cc +++ b/src/feat/feature-spectrogram.cc @@ -48,7 +48,7 @@ void SpectrogramComputer::Compute(BaseFloat signal_log_energy, BaseFloat vtln_warp, VectorBase *signal_frame, VectorBase *feature) { - KALDI_ASSERT(signal_frame->Dim() == opts_.frame_opts.PaddedWindowSize() && + KALDI_ASSERT(signal_frame->Dim() == opts_.frame_opts.NumFftBins() && feature->Dim() == this->Dim()); diff --git a/src/feat/feature-spectrogram.h b/src/feat/feature-spectrogram.h index ec318556f24..6ca0697ef78 100644 --- a/src/feat/feature-spectrogram.h +++ b/src/feat/feature-spectrogram.h @@ -39,10 +39,13 @@ struct SpectrogramOptions { FrameExtractionOptions frame_opts; BaseFloat energy_floor; bool raw_energy; // If true, compute energy before preemphasis and windowing + bool use_energy; // append an extra dimension with energy to the filter banks + BaseFloat low_freq; // e.g. 20; lower frequency cutoff + BaseFloat high_freq; // an upper frequency cutoff; 0 -> no cutoff, negative SpectrogramOptions() : energy_floor(0.0), // not in log scale: a small value e.g. 1.0e-10 - raw_energy(true) {} + raw_energy(true), use_energy(true), low_freq(0), high_freq(0) {} void Register(OptionsItf *opts) { frame_opts.Register(opts); @@ -50,6 +53,12 @@ struct SpectrogramOptions { "Floor on energy (absolute, not relative) in Spectrogram computation"); opts->Register("raw-energy", &raw_energy, "If true, compute energy before preemphasis and windowing"); + opts->Register("use-energy", &use_energy, + "Add an extra dimension with energy to the spectrogram output."); + opts->Register("low-freq", &low_freq, + "Low cutoff frequency for mel bins"); + opts->Register("high-freq", &high_freq, + "High cutoff frequency for mel bins (if < 0, offset from Nyquist)"); } }; diff --git a/src/feat/feature-window.cc b/src/feat/feature-window.cc index 65c0a2a29c3..7b86e71dbb7 100644 --- a/src/feat/feature-window.cc +++ b/src/feat/feature-window.cc @@ -163,7 +163,7 @@ void ExtractWindow(int64 sample_offset, BaseFloat *log_energy_pre_window) { KALDI_ASSERT(sample_offset >= 0 && wave.Dim() != 0); int32 frame_length = opts.WindowSize(), - frame_length_padded = opts.PaddedWindowSize(); + num_fft_bins = opts.NumFftBins(); int64 num_samples = sample_offset + wave.Dim(), start_sample = FirstSampleOfFrame(f, opts), end_sample = start_sample + frame_length; @@ -175,8 +175,8 @@ void ExtractWindow(int64 sample_offset, KALDI_ASSERT(sample_offset == 0 || start_sample >= sample_offset); } - if (window->Dim() != frame_length_padded) - window->Resize(frame_length_padded, kUndefined); + if (window->Dim() != num_fft_bins) + window->Resize(num_fft_bins, kUndefined); // wave_start and wave_end are start and end indexes into 'wave', for the // piece of wave that we're trying to extract. @@ -206,8 +206,8 @@ void ExtractWindow(int64 sample_offset, } } - if (frame_length_padded > frame_length) - window->Range(frame_length, frame_length_padded - frame_length).SetZero(); + if (num_fft_bins > frame_length) + window->Range(frame_length, num_fft_bins - frame_length).SetZero(); SubVector frame(*window, 0, frame_length); diff --git a/src/feat/feature-window.h b/src/feat/feature-window.h index bbb24fd8988..d6acf7e2bed 100644 --- a/src/feat/feature-window.h +++ b/src/feat/feature-window.h @@ -42,6 +42,7 @@ struct FrameExtractionOptions { std::string window_type; // e.g. Hamming window bool round_to_power_of_two; BaseFloat blackman_coeff; + int32 num_fft_bins; bool snip_edges; // May be "hamming", "rectangular", "povey", "hanning", "blackman" // "povey" is a window I made to be similar to Hamming but to go to zero at the @@ -57,6 +58,7 @@ struct FrameExtractionOptions { window_type("povey"), round_to_power_of_two(true), blackman_coeff(0.42), + num_fft_bins(128), snip_edges(true){ } void Register(OptionsItf *opts) { @@ -78,6 +80,8 @@ struct FrameExtractionOptions { opts->Register("round-to-power-of-two", &round_to_power_of_two, "If true, round window size to power of two by zero-padding " "input to FFT."); + opts->Register("num-fft-bins", &num_fft_bins, + "Number of FFT bins to compute spectrogram"); opts->Register("snip-edges", &snip_edges, "If true, end effects will be handled by outputting only frames that " "completely fit in the file, and the number of frames depends on the " @@ -94,6 +98,13 @@ struct FrameExtractionOptions { return (round_to_power_of_two ? RoundUpToNearestPowerOfTwo(WindowSize()) : WindowSize()); } + int32 NumFftBins() const { + int32 padded_window_size = PaddedWindowSize(); + if (num_fft_bins > padded_window_size) + return (round_to_power_of_two ? RoundUpToNearestPowerOfTwo(num_fft_bins) : + num_fft_bins); + return padded_window_size; + } }; diff --git a/src/feat/mel-computations.cc b/src/feat/mel-computations.cc index 714d963f01b..db3f3334ca2 100644 --- a/src/feat/mel-computations.cc +++ b/src/feat/mel-computations.cc @@ -37,13 +37,7 @@ MelBanks::MelBanks(const MelBanksOptions &opts, int32 num_bins = opts.num_bins; if (num_bins < 3) KALDI_ERR << "Must have at least 3 mel bins"; BaseFloat sample_freq = frame_opts.samp_freq; - int32 window_length = static_cast(frame_opts.samp_freq*0.001*frame_opts.frame_length_ms); - int32 window_length_padded = - (frame_opts.round_to_power_of_two ? - RoundUpToNearestPowerOfTwo(window_length) : - window_length); - KALDI_ASSERT(window_length_padded % 2 == 0); - int32 num_fft_bins = window_length_padded/2; + int32 num_fft_bins = frame_opts.NumFftBins(); BaseFloat nyquist = 0.5 * sample_freq; BaseFloat low_freq = opts.low_freq, high_freq; @@ -59,8 +53,8 @@ MelBanks::MelBanks(const MelBanksOptions &opts, << " and high-freq " << high_freq << " vs. nyquist " << nyquist; - BaseFloat fft_bin_width = sample_freq / window_length_padded; - // fft-bin width [think of it as Nyquist-freq / half-window-length] + BaseFloat fft_bin_width = sample_freq / num_fft_bins; + // fft-bin width [think of it as Nyquist-freq / num_fft_bins] BaseFloat mel_low_freq = MelScale(low_freq); BaseFloat mel_high_freq = MelScale(high_freq); @@ -104,9 +98,9 @@ MelBanks::MelBanks(const MelBanksOptions &opts, center_freqs_(bin) = InverseMelScale(center_mel); // this_bin will be a vector of coefficients that is only // nonzero where this mel bin is active. - Vector this_bin(num_fft_bins); + Vector this_bin(num_fft_bins / 2); int32 first_index = -1, last_index = -1; - for (int32 i = 0; i < num_fft_bins; i++) { + for (int32 i = 0; i < num_fft_bins / 2; i++) { BaseFloat freq = (fft_bin_width * i); // Center frequency of this fft // bin. BaseFloat mel = MelScale(freq); diff --git a/src/feat/pitch-functions.cc b/src/feat/pitch-functions.cc index 430e9bdb53a..07e1d181243 100644 --- a/src/feat/pitch-functions.cc +++ b/src/feat/pitch-functions.cc @@ -1402,7 +1402,8 @@ OnlineProcessPitch::OnlineProcessPitch( dim_ ((opts.add_pov_feature ? 1 : 0) + (opts.add_normalized_log_pitch ? 1 : 0) + (opts.add_delta_pitch ? 1 : 0) - + (opts.add_raw_log_pitch ? 1 : 0)) { + + (opts.add_raw_log_pitch ? 1 : 0) + + (opts.add_raw_pov ? 1 : 0)) { KALDI_ASSERT(dim_ > 0 && " At least one of the pitch features should be chosen. " "Check your post-process-pitch options."); @@ -1425,6 +1426,8 @@ void OnlineProcessPitch::GetFrame(int32 frame, (*feat)(index++) = GetDeltaPitchFeature(frame_delayed); if (opts_.add_raw_log_pitch) (*feat)(index++) = GetRawLogPitchFeature(frame_delayed); + if (opts_.add_raw_pov) + (*feat)(index++) = GetRawPov(frame_delayed); KALDI_ASSERT(index == dim_); } @@ -1482,6 +1485,13 @@ BaseFloat OnlineProcessPitch::GetNormalizedLogPitchFeature(int32 frame) { return normalized_log_pitch * opts_.pitch_scale; } +BaseFloat OnlineProcessPitch::GetRawPov(int32 frame) const { + Vector tmp(kRawFeatureDim); + src_->GetFrame(frame, &tmp); // (NCCF, pitch) from pitch extractor + BaseFloat nccf = tmp(0); + return NccfToPov(nccf); +} + // inline void OnlineProcessPitch::GetNormalizationWindow(int32 t, diff --git a/src/feat/pitch-functions.h b/src/feat/pitch-functions.h index 70e85380be6..b94ac661c10 100644 --- a/src/feat/pitch-functions.h +++ b/src/feat/pitch-functions.h @@ -231,6 +231,7 @@ struct ProcessPitchOptions { bool add_normalized_log_pitch; bool add_delta_pitch; bool add_raw_log_pitch; + bool add_raw_pov; ProcessPitchOptions() : pitch_scale(2.0), @@ -245,7 +246,7 @@ struct ProcessPitchOptions { add_pov_feature(true), add_normalized_log_pitch(true), add_delta_pitch(true), - add_raw_log_pitch(false) { } + add_raw_log_pitch(false), add_raw_pov(false) { } void Register(ParseOptions *opts) { @@ -286,6 +287,8 @@ struct ProcessPitchOptions { "features"); opts->Register("add-raw-log-pitch", &add_raw_log_pitch, "If true, log(pitch) is added to output features"); + opts->Register("add-raw-pov", &add_raw_pov, + "If true, add NCCF converted to POV"); } }; @@ -396,6 +399,10 @@ class OnlineProcessPitch: public OnlineFeatureInterface { /// Called from GetFrame(). inline BaseFloat GetNormalizedLogPitchFeature(int32 frame); + /// Computes and retures the raw POV for this frames. + /// Called from GetFrames(). + inline BaseFloat GetRawPov(int32 frame) const; + /// Computes the normalization window sizes. inline void GetNormalizationWindow(int32 frame, int32 src_frames_ready, diff --git a/src/featbin/Makefile b/src/featbin/Makefile index c51867b7d4c..d6d85893289 100644 --- a/src/featbin/Makefile +++ b/src/featbin/Makefile @@ -15,7 +15,8 @@ BINFILES = compute-mfcc-feats compute-plp-feats compute-fbank-feats \ process-kaldi-pitch-feats compare-feats wav-to-duration add-deltas-sdc \ compute-and-process-kaldi-pitch-feats modify-cmvn-stats wav-copy \ wav-reverberate append-vector-to-feats shift-feats concat-feats \ - append-post-to-feats post-to-feats + append-post-to-feats post-to-feats vector-to-feat \ + extract-column compute-snr-targets OBJFILES = diff --git a/src/featbin/apply-cmvn-sliding.cc b/src/featbin/apply-cmvn-sliding.cc index 4a6d02d16cd..105319761b5 100644 --- a/src/featbin/apply-cmvn-sliding.cc +++ b/src/featbin/apply-cmvn-sliding.cc @@ -35,10 +35,13 @@ int main(int argc, char *argv[]) { "Useful for speaker-id; see also apply-cmvn-online\n" "\n" "Usage: apply-cmvn-sliding [options] \n"; - + + std::string skip_dims_str; ParseOptions po(usage); SlidingWindowCmnOptions opts; opts.Register(&po); + po.Register("skip-dims", &skip_dims_str, "Dimensions for which to skip " + "normalization: colon-separated list of integers, e.g. 13:14:15)"); po.Read(argc, argv); @@ -47,15 +50,24 @@ int main(int argc, char *argv[]) { exit(1); } + std::vector skip_dims; // optionally use "fake" + // (zero-mean/unit-variance) stats for some + // dims to disable normalization. + if (!SplitStringToIntegers(skip_dims_str, ":", false, &skip_dims)) { + KALDI_ERR << "Bad --skip-dims option (should be colon-separated list of " + << "integers)"; + } + + int32 num_done = 0, num_err = 0; - + std::string feat_rspecifier = po.GetArg(1); std::string feat_wspecifier = po.GetArg(2); SequentialBaseFloatMatrixReader feat_reader(feat_rspecifier); BaseFloatMatrixWriter feat_writer(feat_wspecifier); - - for (;!feat_reader.Done(); feat_reader.Next()) { + + for (; !feat_reader.Done(); feat_reader.Next()) { std::string utt = feat_reader.Key(); Matrix feat(feat_reader.Value()); if (feat.NumRows() == 0) { @@ -67,7 +79,7 @@ int main(int argc, char *argv[]) { feat.NumCols(), kUndefined); SlidingWindowCmn(opts, feat, &cmvn_feat); - + feat_writer.Write(utt, cmvn_feat); num_done++; } diff --git a/src/featbin/compute-snr-targets.cc b/src/featbin/compute-snr-targets.cc new file mode 100644 index 00000000000..cdb7ef66c2a --- /dev/null +++ b/src/featbin/compute-snr-targets.cc @@ -0,0 +1,273 @@ +// featbin/compute-snr-targets.cc + +// Copyright 2015-2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "matrix/kaldi-matrix.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + + const char *usage = + "Compute snr targets using clean and noisy speech features.\n" + "The targets can be of 3 types -- \n" + "Irm (Ideal Ratio Mask) = Clean fbank / (Clean fbank + Noise fbank)\n" + "FbankMask = Clean fbank / Noisy fbank\n" + "Snr (Signal To Noise Ratio) = Clean fbank / Noise fbank\n" + "Both input and output features are assumed to be in log domain.\n" + "ali-rspecifier and silence-phones are used to identify whether " + "a particular frame is \"clean\" or not. Silence frames in " + "\"clean\" fbank are treated as \"noise\" and hence the SNR for those " + "frames are -inf in log scale.\n" + "Usage: compute-snr-targets [options] \n" + " or compute-snr-targets [options] --binary-targets \n" + "e.g.: compute-snr-targets scp:clean.scp scp:noisy.scp ark:targets.ark\n"; + + std::string target_type = "Irm"; + std::string ali_rspecifier; + std::string silence_phones_str; + std::string floor_str = "-inf", ceiling_str = "inf"; + int32 length_tolerance = 0; + bool binary_targets = false; + int32 target_dim = -1; + + ParseOptions po(usage); + po.Register("target_type", &target_type, "Target type can be FbankMask or IRM"); + po.Register("ali-rspecifier", &ali_rspecifier, "If provided, all the " + "energy in the silence region of clean file is considered noise"); + po.Register("silence-phones", &silence_phones_str, "Comma-separated list of " + "silence phones"); + po.Register("floor", &floor_str, "If specified, the target is floored at " + "this value. You may want to do this if you are using targets " + "in original log form as is usual in the case of Snr, but may " + "not if you are applying Exp() as is usual in the case of Irm"); + po.Register("ceiling", &ceiling_str, "If specified, the target is ceiled " + "at this value. You may want to do this if you expect " + "infinities or very large values, particularly for Snr targets."); + po.Register("length-tolerance", &length_tolerance, "Tolerate differences " + "in utterance lengths of these many frames"); + po.Register("binary-targets", &binary_targets, "If specified, then the " + "targets are created considering each frame to be either " + "completely signal or completely noise as decided by the " + "ali-rspecifier option. When ali-rspecifier is not specified, " + "then the entire utterance is considered to be just signal." + "If this option is specified, then only a single argument " + "-- the clean features -- is must be specified."); + po.Register("target-dim", &target_dim, "Overrides the target dimension. " + "Applicable only with --binary-targets is specified"); + + po.Read(argc, argv); + + if (po.NumArgs() != 3 && po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + std::vector silence_phones; + if (!silence_phones_str.empty()) { + if (!SplitStringToIntegers(silence_phones_str, ":", false, &silence_phones)) { + KALDI_ERR << "Invalid silence-phones string " << silence_phones_str; + } + std::sort(silence_phones.begin(), silence_phones.end()); + } + + double floor = kLogZeroDouble, ceiling = -kLogZeroDouble; + + if (floor_str != "-inf") + if (!ConvertStringToReal(floor_str, &floor)) { + KALDI_ERR << "Invalid --floor value " << floor_str; + } + + if (ceiling_str != "inf") + if (!ConvertStringToReal(ceiling_str, &ceiling)) { + KALDI_ERR << "Invalid --ceiling value " << ceiling_str; + } + + int32 num_done = 0, num_err = 0, num_success = 0; + int64 num_sil_frames = 0; + int64 num_speech_frames = 0; + + if (!binary_targets) { + // This is the 'normal' case, where we have both clean and + // noise/corrupted input features. + // The word 'noisy' in the variable names is used to mean 'corrupted'. + std::string clean_rspecifier = po.GetArg(1), + noisy_rspecifier = po.GetArg(2), + targets_wspecifier = po.GetArg(3); + + SequentialBaseFloatMatrixReader noisy_reader(noisy_rspecifier); + RandomAccessBaseFloatMatrixReader clean_reader(clean_rspecifier); + BaseFloatMatrixWriter kaldi_writer(targets_wspecifier); + + RandomAccessInt32VectorReader alignment_reader(ali_rspecifier); + + for (; !noisy_reader.Done(); noisy_reader.Next(), num_done++) { + const std::string &key = noisy_reader.Key(); + Matrix total_energy(noisy_reader.Value()); + // Although this is called 'energy', it is actually log filterbank + // features of noise or corrupted files + // Actually noise feats in the case of Irm and Snr + + // TODO: Support multiple corrupted version for a particular clean file + std::string uniq_key = key; + if (!clean_reader.HasKey(uniq_key)) { + KALDI_WARN << "Could not find uniq key " << uniq_key << " " + << "in clean feats " << clean_rspecifier; + num_err++; + continue; + } + + Matrix clean_energy(clean_reader.Value(uniq_key)); + + if (target_type == "Irm") { + total_energy.LogAddExpMat(1.0, clean_energy, kNoTrans); + } + + if (!ali_rspecifier.empty()) { + if (!alignment_reader.HasKey(uniq_key)) { + KALDI_WARN << "Could not find uniq key " << uniq_key + << "in alignment " << ali_rspecifier; + num_err++; + continue; + } + const std::vector &ali = alignment_reader.Value(key); + + if (std::abs(static_cast (ali.size()) - clean_energy.NumRows()) > length_tolerance) { + KALDI_WARN << "Mismatch in number of frames in alignment " + << "and feats; " << static_cast(ali.size()) + << " vs " << clean_energy.NumRows(); + num_err++; + continue; + } + + int32 length = std::min(static_cast(ali.size()), clean_energy.NumRows()); + if (ali.size() < length) + // TODO: Support this case + KALDI_ERR << "This code currently does not support the case " + << "where alignment smaller than features because " + << "it is not expected to happen"; + + KALDI_ASSERT(clean_energy.NumRows() == length); + KALDI_ASSERT(total_energy.NumRows() == length); + + if (clean_energy.NumRows() < length) clean_energy.Resize(length, clean_energy.NumCols(), kCopyData); + if (total_energy.NumRows() < length) total_energy.Resize(length, total_energy.NumCols(), kCopyData); + + for (int32 i = 0; i < clean_energy.NumRows(); i++) { + if (std::binary_search(silence_phones.begin(), silence_phones.end(), ali[i])) { + clean_energy.Row(i).Set(kLogZeroDouble); + num_sil_frames++; + } else num_speech_frames++; + } + } + + clean_energy.AddMat(-1.0, total_energy); + if (ceiling_str != "inf") { + clean_energy.ApplyCeiling(ceiling); + } + + if (floor_str != "-inf") { + clean_energy.ApplyFloor(floor); + } + + kaldi_writer.Write(key, Matrix(clean_energy)); + num_success++; + } + } else { + // Copying tables of features. + std::string feats_rspecifier = po.GetArg(1), + targets_wspecifier = po.GetArg(2); + + SequentialBaseFloatMatrixReader feats_reader(feats_rspecifier); + BaseFloatMatrixWriter kaldi_writer(targets_wspecifier); + + RandomAccessInt32VectorReader alignment_reader(ali_rspecifier); + + int64 num_sil_frames = 0; + int64 num_speech_frames = 0; + + for (; !feats_reader.Done(); feats_reader.Next(), num_done++) { + const std::string &key = feats_reader.Key(); + const Matrix &feats = feats_reader.Value(); + + Matrix targets; + + if (target_dim < 0) + targets.Resize(feats.NumRows(), feats.NumCols()); + else + targets.Resize(feats.NumRows(), target_dim); + + if (target_type == "Snr") + targets.Set(-kLogZeroDouble); + + if (!ali_rspecifier.empty()) { + if (!alignment_reader.HasKey(key)) { + KALDI_WARN << "Could not find uniq key " << key + << " in alignment " << ali_rspecifier; + num_err++; + continue; + } + + const std::vector &ali = alignment_reader.Value(key); + + if (std::abs(static_cast (ali.size()) - feats.NumRows()) > length_tolerance) { + KALDI_WARN << "Mismatch in number of frames in alignment " + << "and feats; " << static_cast(ali.size()) + << " vs " << feats.NumRows(); + num_err++; + continue; + } + + int32 length = std::min(static_cast(ali.size()), feats.NumRows()); + KALDI_ASSERT(ali.size() >= length); + + for (int32 i = 0; i < feats.NumRows(); i++) { + if (std::binary_search(silence_phones.begin(), silence_phones.end(), ali[i])) { + targets.Row(i).Set(kLogZeroDouble); + num_sil_frames++; + } else { + num_speech_frames++; + } + } + + if (ceiling_str != "inf") { + targets.ApplyCeiling(ceiling); + } + + if (floor_str != "-inf") { + targets.ApplyFloor(floor); + } + + kaldi_writer.Write(key, targets); + } + } + } + + KALDI_LOG << "Computed SNR targets for " << num_success + << " out of " << num_done << " utterances; failed for " + << num_err; + KALDI_LOG << "Got [ " << num_speech_frames << "," + << num_sil_frames << "] frames of silence and speech"; + return (num_success > 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} diff --git a/src/featbin/extract-column.cc b/src/featbin/extract-column.cc new file mode 100644 index 00000000000..7fa6644af03 --- /dev/null +++ b/src/featbin/extract-column.cc @@ -0,0 +1,84 @@ +// featbin/extract-column.cc + +// Copyright 2015 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "matrix/kaldi-matrix.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace std; + + const char *usage = + "Extract a column out of a matrix. \n" + "This is most useful to extract log-energies \n" + "from feature files\n" + "\n" + "Usage: extract-column [options] --column-index= " + " \n" + " e.g. extract-column ark:feats-in.ark ark:energies.ark\n" + "See also: select-feats, subset-feats, subsample-feats, extract-rows\n"; + + ParseOptions po(usage); + + int32 column_index = 0; + + po.Register("column-index", &column_index, + "Index of column to extract"); + + po.Read(argc, argv); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + string feat_rspecifier = po.GetArg(1); + string vector_wspecifier = po.GetArg(2); + + SequentialBaseFloatMatrixReader reader(feat_rspecifier); + BaseFloatVectorWriter writer(vector_wspecifier); + + int32 num_done = 0, num_err = 0; + + string line; + + for (; !reader.Done(); reader.Next(), num_done++) { + const Matrix& feats(reader.Value()); + Vector col(feats.NumRows()); + if (column_index >= feats.NumCols()) { + KALDI_ERR << "Column index " << column_index << " is " + << "not less than number of columns " << feats.NumCols(); + } + col.CopyColFromMat(feats, column_index); + writer.Write(reader.Key(), col); + } + + KALDI_LOG << "Processed " << num_done << " matrices successfully; " + << "errors on " << num_err; + + return (num_done > 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + + diff --git a/src/featbin/vector-to-feat.cc b/src/featbin/vector-to-feat.cc new file mode 100644 index 00000000000..1fe521db864 --- /dev/null +++ b/src/featbin/vector-to-feat.cc @@ -0,0 +1,100 @@ +// featbin/vector-to-feat.cc + +// Copyright 2015 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "matrix/kaldi-matrix.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + + const char *usage = + "Convert a vector into a single feature so that it can be appended \n" + "to other feature matrices\n" + "Usage: vector-to-feats \n" + "or: vector-to-feats \n" + "e.g.: vector-to-feats scp:weights.scp ark:weight_feats.ark\n" + " or: vector-to-feats weight_vec feat_mat\n" + "See also: copy-feats, copy-matrix, paste-feats, \n" + "subsample-feats, splice-feats\n"; + + ParseOptions po(usage); + bool compress = false, binary = true; + + po.Register("binary", &binary, "Binary-mode output (not relevant if writing " + "to archive)"); + po.Register("compress", &compress, "If true, write output in compressed form" + "(only currently supported for wxfilename, i.e. archive/script," + "output)"); + + po.Read(argc, argv); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + int32 num_done = 0; + + if (ClassifyRspecifier(po.GetArg(1), NULL, NULL) != kNoRspecifier) { + std::string vector_rspecifier = po.GetArg(1); + std::string feature_wspecifier = po.GetArg(2); + + SequentialBaseFloatVectorReader vector_reader(vector_rspecifier); + BaseFloatMatrixWriter feat_writer(feature_wspecifier); + CompressedMatrixWriter compressed_feat_writer(feature_wspecifier); + + for (; !vector_reader.Done(); vector_reader.Next(), ++num_done) { + const Vector &vec = vector_reader.Value(); + Matrix feat(vec.Dim(), 1); + feat.CopyColFromVec(vec, 0); + + if (!compress) + feat_writer.Write(vector_reader.Key(), feat); + else + compressed_feat_writer.Write(vector_reader.Key(), + CompressedMatrix(feat)); + } + KALDI_LOG << "Converted " << num_done << " vectors into features"; + return (num_done != 0 ? 0 : 1); + } + + KALDI_ASSERT(!compress && "Compression not yet supported for single files"); + + std::string vector_rxfilename = po.GetArg(1), + feature_wxfilename = po.GetArg(2); + + Vector vec; + ReadKaldiObject(vector_rxfilename, &vec); + + Matrix feat(vec.Dim(), 1); + feat.CopyColFromVec(vec, 0); + + WriteKaldiObject(feat, feature_wxfilename, binary); + + KALDI_LOG << "Converted vector " << PrintableRxfilename(vector_rxfilename) + << " to " << PrintableWxfilename(feature_wxfilename); + return 0; + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/featbin/wav-reverberate.cc b/src/featbin/wav-reverberate.cc index a9e6d3509c1..3b92f6e0b3e 100644 --- a/src/featbin/wav-reverberate.cc +++ b/src/featbin/wav-reverberate.cc @@ -156,6 +156,8 @@ int main(int argc, char *argv[]) { bool normalize_output = true; BaseFloat volume = 0; BaseFloat duration = 0; + std::string reverb_wxfilename; + std::string additive_noise_wxfilename; po.Register("multi-channel-output", &multi_channel_output, "Specifies if the output should be multi-channel or not"); @@ -212,6 +214,14 @@ int main(int argc, char *argv[]) { "after reverberating and possibly adding noise. " "If you set this option to a nonzero value, it will be as " "if you had also specified --normalize-output=false."); + po.Register("reverb-out-wxfilename", &reverb_wxfilename, + "Output the reverberated wave file, i.e. before adding the " + "additive noise. " + "Useful for computing SNR features or for debugging"); + po.Register("additive-noise-out-wxfilename", + &additive_noise_wxfilename, + "Output the additive noise file used to corrupt the input wave." + "Useful for computing SNR features or for debugging"); po.Read(argc, argv); if (po.NumArgs() != 2) { @@ -314,10 +324,23 @@ int main(int argc, char *argv[]) { int32 num_samp_output = (duration > 0 ? samp_freq_input * duration : (shift_output ? num_samp_input : num_samp_input + num_samp_rir - 1)); + Matrix out_matrix(num_output_channels, num_samp_output); + Matrix out_reverb_matrix; + if (!reverb_wxfilename.empty()) + out_reverb_matrix.Resize(num_output_channels, num_samp_output); + + Matrix out_noise_matrix; + if (!additive_noise_wxfilename.empty()) + out_noise_matrix.Resize(num_output_channels, num_samp_output); + for (int32 output_channel = 0; output_channel < num_output_channels; output_channel++) { Vector input(num_samp_input); + + Vector out_reverb(0); + Vector out_noise(0); + input.CopyRowFromMat(input_matrix, input_channel); float power_before_reverb = VecVec(input, input) / input.Dim(); @@ -337,6 +360,16 @@ int main(int argc, char *argv[]) { } } + if (!reverb_wxfilename.empty()) { + out_reverb.Resize(input.Dim()); + out_reverb.CopyFromVec(input); + } + + if (!additive_noise_wxfilename.empty()) { + out_noise.Resize(input.Dim()); + out_noise.SetZero(); + } + if (additive_signal_matrices.size() > 0) { Vector noise(0); int32 this_noise_channel = (multi_channel_output ? output_channel : noise_channel); @@ -345,33 +378,86 @@ int main(int argc, char *argv[]) { for (int32 i = 0; i < additive_signal_matrices.size(); i++) { noise.Resize(additive_signal_matrices[i].NumCols()); noise.CopyRowFromMat(additive_signal_matrices[i], this_noise_channel); - AddNoise(&noise, snr_vector[i], start_time_vector[i], - samp_freq_input, early_energy, &input); + + if (!additive_noise_wxfilename.empty()) { + AddNoise(&noise, snr_vector[i], start_time_vector[i], + samp_freq_input, early_energy, &out_noise); + } else { + AddNoise(&noise, snr_vector[i], start_time_vector[i], + samp_freq_input, early_energy, &input); + } + } + + if (!additive_noise_wxfilename.empty()) { + input.AddVec(1.0, out_noise); } } float power_after_reverb = VecVec(input, input) / input.Dim(); - if (volume > 0) + if (volume > 0) { input.Scale(volume); - else if (normalize_output) + out_reverb.Scale(volume); + out_noise.Scale(volume); + } else if (normalize_output) { input.Scale(sqrt(power_before_reverb / power_after_reverb)); + out_reverb.Scale(sqrt(power_before_reverb / power_after_reverb)); + out_noise.Scale(sqrt(power_before_reverb / power_after_reverb)); + } if (num_samp_output <= num_samp_input) { // trim the signal from the start out_matrix.CopyRowFromVec(input.Range(shift_index, num_samp_output), output_channel); + + if (!reverb_wxfilename.empty()) { + out_reverb_matrix.CopyRowFromVec(out_reverb.Range(shift_index, num_samp_output), output_channel); + } + + if (!additive_noise_wxfilename.empty()) { + out_noise_matrix.CopyRowFromVec(out_noise.Range(shift_index, num_samp_output), output_channel); + } } else { - // repeat the signal to fill up the duration - Vector extended_input(num_samp_output); - extended_input.SetZero(); - AddVectorsOfUnequalLength(input.Range(shift_index, num_samp_input), &extended_input); - out_matrix.CopyRowFromVec(extended_input, output_channel); + { + // repeat the signal to fill up the duration + Vector extended_input(num_samp_output); + extended_input.SetZero(); + AddVectorsOfUnequalLength(input.Range(shift_index, num_samp_input), &extended_input); + out_matrix.CopyRowFromVec(extended_input, output_channel); + } + if (!reverb_wxfilename.empty()) { + // repeat the signal to fill up the duration + Vector extended_input(num_samp_output); + extended_input.SetZero(); + AddVectorsOfUnequalLength(out_reverb.Range(shift_index, num_samp_input), &extended_input); + out_reverb_matrix.CopyRowFromVec(extended_input, output_channel); + } + if (!additive_noise_wxfilename.empty()) { + // repeat the signal to fill up the duration + Vector extended_input(num_samp_output); + extended_input.SetZero(); + AddVectorsOfUnequalLength(out_noise.Range(shift_index, num_samp_input), &extended_input); + out_noise_matrix.CopyRowFromVec(extended_input, output_channel); + } } } + + { + WaveData out_wave(samp_freq_input, out_matrix); + Output ko(output_wave_file, false); + out_wave.Write(ko.Stream()); + } + + if (!reverb_wxfilename.empty()) { + WaveData out_wave(samp_freq_input, out_reverb_matrix); + Output ko(reverb_wxfilename, false); + out_wave.Write(ko.Stream()); + } - WaveData out_wave(samp_freq_input, out_matrix); - Output ko(output_wave_file, false); - out_wave.Write(ko.Stream()); + if (!additive_noise_wxfilename.empty()) { + WaveData out_wave(samp_freq_input, out_noise_matrix); + Output ko(additive_noise_wxfilename, false); + out_wave.Write(ko.Stream()); + } return 0; } catch(const std::exception &e) { diff --git a/src/gmm/diag-gmm.h b/src/gmm/diag-gmm.h index 1243d7a6bfd..32ef4f146d7 100644 --- a/src/gmm/diag-gmm.h +++ b/src/gmm/diag-gmm.h @@ -32,6 +32,8 @@ #include "matrix/matrix-lib.h" #include "tree/cluster-utils.h" #include "tree/clusterable-classes.h" +#include "util/kaldi-table.h" +#include "util/kaldi-holder.h" namespace kaldi { @@ -255,6 +257,14 @@ operator << (std::ostream &os, const kaldi::DiagGmm &gmm); std::istream & operator >> (std::istream &is, kaldi::DiagGmm &gmm); +typedef KaldiObjectHolder DiagGmmHolder; + +typedef TableWriter DiagGmmWriter; +typedef SequentialTableReader SequentialDiagGmmReader; +typedef RandomAccessTableReader RandomAccessDiagGmmReader; +typedef RandomAccessTableReaderMapped +RandomAccessDiagGmmReaderMapped; + } // End namespace kaldi #include "gmm/diag-gmm-inl.h" // templated functions. diff --git a/src/gmmbin/Makefile b/src/gmmbin/Makefile index 7adb8bdc41e..caf4b1f8118 100644 --- a/src/gmmbin/Makefile +++ b/src/gmmbin/Makefile @@ -28,7 +28,8 @@ BINFILES = gmm-init-mono gmm-est gmm-acc-stats-ali gmm-align \ gmm-est-fmllr-raw gmm-est-fmllr-raw-gpost gmm-global-init-from-feats \ gmm-global-info gmm-latgen-faster-regtree-fmllr gmm-est-fmllr-global \ gmm-acc-mllt-global gmm-transform-means-global gmm-global-get-post \ - gmm-global-gselect-to-post gmm-global-est-lvtln-trans + gmm-global-gselect-to-post gmm-global-est-lvtln-trans \ + gmm-global-post-to-feats OBJFILES = diff --git a/src/gmmbin/gmm-global-copy.cc b/src/gmmbin/gmm-global-copy.cc index af31b03aa9a..b850cdced51 100644 --- a/src/gmmbin/gmm-global-copy.cc +++ b/src/gmmbin/gmm-global-copy.cc @@ -29,11 +29,13 @@ int main(int argc, char *argv[]) { const char *usage = "Copy a diagonal-covariance GMM\n" "Usage: gmm-global-copy [options] \n" + " or gmm-global-copy [options] \n" "e.g.: gmm-global-copy --binary=false 1.model - | less"; bool binary_write = true; ParseOptions po(usage); - po.Register("binary", &binary_write, "Write output in binary mode"); + po.Register("binary", &binary_write, + "Write in binary mode (only relevant if output is a wxfilename)"); po.Read(argc, argv); @@ -45,15 +47,39 @@ int main(int argc, char *argv[]) { std::string model_in_filename = po.GetArg(1), model_out_filename = po.GetArg(2); - DiagGmm gmm; - { - bool binary_read; - Input ki(model_in_filename, &binary_read); - gmm.Read(ki.Stream(), binary_read); - } - WriteKaldiObject(gmm, model_out_filename, binary_write); + // all these "fn"'s are either rspecifiers or filenames. + + bool in_is_rspecifier = + (ClassifyRspecifier(model_in_filename, NULL, NULL) + != kNoRspecifier), + out_is_wspecifier = + (ClassifyWspecifier(model_out_filename, NULL, NULL, NULL) + != kNoWspecifier); + + if (in_is_rspecifier != out_is_wspecifier) + KALDI_ERR << "Cannot mix archives with regular files (copying gmm models)"; - KALDI_LOG << "Written model to " << model_out_filename; + if (!in_is_rspecifier) { + DiagGmm gmm; + { + bool binary_read; + Input ki(model_in_filename, &binary_read); + gmm.Read(ki.Stream(), binary_read); + } + WriteKaldiObject(gmm, model_out_filename, binary_write); + + KALDI_LOG << "Written model to " << model_out_filename; + } else { + SequentialDiagGmmReader gmm_reader(model_in_filename); + DiagGmmWriter gmm_writer(model_out_filename); + + int32 num_done = 0; + for (; !gmm_reader.Done(); gmm_reader.Next(), num_done++) { + gmm_writer.Write(gmm_reader.Key(), gmm_reader.Value()); + } + + KALDI_LOG << "Wrote " << num_done << " GMM models to " << model_out_filename; + } } catch(const std::exception &e) { std::cerr << e.what() << '\n'; return -1; diff --git a/src/gmmbin/gmm-global-get-post.cc b/src/gmmbin/gmm-global-get-post.cc index b364c33cab4..35438a7e849 100644 --- a/src/gmmbin/gmm-global-get-post.cc +++ b/src/gmmbin/gmm-global-get-post.cc @@ -36,35 +36,51 @@ int main(int argc, char *argv[]) { " (e.g. in training UBMs, SGMMs, tied-mixture systems)\n" " For each frame, gives a list of the n best Gaussian indices,\n" " sorted from best to worst.\n" - "Usage: gmm-global-get-post [options] \n" - "e.g.: gmm-global-get-post --n=20 1.gmm \"ark:feature-command |\" \"ark,t:|gzip -c >post.1.gz\"\n"; + "Usage: gmm-global-get-post [options] []\n" + "e.g.: gmm-global-get-post --n=20 1.gmm \"ark:feature-command |\" \"ark,t:|gzip -c >post.1.gz\"\n" + " or : gmm-global-get-post --n=20 ark:1.gmm \"ark:feature-command |\" \"ark,t:|gzip -c >post.1.gz\"\n"; ParseOptions po(usage); int32 num_post = 50; BaseFloat min_post = 0.0; + std::string utt2spk_rspecifier; + po.Register("n", &num_post, "Number of Gaussians to keep per frame\n"); po.Register("min-post", &min_post, "Minimum posterior we will output " "before pruning and renormalizing (e.g. 0.01)"); + po.Register("utt2spk", &utt2spk_rspecifier, + "rspecifier for utterance to speaker map for reading " + "per-speaker GMM models"); po.Read(argc, argv); - if (po.NumArgs() != 3) { + if (po.NumArgs() < 3 || po.NumArgs() > 4) { po.PrintUsage(); exit(1); } - std::string model_filename = po.GetArg(1), + std::string model_in_filename = po.GetArg(1), feature_rspecifier = po.GetArg(2), - post_wspecifier = po.GetArg(3); + post_wspecifier = po.GetArg(3), + frame_loglikes_wspecifier = po.GetOptArg(4); - DiagGmm gmm; - ReadKaldiObject(model_filename, &gmm); + RandomAccessDiagGmmReaderMapped *gmm_reader = NULL; + DiagGmm diag_gmm; + KALDI_ASSERT(num_post > 0); KALDI_ASSERT(min_post < 1.0); - int32 num_gauss = gmm.NumGauss(); - if (num_post > num_gauss) { - KALDI_WARN << "You asked for " << num_post << " Gaussians but GMM " - << "only has " << num_gauss << ", returning this many. "; - num_post = num_gauss; + + if (ClassifyRspecifier(model_in_filename, NULL, NULL) + != kNoRspecifier) { // reading models from a Table. + gmm_reader = new RandomAccessDiagGmmReaderMapped(model_in_filename, + utt2spk_rspecifier); + } else { + ReadKaldiObject(model_in_filename, &diag_gmm); + int32 num_gauss = diag_gmm.NumGauss(); + if (num_post > num_gauss) { + KALDI_WARN << "You asked for " << num_post << " Gaussians but GMM " + << "only has " << num_gauss << ", returning this many. "; + num_post = num_gauss; + } } double tot_like = 0.0; @@ -72,10 +88,11 @@ int main(int argc, char *argv[]) { SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier); PosteriorWriter post_writer(post_wspecifier); + BaseFloatVectorWriter frame_loglikes_writer(frame_loglikes_wspecifier); int32 num_done = 0, num_err = 0; for (; !feature_reader.Done(); feature_reader.Next()) { - std::string utt = feature_reader.Key(); + const std::string &utt = feature_reader.Key(); const Matrix &feats = feature_reader.Value(); int32 T = feats.NumRows(); if (T == 0) { @@ -83,25 +100,46 @@ int main(int argc, char *argv[]) { num_err++; continue; } - if (feats.NumCols() != gmm.Dim()) { + + const DiagGmm *gmm; + if (gmm_reader) { + if (!gmm_reader->HasKey(utt)) { + KALDI_WARN << "Could not find GMM for utterance " << utt; + num_err++; + continue; + } + gmm = &(gmm_reader->Value(utt)); + } else { + gmm = &diag_gmm; + } + int32 num_gauss_to_compute = + num_post > gmm->NumGauss() ? gmm->NumGauss() : num_post; + + if (feats.NumCols() != gmm->Dim()) { KALDI_WARN << "Dimension mismatch for utterance " << utt - << ": got " << feats.NumCols() << ", expected " << gmm.Dim(); + << ": got " << feats.NumCols() << ", expected " + << gmm->Dim(); num_err++; continue; } - vector > gselect(T); - Matrix loglikes; - gmm.LogLikelihoods(feats, &loglikes); + gmm->LogLikelihoods(feats, &loglikes); + + Vector frame_loglikes; + if (!frame_loglikes_wspecifier.empty()) frame_loglikes.Resize(T); Posterior post(T); double log_like_this_file = 0.0; for (int32 t = 0; t < T; t++) { - log_like_this_file += - VectorToPosteriorEntry(loglikes.Row(t), num_post, + double log_like_this_frame = + VectorToPosteriorEntry(loglikes.Row(t), + num_gauss_to_compute, min_post, &(post[t])); + if (!frame_loglikes_wspecifier.empty()) + frame_loglikes(t) = log_like_this_frame; + log_like_this_file += log_like_this_frame; } KALDI_VLOG(1) << "Processed utterance " << utt << ", average likelihood " << (log_like_this_file / T) << " over " << T << " frames"; @@ -109,8 +147,13 @@ int main(int argc, char *argv[]) { tot_t += T; post_writer.Write(utt, post); + if (!frame_loglikes_wspecifier.empty()) + frame_loglikes_writer.Write(utt, frame_loglikes); + num_done++; } + + delete gmm_reader; KALDI_LOG << "Done " << num_done << " files, " << num_err << " with errors, average UBM log-likelihood is " diff --git a/src/gmmbin/gmm-global-post-to-feats.cc b/src/gmmbin/gmm-global-post-to-feats.cc new file mode 100644 index 00000000000..fa903b66014 --- /dev/null +++ b/src/gmmbin/gmm-global-post-to-feats.cc @@ -0,0 +1,103 @@ +// gmmbin/gmm-global-post-to-feats.cc + +// Copyright 2016 Brno University of Technology (Author: Karel Vesely) +// 2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "matrix/kaldi-matrix.h" +#include "hmm/posterior.h" +#include "gmm/diag-gmm.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + + const char *usage = + "Convert GMM global posteriors to features\n" + "\n" + "Usage: gmm-global-post-to-feats [options] \n" + "e.g.: gmm-global-post-to-feats ark:1.gmm ark:post.ark ark:feat.ark\n" + "See also: post-to-feats --post-dim, post-to-weights feat-to-post, append-vector-to-feats, append-post-to-feats\n"; + + ParseOptions po(usage); + std::string utt2spk_rspecifier; + + po.Register("utt2spk", &utt2spk_rspecifier, + "rspecifier for utterance to speaker map for reading " + "per-speaker GMM models"); + + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + std::string model_in_filename = po.GetArg(1), + post_rspecifier = po.GetArg(2), + feat_wspecifier = po.GetArg(3); + + DiagGmm diag_gmm; + RandomAccessDiagGmmReaderMapped *gmm_reader = NULL; + SequentialPosteriorReader post_reader(post_rspecifier); + BaseFloatMatrixWriter feat_writer(feat_wspecifier); + + if (ClassifyRspecifier(po.GetArg(1), NULL, NULL) + != kNoRspecifier) { // We're operating on tables, e.g. archives. + gmm_reader = new RandomAccessDiagGmmReaderMapped(model_in_filename, + utt2spk_rspecifier); + } else { + ReadKaldiObject(model_in_filename, &diag_gmm); + } + + int32 num_done = 0, num_err = 0; + + for (; !post_reader.Done(); post_reader.Next()) { + const std::string &utt = post_reader.Key(); + + const DiagGmm *gmm = &diag_gmm; + if (gmm_reader) { + if (!gmm_reader->HasKey(utt)) { + KALDI_WARN << "Could not find GMM model for utterance " << utt; + num_err++; + continue; + } + gmm = &(gmm_reader->Value(utt)); + } + + int32 post_dim = gmm->NumGauss(); + + const Posterior &post = post_reader.Value(); + + Matrix output; + PosteriorToMatrix(post, post_dim, &output); + + feat_writer.Write(utt, output); + num_done++; + } + KALDI_LOG << "Done " << num_done << " utts, errors on " + << num_err; + + return (num_done == 0 ? -1 : 0); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} diff --git a/src/hmm/hmm-utils.cc b/src/hmm/hmm-utils.cc index fe6c5b32d6e..d0ec8595227 100644 --- a/src/hmm/hmm-utils.cc +++ b/src/hmm/hmm-utils.cc @@ -231,8 +231,6 @@ GetHmmAsFstSimple(std::vector phone_window, - - // The H transducer has a separate outgoing arc for each of the symbols in ilabel_info. fst::VectorFst *GetHTransducer (const std::vector > &ilabel_info, diff --git a/src/hmm/transition-model.cc b/src/hmm/transition-model.cc index 83edbaf5805..7973be69dcd 100644 --- a/src/hmm/transition-model.cc +++ b/src/hmm/transition-model.cc @@ -240,6 +240,16 @@ TransitionModel::TransitionModel(const ContextDependencyInterface &ctx_dep, Check(); } +void TransitionModel::Init(const ContextDependencyInterface &ctx_dep, + const HmmTopology &hmm_topo) { + topo_ = hmm_topo; + // First thing is to get all possible tuples. + ComputeTuples(ctx_dep); + ComputeDerived(); + InitializeProbs(); + Check(); +} + int32 TransitionModel::TupleToTransitionState(int32 phone, int32 hmm_state, int32 pdf, int32 self_loop_pdf) const { Tuple tuple(phone, hmm_state, pdf, self_loop_pdf); // Note: if this ever gets too expensive, which is unlikely, we can refactor diff --git a/src/hmm/transition-model.h b/src/hmm/transition-model.h index 442de8fd2e0..e6a9221fb31 100644 --- a/src/hmm/transition-model.h +++ b/src/hmm/transition-model.h @@ -128,11 +128,17 @@ class TransitionModel { TransitionModel(const ContextDependencyInterface &ctx_dep, const HmmTopology &hmm_topo); - /// Constructor that takes no arguments: typically used prior to calling Read. TransitionModel() { } + + virtual ~TransitionModel() { } + + /// Does the same things as the constructor. + void Init(const ContextDependencyInterface &ctx_dep, + const HmmTopology &hmm_topo); - void Read(std::istream &is, bool binary); // note, no symbol table: topo object always read/written w/o symbols. + // note, no symbol table: topo object always read/written w/o symbols. + virtual void Read(std::istream &is, bool binary); void Write(std::ostream &os, bool binary) const; @@ -316,7 +322,6 @@ class TransitionModel { /// of pdfs). int32 num_pdfs_; - KALDI_DISALLOW_COPY_AND_ASSIGN(TransitionModel); }; diff --git a/src/ivectorbin/Makefile b/src/ivectorbin/Makefile index 71a855762fe..5df22a2bb8a 100644 --- a/src/ivectorbin/Makefile +++ b/src/ivectorbin/Makefile @@ -15,7 +15,9 @@ BINFILES = ivector-extractor-init ivector-extractor-acc-stats \ ivector-subtract-global-mean ivector-plda-scoring \ logistic-regression-train logistic-regression-eval \ logistic-regression-copy create-split-from-vad \ - ivector-extract-online ivector-adapt-plda + ivector-extract-online ivector-adapt-plda \ + ivector-extract-dense ivector-cluster \ + ivector-cluster-plda OBJFILES = diff --git a/src/lat/sausages.cc b/src/lat/sausages.cc index e6fd0b61dd9..89734f76a04 100644 --- a/src/lat/sausages.cc +++ b/src/lat/sausages.cc @@ -51,10 +51,19 @@ void MinimumBayesRisk::MbrDecode() { R_[q] = rhat; } if (R_[q] != 0) { - one_best_times_.push_back(times_[q]); BaseFloat confidence = 0.0; + bool first_time = true; for (int32 j = 0; j < gamma_[q].size(); j++) - if (gamma_[q][j].first == R_[q]) confidence = gamma_[q][j].second; + if (gamma_[q][j].first == R_[q]) { + KALDI_ASSERT(first_time); first_time = false; + confidence = gamma_[q][j].second; + KALDI_ASSERT(confidence > 0); + KALDI_ASSERT(begin_times_[q].count(R_[q]) > 0); + KALDI_ASSERT(end_times_[q].count(R_[q]) > 0); + one_best_times_.push_back(std::make_pair( + begin_times_[q][R_[q]] / confidence, + end_times_[q][R_[q]] / confidence)); + } one_best_confidences_.push_back(confidence); } } @@ -145,11 +154,13 @@ void MinimumBayesRisk::AccStats() { std::vector > gamma(Q+1); // temp. form of gamma. // index 1...Q [word] -> occ. + vector > tau_b(Q+1), tau_e(Q+1); + // The tau arrays below are the sums over words of the tau_b // and tau_e timing quantities mentioned in Appendix C of // the paper... we are using these to get averaged times for // the sausage bins, not specifically for the 1-best output. - Vector tau_b(Q+1), tau_e(Q+1); + //Vector tau_b(Q+1), tau_e(Q+1); double Ltmp = EditDistance(N, Q, alpha, alpha_dash, alpha_dash_arc); if (L_ != 0 && Ltmp > L_) { // L_ != 0 is to rule out 1st iter. @@ -189,8 +200,11 @@ void MinimumBayesRisk::AccStats() { // next: gamma(q, w(a)) += beta_dash_arc(q) AddToMap(w_a, beta_dash_arc(q), &(gamma[q])); // next: accumulating times, see decl for tau_b,tau_e - tau_b(q) += state_times_[s_a] * beta_dash_arc(q); - tau_e(q) += state_times_[n] * beta_dash_arc(q); + AddToMap(w_a, state_times_[s_a] * beta_dash_arc(q), &(tau_b[q]), false); + AddToMap(w_a, state_times_[n] * beta_dash_arc(q), &(tau_e[q]), false); + KALDI_ASSERT(tau_b[q].size() == tau_e[q].size()); + //tau_b(q) += state_times_[s_a] * beta_dash_arc(q); + //tau_e(q) += state_times_[n] * beta_dash_arc(q); break; case 2: beta_dash(s_a, q) += beta_dash_arc(q); @@ -203,8 +217,11 @@ void MinimumBayesRisk::AccStats() { // WARNING: there was an error in Appendix C. If we followed // the instructions there the next line would say state_times_[sa], but // it would be wrong. I will try to publish an erratum. - tau_b(q) += state_times_[n] * beta_dash_arc(q); - tau_e(q) += state_times_[n] * beta_dash_arc(q); + AddToMap(0, state_times_[n] * beta_dash_arc(q), &(tau_b[q]), false); + AddToMap(0, state_times_[n] * beta_dash_arc(q), &(tau_e[q]), false); + KALDI_ASSERT(tau_b[q].size() == tau_e[q].size()); + //tau_b(q) += state_times_[n] * beta_dash_arc(q); + //tau_e(q) += state_times_[n] * beta_dash_arc(q); break; default: KALDI_ERR << "Invalid b_arc value"; // error in code. @@ -221,8 +238,11 @@ void MinimumBayesRisk::AccStats() { AddToMap(0, beta_dash_arc(q), &(gamma[q])); // the statements below are actually redundant because // state_times_[1] is zero. - tau_b(q) += state_times_[1] * beta_dash_arc(q); - tau_e(q) += state_times_[1] * beta_dash_arc(q); + //tau_b(q) += state_times_[1] * beta_dash_arc(q); + //tau_e(q) += state_times_[1] * beta_dash_arc(q); + AddToMap(0, state_times_[1] * beta_dash_arc(q), &(tau_b[q]), false); + AddToMap(0, state_times_[1] * beta_dash_arc(q), &(tau_e[q]), false); + KALDI_ASSERT(tau_b[q].size() == tau_e[q].size()); } for (int32 q = 1; q <= Q; q++) { // a check (line 35) double sum = 0.0; @@ -249,9 +269,24 @@ void MinimumBayesRisk::AccStats() { // indexing. times_.clear(); times_.resize(Q); + begin_times_.clear(); + begin_times_.resize(Q); + end_times_.clear(); + end_times_.resize(Q); for (int32 q = 1; q <= Q; q++) { - times_[q-1].first = tau_b(q); - times_[q-1].second = tau_e(q); + KALDI_ASSERT(tau_b[q].size() == tau_e[q].size()); + for (map::iterator iter = tau_b[q].begin(); + iter != tau_b[q].end(); ++iter) { + times_[q-1].first += iter->second; + begin_times_[q-1].insert(std::make_pair(iter->first, iter->second)); + } + + for (map::iterator iter = tau_e[q].begin(); + iter != tau_e[q].end(); ++iter) { + times_[q-1].second += iter->second; + end_times_[q-1].insert(std::make_pair(iter->first, iter->second)); + } + if (times_[q-1].first > times_[q-1].second) // this is quite bad. KALDI_WARN << "Times out of order"; if (q > 1 && times_[q-2].second > times_[q-1].first) { diff --git a/src/lat/sausages.h b/src/lat/sausages.h index 8ada15e64b5..4f709bf1703 100644 --- a/src/lat/sausages.h +++ b/src/lat/sausages.h @@ -133,8 +133,8 @@ class MinimumBayesRisk { // used in the algorithm. /// Function used to increment map. - static inline void AddToMap(int32 i, double d, std::map *gamma) { - if (d == 0) return; + static inline void AddToMap(int32 i, double d, std::map *gamma, bool return_if_zero = true) { + if (return_if_zero && d == 0) return; std::pair pr(i, d); std::pair::iterator, bool> ret = gamma->insert(pr); if (!ret.second) // not inserted, so add to contents. @@ -178,6 +178,12 @@ class MinimumBayesRisk { // paper. We sort in reverse order on the second member (posterior), so more // likely word is first. + std::vector > begin_times_; + std::vector > end_times_; + // The average start and end times for each word in a confusion-network bin. + // These are the tau_b and tau_e quantities in Appendix C of the paper. + // Indexed from zero, like gamma_ and R_. + std::vector > times_; // The average start and end times for each confusion-network bin. This // is like an average over words, of the tau_b and tau_e quantities in diff --git a/src/matrix/compressed-matrix.cc b/src/matrix/compressed-matrix.cc index 2ac2c544bc8..6fc365c8f03 100644 --- a/src/matrix/compressed-matrix.cc +++ b/src/matrix/compressed-matrix.cc @@ -24,14 +24,14 @@ namespace kaldi { -//static +//static MatrixIndexT CompressedMatrix::DataSize(const GlobalHeader &header) { // Returns size in bytes of the data. if (header.format == 1) { return sizeof(GlobalHeader) + header.num_cols * (sizeof(PerColHeader) + header.num_rows); } else { - KALDI_ASSERT(header.format == 2) ; + KALDI_ASSERT(header.format == 2); return sizeof(GlobalHeader) + 2 * header.num_rows * header.num_cols; } @@ -40,7 +40,7 @@ MatrixIndexT CompressedMatrix::DataSize(const GlobalHeader &header) { template void CompressedMatrix::CopyFromMat( - const MatrixBase &mat) { + const MatrixBase &mat, int32 format) { if (data_ != NULL) { delete [] static_cast(data_); // call delete [] because was allocated with new float[] data_ = NULL; @@ -52,7 +52,7 @@ void CompressedMatrix::CopyFromMat( KALDI_COMPILE_TIME_ASSERT(sizeof(global_header) == 20); // otherwise // something weird is happening and our code probably won't work or // won't be robust across platforms. - + // Below, the point of the "safety_margin" is that the minimum // and maximum values in the matrix shouldn't coincide with // the minimum and maximum ranges of the 16-bit range, because @@ -80,16 +80,22 @@ void CompressedMatrix::CopyFromMat( global_header.num_rows = mat.NumRows(); global_header.num_cols = mat.NumCols(); - if (mat.NumRows() > 8) { - global_header.format = 1; // format where each row has a PerColHeader. + if (format <= 0) { + if (mat.NumRows() > 8) { + global_header.format = 1; // format where each row has a PerColHeader. + } else { + global_header.format = 2; // format where all data is uint16. + } + } else if (format == 1 || format == 2) { + global_header.format = format; } else { - global_header.format = 2; // format where all data is uint16. + KALDI_ERR << "Error format for compression:format should be <=2."; } - + int32 data_size = DataSize(global_header); data_ = AllocateData(data_size); - + *(reinterpret_cast(data_)) = global_header; if (global_header.format == 1) { @@ -124,10 +130,12 @@ void CompressedMatrix::CopyFromMat( // Instantiate the template for float and double. template -void CompressedMatrix::CopyFromMat(const MatrixBase &mat); +void CompressedMatrix::CopyFromMat(const MatrixBase &mat, + int32 format); template -void CompressedMatrix::CopyFromMat(const MatrixBase &mat); +void CompressedMatrix::CopyFromMat(const MatrixBase &mat, + int32 format); CompressedMatrix::CompressedMatrix( @@ -146,10 +154,10 @@ CompressedMatrix::CompressedMatrix( if (old_num_rows == 0) { return; } // Zero-size matrix stored as zero pointer. if (num_rows == 0 || num_cols == 0) { return; } - + GlobalHeader new_global_header; KALDI_COMPILE_TIME_ASSERT(sizeof(new_global_header) == 20); - + GlobalHeader *old_global_header = reinterpret_cast(cmat.Data()); new_global_header = *old_global_header; @@ -159,10 +167,10 @@ CompressedMatrix::CompressedMatrix( // We don't switch format from 1 -> 2 (in case of size reduction) yet; if this // is needed, we will do this below by creating a temporary Matrix. new_global_header.format = old_global_header->format; - + data_ = AllocateData(DataSize(new_global_header)); // allocate memory *(reinterpret_cast(data_)) = new_global_header; - + if (old_global_header->format == 1) { // Both have the format where we have a PerColHeader and then compress as // chars... @@ -196,7 +204,7 @@ CompressedMatrix::CompressedMatrix( reinterpret_cast(old_global_header + 1); uint16 *new_data = reinterpret_cast(reinterpret_cast(data_) + 1); - + old_data += col_offset + (old_num_cols * row_offset); for (int32 row = 0; row < num_rows; row++) { @@ -281,7 +289,7 @@ void CompressedMatrix::ComputeColHeader( // Now, sdata.begin(), sdata.begin() + quarter_nr, and sdata.begin() + // 3*quarter_nr, and sdata.end() - 1, contain the elements that would appear // at those positions in sorted order. - + header->percentile_0 = std::min(FloatToUint16(global_header, sdata[0]), 65532); header->percentile_25 = @@ -297,7 +305,7 @@ void CompressedMatrix::ComputeColHeader( header->percentile_100 = std::max( FloatToUint16(global_header, sdata[num_rows-1]), header->percentile_75 + static_cast(1)); - + } else { // handle this pathological case. std::sort(sdata.begin(), sdata.end()); // Note: we know num_rows is at least 1. @@ -382,7 +390,7 @@ void CompressedMatrix::CompressColumn( unsigned char *byte_data) { ComputeColHeader(global_header, data, stride, num_rows, header); - + float p0 = Uint16ToFloat(global_header, header->percentile_0), p25 = Uint16ToFloat(global_header, header->percentile_25), p75 = Uint16ToFloat(global_header, header->percentile_75), @@ -491,7 +499,7 @@ void CompressedMatrix::CopyToMat(MatrixBase *mat, mat->CopyFromMat(temp, kTrans); return; } - + if (data_ == NULL) { KALDI_ASSERT(mat->NumRows() == 0); KALDI_ASSERT(mat->NumCols() == 0); @@ -501,7 +509,7 @@ void CompressedMatrix::CopyToMat(MatrixBase *mat, int32 num_cols = h->num_cols, num_rows = h->num_rows; KALDI_ASSERT(mat->NumRows() == num_rows); KALDI_ASSERT(mat->NumCols() == num_cols); - + if (h->format == 1) { PerColHeader *per_col_header = reinterpret_cast(h+1); unsigned char *byte_data = reinterpret_cast(per_col_header + @@ -625,7 +633,7 @@ void CompressedMatrix::CopyToMat(int32 row_offset, GlobalHeader *h = reinterpret_cast(data_); int32 num_rows = h->num_rows, num_cols = h->num_cols, tgt_cols = dest->NumCols(), tgt_rows = dest->NumRows(); - + if (h->format == 1) { // format where we have a per-column header and use one byte per // element. diff --git a/src/matrix/compressed-matrix.h b/src/matrix/compressed-matrix.h index 4e4238c43da..585fddbce21 100644 --- a/src/matrix/compressed-matrix.h +++ b/src/matrix/compressed-matrix.h @@ -35,12 +35,12 @@ namespace kaldi { /// column). /// The basic idea is for each column (in the normal configuration) -/// we work out the values at the 0th, 25th, 50th and 100th percentiles +/// we work out the values at the 0th, 25th, 75th and 100th percentiles /// and store them as 16-bit integers; we then encode each value in /// the column as a single byte, in 3 separate ranges with different -/// linear encodings (0-25th, 25-50th, 50th-100th). -/// If the matrix has 8 rows or fewer, we simply store all values as -/// uint16. +/// linear encodings (0-25th, 25-75th, 75th-100th). +/// If the matrix has 8 rows or fewer or format=2, we simply store all values +/// as uint16. class CompressedMatrix { public: @@ -49,7 +49,9 @@ class CompressedMatrix { ~CompressedMatrix() { Clear(); } template - CompressedMatrix(const MatrixBase &mat): data_(NULL) { CopyFromMat(mat); } + CompressedMatrix(const MatrixBase &mat, int32 format = 0): data_(NULL) { + CopyFromMat(mat, format); + } /// Initializer that can be used to select part of an existing /// CompressedMatrix without un-compressing and re-compressing (note: unlike @@ -65,7 +67,7 @@ class CompressedMatrix { /// This will resize *this and copy the contents of mat to *this. template - void CopyFromMat(const MatrixBase &mat); + void CopyFromMat(const MatrixBase &mat, int32 format = 0); CompressedMatrix(const CompressedMatrix &mat); diff --git a/src/matrix/kaldi-matrix.cc b/src/matrix/kaldi-matrix.cc index 50c23a7be63..818825e1df7 100644 --- a/src/matrix/kaldi-matrix.cc +++ b/src/matrix/kaldi-matrix.cc @@ -396,6 +396,87 @@ void MatrixBase::AddMat(const Real alpha, const MatrixBase& A, } } +template +void MatrixBase::LogAddExpMat(const Real alpha, const MatrixBase& A, + MatrixTransposeType transA) { + if (alpha == 0) return; + + if (&A == this) { + if (transA == kNoTrans) { + Add(alpha + 1.0); + } else { + KALDI_ASSERT(num_rows_ == num_cols_ && "AddMat: adding to self (transposed): not symmetric."); + Real *data = data_; + if (alpha == 1.0) { // common case-- handle separately. + for (MatrixIndexT row = 0; row < num_rows_; row++) { + for (MatrixIndexT col = 0; col < row; col++) { + Real *lower = data + (row * stride_) + col, + *upper = data + (col * stride_) + row; + Real sum = LogAdd(*lower, *upper); + *lower = *upper = sum; + } + *(data + (row * stride_) + row) += Log(2.0); // diagonal. + } + } else { + for (MatrixIndexT row = 0; row < num_rows_; row++) { + for (MatrixIndexT col = 0; col < row; col++) { + Real *lower = data + (row * stride_) + col, + *upper = data + (col * stride_) + row; + Real lower_tmp = *lower; + if (alpha > 0) { + *lower = LogAdd(*lower, Log(alpha) + *upper); + *upper = LogAdd(*upper, Log(alpha) + lower_tmp); + } else { + KALDI_ASSERT(alpha < 0); + *lower = LogSub(*lower, Log(-alpha) + *upper); + *upper = LogSub(*upper, Log(-alpha) + lower_tmp); + } + } + if (alpha > -1.0) + *(data + (row * stride_) + row) += Log(1.0 + alpha); // diagonal. + else + KALDI_ERR << "Cannot subtract log-matrices if the difference is " + << "negative"; + } + } + } + } else { + int aStride = (int) A.stride_; + Real *adata = A.data_, *data = data_; + if (transA == kNoTrans) { + KALDI_ASSERT(A.num_rows_ == num_rows_ && A.num_cols_ == num_cols_); + if (num_rows_ == 0) return; + for (MatrixIndexT row = 0; row < num_rows_; row++) { + for (MatrixIndexT col = 0; col < num_cols_; col++) { + Real *value = data + (row * stride_) + col, + *aValue = adata + (row * aStride) + col; + if (alpha > 0) + *value = LogAdd(*value, Log(alpha) + *aValue); + else { + KALDI_ASSERT(alpha < 0); + *value = LogSub(*value, Log(-alpha) + *aValue); + } + } + } + } else { + KALDI_ASSERT(A.num_cols_ == num_rows_ && A.num_rows_ == num_cols_); + if (num_rows_ == 0) return; + for (MatrixIndexT row = 0; row < num_rows_; row++) { + for (MatrixIndexT col = 0; col < num_cols_; col++) { + Real *value = data + (row * stride_) + col, + *aValue = adata + (col * aStride) + row; + if (alpha > 0) + *value = LogAdd(*value, Log(alpha) + *aValue); + else { + KALDI_ASSERT(alpha < 0); + *value = LogSub(*value, Log(-alpha) + *aValue); + } + } + } + } + } +} + template template void MatrixBase::AddSp(const Real alpha, const SpMatrix &S) { @@ -2546,6 +2627,15 @@ Real MatrixBase::ApplySoftMax() { return max + Log(sum); } +template +void MatrixBase::ApplySoftMaxPerRow() { + for (MatrixIndexT i = 0; i < num_rows_; i++) { + Row(i).ApplySoftMax(); + kaldi::ApproxEqual(Row(i).Sum(), 1.0); + } + KALDI_ASSERT(Max() <= 1.0 && Min() >= 0.0); +} + template void MatrixBase::Tanh(const MatrixBase &src) { KALDI_ASSERT(SameDim(*this, src)); diff --git a/src/matrix/kaldi-matrix.h b/src/matrix/kaldi-matrix.h index 25b999fe062..1af5d6c7d9b 100644 --- a/src/matrix/kaldi-matrix.h +++ b/src/matrix/kaldi-matrix.h @@ -455,6 +455,11 @@ class MatrixBase { /// Apply soft-max to the collection of all elements of the /// matrix and return normalizer (log sum of exponentials). Real ApplySoftMax(); + + /// Softmax nonlinearity + /// Y = Softmax(X) : Yij = e^Xij / sum_k(e^Xik), done to each row + /// for each row, the max value is first subtracted for good numerical stability + void ApplySoftMaxPerRow(); /// Set each element to the sigmoid of the corresponding element of "src". void Sigmoid(const MatrixBase &src); @@ -545,6 +550,10 @@ class MatrixBase { /// *this += alpha * M [or M^T] void AddMat(const Real alpha, const MatrixBase &M, MatrixTransposeType transA = kNoTrans); + + /// *this += alpha * M [or M^T] when the matrices are stored as log + void LogAddExpMat(const Real alpha, const MatrixBase &M, + MatrixTransposeType transA = kNoTrans); /// *this = beta * *this + alpha * M M^T, for symmetric matrices. It only /// updates the lower triangle of *this. It will leave the matrix asymmetric; diff --git a/src/matrix/sparse-matrix.cc b/src/matrix/sparse-matrix.cc index 477d36f190a..c5bc868f48e 100644 --- a/src/matrix/sparse-matrix.cc +++ b/src/matrix/sparse-matrix.cc @@ -281,6 +281,14 @@ void SparseVector::Resize(MatrixIndexT dim, dim_ = dim; } +template +void SparseVector::Scale(BaseFloat scale) { + typename std::vector >::iterator it = pairs_.begin(); + for (; it != pairs_.end(); ++it) { + it->second *= scale; + } +} + template MatrixIndexT SparseMatrix::NumRows() const { return rows_.size(); @@ -574,6 +582,14 @@ void SparseMatrix::Resize(MatrixIndexT num_rows, rows_[row].Resize(num_cols, kCopyData); } } + +template +void SparseMatrix::Scale(BaseFloat scale) { + for (typename std::vector >::iterator it = rows_.begin(); + it != rows_.end(); ++it) { + it->Scale(scale); + } +} template void SparseMatrix::AppendSparseMatrixRows( @@ -705,15 +721,16 @@ MatrixIndexT GeneralMatrix::NumCols() const { } -void GeneralMatrix::Compress() { +void GeneralMatrix::Compress(int32 format) { if (mat_.NumRows() != 0) { - cmat_.CopyFromMat(mat_); + cmat_.CopyFromMat(mat_, format); mat_.Resize(0, 0); } } void GeneralMatrix::Uncompress() { if (cmat_.NumRows() != 0) { + mat_.Resize(cmat_.NumRows(), cmat_.NumCols(), kUndefined); cmat_.CopyToMat(&mat_); cmat_.Clear(); } @@ -1052,6 +1069,18 @@ void GeneralMatrix::AddToMat(BaseFloat alpha, MatrixBase *mat, } } + +void GeneralMatrix::Scale(BaseFloat scale) { + if(Type() == kCompressedMatrix) + Uncompress(); + if (Type() == kFullMatrix) { + mat_.Scale(scale); + } else if (Type() == kSparseMatrix) { + smat_.Scale(scale); + } +} + + template Real SparseVector::Max(int32 *index_out) const { KALDI_ASSERT(dim_ > 0 && pairs_.size() <= static_cast(dim_)); diff --git a/src/matrix/sparse-matrix.h b/src/matrix/sparse-matrix.h index 9f9362542e1..8ad62e0ac51 100644 --- a/src/matrix/sparse-matrix.h +++ b/src/matrix/sparse-matrix.h @@ -98,6 +98,8 @@ class SparseVector { /// Resizes to this dimension. resize_type == kUndefined /// behaves the same as kSetZero. void Resize(MatrixIndexT dim, MatrixResizeType resize_type = kSetZero); + + void Scale(BaseFloat scale); void Write(std::ostream &os, bool binary) const; @@ -196,6 +198,8 @@ class SparseMatrix { void Resize(MatrixIndexT rows, MatrixIndexT cols, MatrixResizeType resize_type = kSetZero); + void Scale(BaseFloat scale); + // Use the Matrix::CopyFromSmat() function to copy from this to Matrix. Also // see Matrix::AddSmat(). There is not very extensive functionality for // SparseMat just yet (e.g. no matrix multiply); we will add things as needed @@ -228,8 +232,10 @@ class GeneralMatrix { public: GeneralMatrixType Type() const; - void Compress(); // If it was a full matrix, compresses, changing Type() to - // kCompressedMatrix; otherwise does nothing. + /// If it was a full matrix, compresses, changing Type() to + /// kCompressedMatrix; otherwise does nothing. + /// format shows the compression format. + void Compress(int32 format = 0); void Uncompress(); // If it was a compressed matrix, uncompresses, changing // Type() to kFullMatrix; otherwise does nothing. @@ -284,6 +290,8 @@ class GeneralMatrix { void AddToMat(BaseFloat alpha, CuMatrixBase *cu_mat, MatrixTransposeType trans = kNoTrans) const; + void Scale(BaseFloat alpha); + /// Assignment from regular matrix. GeneralMatrix &operator= (const MatrixBase &mat); diff --git a/src/nnet3/nnet-chain-training.cc b/src/nnet3/nnet-chain-training.cc index 5fe28e8142b..8468c7703aa 100644 --- a/src/nnet3/nnet-chain-training.cc +++ b/src/nnet3/nnet-chain-training.cc @@ -136,14 +136,14 @@ void NnetChainTrainer::ProcessOutputs(const NnetChainExample &eg, computer->AcceptInput(sup.name, &nnet_output_deriv); objf_info_[sup.name].UpdateStats(sup.name, opts_.nnet_config.print_interval, - num_minibatches_processed_++, tot_weight, tot_objf, tot_l2_term); - + if (use_xent) { xent_deriv.Scale(opts_.chain_config.xent_regularize); computer->AcceptInput(xent_name, &xent_deriv); } } + num_minibatches_processed_++; } void NnetChainTrainer::UpdateParamsWithMaxChange() { diff --git a/src/nnet3/nnet-combine.cc b/src/nnet3/nnet-combine.cc index ba904b1c93a..186709fade4 100644 --- a/src/nnet3/nnet-combine.cc +++ b/src/nnet3/nnet-combine.cc @@ -504,15 +504,28 @@ double NnetCombiner::ComputeObjfAndDerivFromNnet( end = egs_.end(); for (; iter != end; ++iter) prob_computer_->Compute(*iter); - const SimpleObjectiveInfo *objf_info = prob_computer_->GetObjective("output"); - if (objf_info == NULL) - KALDI_ERR << "Error getting objective info (unsuitable egs?)"; - KALDI_ASSERT(objf_info->tot_weight > 0.0); + + double tot_weight = 0.0; + double tot_objf = 0.0; + + { + const unordered_map &objf_info = prob_computer_->GetAllObjectiveInfo(); + unordered_map::const_iterator objf_it = objf_info.begin(), + objf_end = objf_info.end(); + + for (; objf_it != objf_end; ++objf_it) { + tot_objf += objf_it->second.tot_objective; + tot_weight += objf_it->second.tot_weight; + } + } + + KALDI_ASSERT(tot_weight > 0.0); + const Nnet &deriv = prob_computer_->GetDeriv(); VectorizeNnet(deriv, nnet_params_deriv); // we prefer to deal with normalized objective functions. - nnet_params_deriv->Scale(1.0 / objf_info->tot_weight); - return objf_info->tot_objective / objf_info->tot_weight; + nnet_params_deriv->Scale(1.0 / tot_weight); + return tot_objf / tot_weight; } diff --git a/src/nnet3/nnet-component-itf.cc b/src/nnet3/nnet-component-itf.cc index 4a2a8d1c09a..63e01be8792 100644 --- a/src/nnet3/nnet-component-itf.cc +++ b/src/nnet3/nnet-component-itf.cc @@ -89,6 +89,10 @@ Component* Component::NewComponentOfType(const std::string &component_type) { ans = new SoftmaxComponent(); } else if (component_type == "LogSoftmaxComponent") { ans = new LogSoftmaxComponent(); + } else if (component_type == "LogComponent") { + ans = new LogComponent(); + } else if (component_type == "ExpComponent") { + ans = new ExpComponent(); } else if (component_type == "RectifiedLinearComponent") { ans = new RectifiedLinearComponent(); } else if (component_type == "NormalizeComponent") { @@ -119,6 +123,8 @@ Component* Component::NewComponentOfType(const std::string &component_type) { ans = new NoOpComponent(); } else if (component_type == "ClipGradientComponent") { ans = new ClipGradientComponent(); + } else if (component_type == "ScaleGradientComponent") { + ans = new ScaleGradientComponent(); } else if (component_type == "ElementwiseProductComponent") { ans = new ElementwiseProductComponent(); } else if (component_type == "ConvolutionComponent") { @@ -314,11 +320,14 @@ std::string NonlinearComponent::Info() const { std::stringstream stream; if (InputDim() == OutputDim()) { stream << Type() << ", dim=" << InputDim(); - } else { + } else if (OutputDim() - InputDim() == 1) { // Note: this is a very special case tailored for class NormalizeComponent. stream << Type() << ", input-dim=" << InputDim() << ", output-dim=" << OutputDim() << ", add-log-stddev=true"; + } else { + stream << Type() << ", input-dim=" << InputDim() + << ", output-dim=" << OutputDim(); } if (self_repair_lower_threshold_ != BaseFloat(kUnsetThreshold)) @@ -327,7 +336,7 @@ std::string NonlinearComponent::Info() const { stream << ", self-repair-upper-threshold=" << self_repair_upper_threshold_; if (self_repair_scale_ != 0.0) stream << ", self-repair-scale=" << self_repair_scale_; - if (count_ > 0 && value_sum_.Dim() == dim_ && deriv_sum_.Dim() == dim_) { + if (count_ > 0 && value_sum_.Dim() == dim_) { stream << ", count=" << std::setprecision(3) << count_ << std::setprecision(6); stream << ", self-repaired-proportion=" @@ -337,10 +346,12 @@ std::string NonlinearComponent::Info() const { Vector value_avg(value_avg_dbl); value_avg.Scale(1.0 / count_); stream << ", value-avg=" << SummarizeVector(value_avg); - Vector deriv_avg_dbl(deriv_sum_); - Vector deriv_avg(deriv_avg_dbl); - deriv_avg.Scale(1.0 / count_); - stream << ", deriv-avg=" << SummarizeVector(deriv_avg); + if (deriv_sum_.Dim() == dim_) { + Vector deriv_avg_dbl(deriv_sum_); + Vector deriv_avg(deriv_avg_dbl); + deriv_avg.Scale(1.0 / count_); + stream << ", deriv-avg=" << SummarizeVector(deriv_avg); + } } return stream.str(); } diff --git a/src/nnet3/nnet-component-itf.h b/src/nnet3/nnet-component-itf.h index 7cf438a025e..fae95de9651 100644 --- a/src/nnet3/nnet-component-itf.h +++ b/src/nnet3/nnet-component-itf.h @@ -401,6 +401,11 @@ class UpdatableComponent: public Component { /// Sets the learning rate directly, bypassing learning_rate_factor_. virtual void SetActualLearningRate(BaseFloat lrate) { learning_rate_ = lrate; } + /// Sets the learning rate factor + virtual void SetLearningRateFactor(BaseFloat lrate_factor) { + learning_rate_factor_ = lrate_factor; + } + /// \brief Sets is_gradient_ to true and sets learning_rate_ to 1, ignoring /// learning_rate_factor_. virtual void SetAsGradient() { learning_rate_ = 1.0; is_gradient_ = true; } @@ -410,12 +415,15 @@ class UpdatableComponent: public Component { /// a different value than x will returned. BaseFloat LearningRate() const { return learning_rate_; } + /// Gets the learning rate factor + BaseFloat LearningRateFactor() const { return learning_rate_factor_; } + /// Gets per-component max-change value. Note: the components themselves do /// not enforce the per-component max-change; it's enforced in class /// NnetTrainer by querying the max-changes for each component. /// See NnetTrainer::UpdateParamsWithMaxChange() in nnet3/nnet-training.cc. BaseFloat MaxChange() const { return max_change_; } - + virtual std::string Info() const; /// The following new virtual function returns the total dimension of diff --git a/src/nnet3/nnet-component-test.cc b/src/nnet3/nnet-component-test.cc index fdc9849dfc2..a7939226b3e 100644 --- a/src/nnet3/nnet-component-test.cc +++ b/src/nnet3/nnet-component-test.cc @@ -381,6 +381,11 @@ bool TestSimpleComponentDataDerivative(const Component &c, KALDI_LOG << "Accepting deriv differences since " << "it is ClipGradientComponent."; return true; + } + else if (c.Type() == "ScaleGradientComponent") { + KALDI_LOG << "Accepting deriv differences since " + << "it is ScaleGradientComponent."; + return true; } return ans; } diff --git a/src/nnet3/nnet-diagnostics.cc b/src/nnet3/nnet-diagnostics.cc index 302e2cbfa50..d0e801e27df 100644 --- a/src/nnet3/nnet-diagnostics.cc +++ b/src/nnet3/nnet-diagnostics.cc @@ -92,34 +92,49 @@ void NnetComputeProb::ProcessOutputs(const NnetExample &eg, << "mismatch for '" << io.name << "': " << output.NumCols() << " (nnet) vs. " << io.features.NumCols() << " (egs)\n"; } + + const Vector *deriv_weights = NULL; + if (config_.apply_deriv_weights && io.deriv_weights.Dim() > 0) + deriv_weights = &(io.deriv_weights); { BaseFloat tot_weight, tot_objf; bool supply_deriv = config_.compute_deriv; ComputeObjectiveFunction(io.features, obj_type, io.name, supply_deriv, computer, - &tot_weight, &tot_objf); + &tot_weight, &tot_objf, deriv_weights); SimpleObjectiveInfo &totals = objf_info_[io.name]; totals.tot_weight += tot_weight; totals.tot_objective += tot_objf; } - if (obj_type == kLinear && config_.compute_accuracy) { + if (config_.compute_accuracy) { BaseFloat tot_weight, tot_accuracy; + PerDimObjectiveInfo &totals = accuracy_info_[io.name]; + + if (config_.compute_per_dim_accuracy && + totals.tot_objective_vec.Dim() == 0) { + totals.tot_objective_vec.Resize(output.NumCols()); + totals.tot_weight_vec.Resize(output.NumCols()); + } + ComputeAccuracy(io.features, output, - &tot_weight, &tot_accuracy); - SimpleObjectiveInfo &totals = accuracy_info_[io.name]; + &tot_weight, &tot_accuracy, deriv_weights, + config_.compute_per_dim_accuracy ? + &totals.tot_weight_vec : NULL, + config_.compute_per_dim_accuracy ? + &totals.tot_objective_vec : NULL); totals.tot_weight += tot_weight; totals.tot_objective += tot_accuracy; } - num_minibatches_processed_++; } } + num_minibatches_processed_++; } bool NnetComputeProb::PrintTotalStats() const { bool ans = false; - unordered_map::const_iterator - iter, end; { // First print regular objectives + unordered_map::const_iterator iter, end; iter = objf_info_.begin(); end = objf_info_.end(); for (; iter != end; ++iter) { @@ -137,15 +152,34 @@ bool NnetComputeProb::PrintTotalStats() const { ans = true; } } - { // now print accuracies. + { + unordered_map::const_iterator iter, end; + // now print accuracies. iter = accuracy_info_.begin(); end = accuracy_info_.end(); for (; iter != end; ++iter) { const std::string &name = iter->first; - const SimpleObjectiveInfo &info = iter->second; + const PerDimObjectiveInfo &info = iter->second; KALDI_LOG << "Overall accuracy for '" << name << "' is " << (info.tot_objective / info.tot_weight) << " per frame" << ", over " << info.tot_weight << " frames."; + + if (info.tot_weight_vec.Dim() > 0) { + Vector accuracy_vec(info.tot_weight_vec.Dim()); + for (size_t j = 0; j < info.tot_weight_vec.Dim(); j++) { + if (info.tot_weight_vec(j) != 0) { + accuracy_vec(j) = info.tot_objective_vec(j) + / info.tot_weight_vec(j); + } else { + accuracy_vec(j) = -1.0; + } + } + + KALDI_LOG << "Overall per-dim accuracy vector for '" << name + << "' is " << accuracy_vec << " per frame" + << ", over " << info.tot_weight << " frames."; + } // don't bother changing ans; the loop over the regular objective should // already have set it to true if we got any data. } @@ -156,12 +190,20 @@ bool NnetComputeProb::PrintTotalStats() const { void ComputeAccuracy(const GeneralMatrix &supervision, const CuMatrixBase &nnet_output, BaseFloat *tot_weight_out, - BaseFloat *tot_accuracy_out) { + BaseFloat *tot_accuracy_out, + const Vector *deriv_weights, + Vector *tot_weight_vec, + Vector *tot_accuracy_vec) { int32 num_rows = nnet_output.NumRows(), num_cols = nnet_output.NumCols(); KALDI_ASSERT(supervision.NumRows() == num_rows && supervision.NumCols() == num_cols); + if (tot_accuracy_vec || tot_weight_vec) + KALDI_ASSERT(tot_accuracy_vec && tot_weight_vec && + tot_accuracy_vec->Dim() == num_cols && + tot_weight_vec->Dim() == num_cols); + CuArray best_index(num_rows); nnet_output.FindRowMaxId(&best_index); std::vector best_index_cpu; @@ -181,27 +223,40 @@ void ComputeAccuracy(const GeneralMatrix &supervision, for (int32 r = 0; r < num_rows; r++) { SubVector vec(mat, r); BaseFloat row_sum = vec.Sum(); - KALDI_ASSERT(row_sum >= 0.0); + // KALDI_ASSERT(row_sum >= 0.0); // For conventional ASR systems int32 best_index; vec.Max(&best_index); // discard max value. + if (deriv_weights) + row_sum *= (*deriv_weights)(r); tot_weight += row_sum; - if (best_index == best_index_cpu[r]) + if (tot_weight_vec) + (*tot_weight_vec)(best_index) += row_sum; + if (best_index == best_index_cpu[r]) { tot_accuracy += row_sum; + if (tot_accuracy_vec) + (*tot_accuracy_vec)(best_index) += row_sum; + } } break; - } case kFullMatrix: { const Matrix &mat = supervision.GetFullMatrix(); for (int32 r = 0; r < num_rows; r++) { SubVector vec(mat, r); BaseFloat row_sum = vec.Sum(); - KALDI_ASSERT(row_sum >= 0.0); + // KALDI_ASSERT(row_sum >= 0.0); // For conventional ASR systems int32 best_index; vec.Max(&best_index); // discard max value. + if (deriv_weights) + row_sum *= (*deriv_weights)(r); tot_weight += row_sum; - if (best_index == best_index_cpu[r]) + if (tot_weight_vec) + (*tot_weight_vec)(best_index) += row_sum; + if (best_index == best_index_cpu[r]) { tot_accuracy += row_sum; + if (tot_accuracy_vec) + (*tot_accuracy_vec)(best_index) += row_sum; + } } break; } @@ -212,10 +267,17 @@ void ComputeAccuracy(const GeneralMatrix &supervision, BaseFloat row_sum = row.Sum(); int32 best_index; row.Max(&best_index); + if (deriv_weights) + row_sum *= (*deriv_weights)(r); KALDI_ASSERT(best_index < num_cols); tot_weight += row_sum; - if (best_index == best_index_cpu[r]) + if (tot_weight_vec) + (*tot_weight_vec)(best_index) += row_sum; + if (best_index == best_index_cpu[r]) { tot_accuracy += row_sum; + if (tot_accuracy_vec) + (*tot_accuracy_vec)(best_index) += row_sum; + } } break; } diff --git a/src/nnet3/nnet-diagnostics.h b/src/nnet3/nnet-diagnostics.h index fd2ceb1df9e..7c2f7ac7734 100644 --- a/src/nnet3/nnet-diagnostics.h +++ b/src/nnet3/nnet-diagnostics.h @@ -36,21 +36,30 @@ struct SimpleObjectiveInfo { double tot_objective; SimpleObjectiveInfo(): tot_weight(0.0), tot_objective(0.0) { } - }; +struct PerDimObjectiveInfo : SimpleObjectiveInfo { + Vector tot_weight_vec; + Vector tot_objective_vec; + PerDimObjectiveInfo(): SimpleObjectiveInfo() { } +}; struct NnetComputeProbOptions { bool debug_computation; bool compute_deriv; bool compute_accuracy; + bool compute_per_dim_accuracy; + bool apply_deriv_weights; + NnetOptimizeOptions optimize_config; NnetComputeOptions compute_config; CachingOptimizingCompilerOptions compiler_config; NnetComputeProbOptions(): debug_computation(false), compute_deriv(false), - compute_accuracy(true) { } + compute_accuracy(true), + compute_per_dim_accuracy(false), + apply_deriv_weights(true) { } void Register(OptionsItf *opts) { // compute_deriv is not included in the command line options // because it's not relevant for nnet3-compute-prob. @@ -58,6 +67,11 @@ struct NnetComputeProbOptions { "debug for the actual computation (very verbose!)"); opts->Register("compute-accuracy", &compute_accuracy, "If true, compute " "accuracy values as well as objective functions"); + opts->Register("compute-per-dim-accuracy", &compute_per_dim_accuracy, + "If true, compute accuracy values per-dim"); + opts->Register("apply-deriv-weights", &apply_deriv_weights, + "Apply per-frame deriv weights"); + // register the optimization options with the prefix "optimization". ParseOptions optimization_opts("optimization", opts); optimize_config.Register(&optimization_opts); @@ -100,11 +114,17 @@ class NnetComputeProb { // or NULL if there is no such info. const SimpleObjectiveInfo *GetObjective(const std::string &output_name) const; + // return objective info for all outputs + const unordered_map & GetAllObjectiveInfo() const { + return objf_info_; + } + // if config.compute_deriv == true, returns a reference to the // computed derivative. Otherwise crashes. const Nnet &GetDeriv() const; ~NnetComputeProb(); + private: void ProcessOutputs(const NnetExample &eg, NnetComputer *computer); @@ -120,7 +140,7 @@ class NnetComputeProb { unordered_map objf_info_; - unordered_map accuracy_info_; + unordered_map accuracy_info_; }; @@ -155,7 +175,10 @@ class NnetComputeProb { void ComputeAccuracy(const GeneralMatrix &supervision, const CuMatrixBase &nnet_output, BaseFloat *tot_weight, - BaseFloat *tot_accuracy); + BaseFloat *tot_accuracy, + const Vector *deriv_weights = NULL, + Vector *tot_weight_vec = NULL, + Vector *tot_accuracy_vec = NULL); } // namespace nnet3 diff --git a/src/nnet3/nnet-example-utils.cc b/src/nnet3/nnet-example-utils.cc index 088772bcba7..6dd10fba9b8 100644 --- a/src/nnet3/nnet-example-utils.cc +++ b/src/nnet3/nnet-example-utils.cc @@ -66,9 +66,9 @@ static void GetIoSizes(const std::vector &src, KALDI_ASSERT(*names_iter == io.name); int32 i = names_iter - names_begin; int32 this_dim = io.features.NumCols(); - if (dims[i] == -1) + if (dims[i] == -1) { dims[i] = this_dim; - else if(dims[i] != this_dim) { + } else if (dims[i] != this_dim) { KALDI_ERR << "Merging examples with inconsistent feature dims: " << dims[i] << " vs. " << this_dim << " for '" << io.name << "'."; @@ -90,9 +90,20 @@ static void MergeIo(const std::vector &src, const std::vector &sizes, bool compress, NnetExample *merged_eg) { + // The total number of Indexes we have across all examples. int32 num_feats = names.size(); + std::vector cur_size(num_feats, 0); + + // The features in the different NnetIo in the Indexes across all examples std::vector > output_lists(num_feats); + + // The deriv weights in the different NnetIo in the Indexes across all + // examples + std::vector const*> > + output_deriv_weights(num_feats); + + // Initialize the merged_eg merged_eg->io.clear(); merged_eg->io.resize(num_feats); for (int32 f = 0; f < num_feats; f++) { @@ -105,20 +116,27 @@ static void MergeIo(const std::vector &src, std::vector::const_iterator names_begin = names.begin(), names_end = names.end(); - std::vector::const_iterator iter = src.begin(), end = src.end(); - for (int32 n = 0; iter != end; ++iter,++n) { - std::vector::const_iterator iter2 = iter->io.begin(), - end2 = iter->io.end(); - for (; iter2 != end2; ++iter2) { - const NnetIo &io = *iter2; + std::vector::const_iterator eg_iter = src.begin(), + eg_end = src.end(); + for (int32 n = 0; eg_iter != eg_end; ++eg_iter, ++n) { + std::vector::const_iterator io_iter = eg_iter->io.begin(), + io_end = eg_iter->io.end(); + for (; io_iter != io_end; ++io_iter) { + const NnetIo &io = *io_iter; std::vector::const_iterator names_iter = std::lower_bound(names_begin, names_end, io.name); KALDI_ASSERT(*names_iter == io.name); + int32 f = names_iter - names_begin; - int32 this_size = io.indexes.size(), - &this_offset = cur_size[f]; + int32 this_size = io.indexes.size(); + int32 &this_offset = cur_size[f]; KALDI_ASSERT(this_size + this_offset <= sizes[f]); + + // Add f^th Io's features and deriv_weights output_lists[f].push_back(&(io.features)); + output_deriv_weights[f].push_back(&(io.deriv_weights)); + + // Work on the Indexes for the f^th Io in merged_eg NnetIo &output_io = merged_eg->io[f]; std::copy(io.indexes.begin(), io.indexes.end(), output_io.indexes.begin() + this_offset); @@ -142,11 +160,28 @@ static void MergeIo(const std::vector &src, // the following won't do anything if the features were sparse. merged_eg->io[f].features.Compress(); } + + Vector &this_deriv_weights = merged_eg->io[f].deriv_weights; + this_deriv_weights.Resize( + merged_eg->io[f].indexes.size(), kUndefined); + this_deriv_weights.Set(1.0); + KALDI_ASSERT(this_deriv_weights.Dim() == + merged_eg->io[f].features.NumRows()); + + std::vector const*>::const_iterator + it = output_deriv_weights[f].begin(), + end = output_deriv_weights[f].end(); + + for (int32 i = 0, cur_offset = 0; it != end; ++it, i++) { + if((*it)->Dim() > 0) { + KALDI_ASSERT((*it)->Dim() == output_lists[f][i]->NumRows()); + this_deriv_weights.Range(cur_offset, (*it)->Dim()).CopyFromVec(**it); + } + cur_offset += output_lists[f][i]->NumRows(); + } } } - - void MergeExamples(const std::vector &src, bool compress, NnetExample *merged_eg) { @@ -1251,6 +1286,13 @@ void ExampleMerger::Finish() { stats_.PrintStats(); } +int32 NumOutputs(const NnetExample &eg) { + int32 num_outputs = 0; + for (size_t i = 0; i < eg.io.size(); i++) + if (eg.io[i].name.find("output") != std::string::npos) + num_outputs++; + return num_outputs; +} } // namespace nnet3 } // namespace kaldi diff --git a/src/nnet3/nnet-example-utils.h b/src/nnet3/nnet-example-utils.h index debd93599e9..de4efa73d5e 100644 --- a/src/nnet3/nnet-example-utils.h +++ b/src/nnet3/nnet-example-utils.h @@ -516,6 +516,8 @@ class ExampleMerger { }; +// Returns the number of outputs in an eg +int32 NumOutputs(const NnetExample &eg); } // namespace nnet3 } // namespace kaldi diff --git a/src/nnet3/nnet-example.cc b/src/nnet3/nnet-example.cc index c011f2a0b8a..f1efa80fcdc 100644 --- a/src/nnet3/nnet-example.cc +++ b/src/nnet3/nnet-example.cc @@ -19,6 +19,7 @@ // limitations under the License. #include "nnet3/nnet-example.h" +#include "nnet3/nnet-example-utils.h" #include "lat/lattice-functions.h" #include "hmm/posterior.h" @@ -31,6 +32,8 @@ void NnetIo::Write(std::ostream &os, bool binary) const { WriteToken(os, binary, name); WriteIndexVector(os, binary, indexes); features.Write(os, binary); + WriteToken(os, binary, ""); // for DerivWeights. Want to save space. + WriteVectorAsChar(os, binary, deriv_weights); WriteToken(os, binary, ""); KALDI_ASSERT(static_cast(features.NumRows()) == indexes.size()); } @@ -40,7 +43,14 @@ void NnetIo::Read(std::istream &is, bool binary) { ReadToken(is, binary, &name); ReadIndexVector(is, binary, &indexes); features.Read(is, binary); - ExpectToken(is, binary, ""); + std::string token; + ReadToken(is, binary, &token); + // in the future this back-compatibility code can be reworked. + if (token != "") { + KALDI_ASSERT(token == ""); + ReadVectorAsChar(is, binary, &deriv_weights); + ExpectToken(is, binary, ""); + } } bool NnetIo::operator == (const NnetIo &other) const { @@ -52,42 +62,75 @@ bool NnetIo::operator == (const NnetIo &other) const { Matrix this_mat, other_mat; features.GetMatrix(&this_mat); other.features.GetMatrix(&other_mat); - return ApproxEqual(this_mat, other_mat); + return (ApproxEqual(this_mat, other_mat) && + deriv_weights.ApproxEqual(other.deriv_weights)); } NnetIo::NnetIo(const std::string &name, - int32 t_begin, const MatrixBase &feats): + int32 t_begin, const MatrixBase &feats, + int32 skip_frame): name(name), features(feats) { - int32 num_rows = feats.NumRows(); - KALDI_ASSERT(num_rows > 0); - indexes.resize(num_rows); // sets all n,t,x to zeros. - for (int32 i = 0; i < num_rows; i++) - indexes[i].t = t_begin + i; + int32 num_skipped_rows = feats.NumRows(); + KALDI_ASSERT(num_skipped_rows > 0); + indexes.resize(num_skipped_rows); // sets all n,t,x to zeros. + for (int32 i = 0; i < num_skipped_rows; i++) + indexes[i].t = t_begin + i * skip_frame; +} + +NnetIo::NnetIo(const std::string &name, + const VectorBase &deriv_weights, + int32 t_begin, const MatrixBase &feats, + int32 skip_frame): + name(name), features(feats), deriv_weights(deriv_weights) { + int32 num_skipped_rows = feats.NumRows(); + KALDI_ASSERT(num_skipped_rows > 0); + indexes.resize(num_skipped_rows); // sets all n,t,x to zeros. + for (int32 i = 0; i < num_skipped_rows; i++) + indexes[i].t = t_begin + i * skip_frame; } void NnetIo::Swap(NnetIo *other) { name.swap(other->name); indexes.swap(other->indexes); features.Swap(&(other->features)); + deriv_weights.Swap(&(other->deriv_weights)); } NnetIo::NnetIo(const std::string &name, int32 dim, int32 t_begin, - const Posterior &labels): + const Posterior &labels, + int32 skip_frame): name(name) { - int32 num_rows = labels.size(); - KALDI_ASSERT(num_rows > 0); + int32 num_skipped_rows = labels.size(); + KALDI_ASSERT(num_skipped_rows > 0); SparseMatrix sparse_feats(dim, labels); features = sparse_feats; - indexes.resize(num_rows); // sets all n,t,x to zeros. - for (int32 i = 0; i < num_rows; i++) - indexes[i].t = t_begin + i; + indexes.resize(num_skipped_rows); // sets all n,t,x to zeros. + for (int32 i = 0; i < num_skipped_rows; i++) + indexes[i].t = t_begin + i * skip_frame; } - +NnetIo::NnetIo(const std::string &name, + const VectorBase &deriv_weights, + int32 dim, + int32 t_begin, + const Posterior &labels, + int32 skip_frame): + name(name), deriv_weights(deriv_weights) { + int32 num_skipped_rows = labels.size(); + KALDI_ASSERT(num_skipped_rows > 0); + SparseMatrix sparse_feats(dim, labels); + features = sparse_feats; + indexes.resize(num_skipped_rows); // sets all n,t,x to zeros. + for (int32 i = 0; i < num_skipped_rows; i++) + indexes[i].t = t_begin + i * skip_frame; +} void NnetExample::Write(std::ostream &os, bool binary) const { +#ifdef KALDI_PARANOID + KALDI_ASSERT(NumOutputs(eg) > 0); +#endif // Note: weight, label, input_frames and spk_info are members. This is a // struct. WriteToken(os, binary, ""); @@ -114,12 +157,12 @@ void NnetExample::Read(std::istream &is, bool binary) { } -void NnetExample::Compress() { +void NnetExample::Compress(int32 format) { std::vector::iterator iter = io.begin(), end = io.end(); // calling features.Compress() will do nothing if they are sparse or already // compressed. for (; iter != end; ++iter) - iter->features.Compress(); + iter->features.Compress(format); } diff --git a/src/nnet3/nnet-example.h b/src/nnet3/nnet-example.h index 347894e958c..ce2a76ce79d 100644 --- a/src/nnet3/nnet-example.h +++ b/src/nnet3/nnet-example.h @@ -45,12 +45,32 @@ struct NnetIo { /// a Matrix, or SparseMatrix (a SparseMatrix would be the natural format for posteriors). GeneralMatrix features; + /// This is a vector of per-frame weights, required to be between 0 and 1, + /// that is applied to the derivative during training (but not during model + /// combination, where the derivatives need to agree with the computed objf + /// values for the optimization code to work). + /// If this vector is empty it means we're not applying per-frame weights, + /// so it's equivalent to a vector of all ones. This vector is written + /// to disk compactly as unsigned char. + Vector deriv_weights; + /// This constructor creates NnetIo with name "name", indexes with n=0, x=0, /// and t values ranging from t_begin to t_begin + feats.NumRows() - 1, and /// the provided features. t_begin should be the frame that feats.Row(0) /// represents. NnetIo(const std::string &name, - int32 t_begin, const MatrixBase &feats); + int32 t_begin, + const MatrixBase &feats, + int32 skip_frame = 1); + + /// This is similar to the above constructor but also takes in a + /// a deriv weights argument. + NnetIo(const std::string &name, + const VectorBase &deriv_weights, + int32 t_begin, + const MatrixBase &feats, + int32 skip_frame = 1); + /// This constructor sets "name" to the provided string, sets "indexes" with /// n=0, x=0, and t from t_begin to t_begin + labels.size() - 1, and the labels @@ -58,12 +78,30 @@ struct NnetIo { NnetIo(const std::string &name, int32 dim, int32 t_begin, - const Posterior &labels); + const Posterior &labels, + int32 skip_frame = 1); + + /// This is similar to the above constructor but also takes in a + /// a deriv weights argument. + NnetIo(const std::string &name, + const VectorBase &deriv_weights, + int32 dim, + int32 t_begin, + const Posterior &labels, + int32 skip_frame = 1); void Swap(NnetIo *other); NnetIo() { } + // Compress the features in this NnetIo structure with specified format. + // the "format" will be 1 for the original format where each column has a + // PerColHeader, and 2 for the format, where everything is represented as + // 16-bit integers. + // If format <= 0, then format 1 will be used, unless the matrix has 8 or + // fewer rows (in which case format 2 will be used). + void Compress(int32 format = 0) { features.Compress(format); } + // Use default copy constructor and assignment operators. void Write(std::ostream &os, bool binary) const; @@ -96,7 +134,6 @@ struct NnetIoStructureCompare { /// more frames of input, used for standard cross-entropy training of neural /// nets (and possibly for other objective functions). struct NnetExample { - /// "io" contains the input and output. In principle there can be multiple /// types of both input and output, with different names. The order is /// irrelevant. @@ -111,8 +148,13 @@ struct NnetExample { void Swap(NnetExample *other) { io.swap(other->io); } - /// Compresses any (input) features that are not sparse. - void Compress(); + // Compresses any features that are not sparse and not compressed. + // The "format" is 1 for the original format where each column has a + // PerColHeader, and 2 for the format, where everything is represented as + // 16-bit integers. + // If format <= 0, then format 1 will be used, unless the matrix has 8 or + // fewer rows (in which case format 2 will be used). + void Compress(int32 format = 0); /// Caution: this operator == is not very efficient. It's only used in /// testing code. diff --git a/src/nnet3/nnet-nnet.cc b/src/nnet3/nnet-nnet.cc index dd90af739e7..c4020d99d1d 100644 --- a/src/nnet3/nnet-nnet.cc +++ b/src/nnet3/nnet-nnet.cc @@ -86,8 +86,14 @@ std::string Nnet::GetAsConfigLine(int32 node_index, bool include_dim) const { node.descriptor.WriteConfig(ans, node_names_); if (include_dim) ans << " dim=" << node.Dim(*this); - ans << " objective=" << (node.u.objective_type == kLinear ? "linear" : - "quadratic"); + + if (node.u.objective_type == kLinear) + ans << " objective=linear"; + else if (node.u.objective_type == kQuadratic) + ans << " objective=quadratic"; + else if (node.u.objective_type == kXentPerDim) + ans << " objective=xent-per-dim"; + break; case kComponent: ans << "component-node name=" << name << " component=" @@ -392,6 +398,8 @@ void Nnet::ProcessOutputNodeConfigLine( nodes_[node_index].u.objective_type = kLinear; } else if (objective_type == "quadratic") { nodes_[node_index].u.objective_type = kQuadratic; + } else if (objective_type == "xent-per-dim") { + nodes_[node_index].u.objective_type = kXentPerDim; } else { KALDI_ERR << "Invalid objective type: " << objective_type; } diff --git a/src/nnet3/nnet-nnet.h b/src/nnet3/nnet-nnet.h index 5eb87fd30f3..b3b36f8b87b 100644 --- a/src/nnet3/nnet-nnet.h +++ b/src/nnet3/nnet-nnet.h @@ -49,7 +49,12 @@ namespace nnet3 { /// - Objective type kQuadratic is used to mean the objective function /// f(x, y) = -0.5 (x-y).(x-y), which is to be maximized, as in the kLinear /// case. -enum ObjectiveType { kLinear, kQuadratic }; +/// - Objective type kXentPerDim is the objective function that is used +/// to learn a set of bernoulli random variables. +/// f(x, y) = x * y + (1-x) * Log(1-Exp(y)), where +/// x is the true probability of class 1 and +/// y is the predicted log probability of class 1 +enum ObjectiveType { kLinear, kQuadratic, kXentPerDim }; enum NodeType { kInput, kDescriptor, kComponent, kDimRange, kNone }; diff --git a/src/nnet3/nnet-optimize-utils.cc b/src/nnet3/nnet-optimize-utils.cc index 60ec93f3f18..72f4147931b 100644 --- a/src/nnet3/nnet-optimize-utils.cc +++ b/src/nnet3/nnet-optimize-utils.cc @@ -2523,7 +2523,7 @@ void ComputationExpander::ExpandRowRangesCommand( num_rows_new = expanded_computation_->submatrices[s1].num_rows; KALDI_ASSERT(static_cast(c_in.arg3) < computation_.indexes_ranges.size()); - KALDI_ASSERT(num_rows_old % 2 == 0); + //KALDI_ASSERT(num_rows_old % 2 == 0); int32 num_n_values = num_n_values_; diff --git a/src/nnet3/nnet-simple-component.cc b/src/nnet3/nnet-simple-component.cc index 91f8f5139b2..cfdddc9c44a 100644 --- a/src/nnet3/nnet-simple-component.cc +++ b/src/nnet3/nnet-simple-component.cc @@ -920,6 +920,87 @@ void ClipGradientComponent::Add(BaseFloat alpha, const Component &other_in) { num_clipped_ += alpha * other->num_clipped_; } + +void ScaleGradientComponent::Init(const CuVectorBase &scales) { + KALDI_ASSERT(scales.Dim() != 0); + scales_ = scales; +} + + +void ScaleGradientComponent::InitFromConfig(ConfigLine *cfl) { + std::string filename; + // Accepts "scales" config (for filename) or "dim" -> random init, for testing. + if (cfl->GetValue("scales", &filename)) { + if (cfl->HasUnusedValues()) + KALDI_ERR << "Invalid initializer for layer of type " + << Type() << ": \"" << cfl->WholeLine() << "\""; + CuVector vec; + ReadKaldiObject(filename, &vec); + Init(vec); + } else { + int32 dim; + BaseFloat scale = 1.0; + bool scale_ok = cfl->GetValue("scale", &scale); + if (!cfl->GetValue("dim", &dim) || cfl->HasUnusedValues()) + KALDI_ERR << "Invalid initializer for layer of type " + << Type() << ": \"" << cfl->WholeLine() << "\""; + KALDI_ASSERT(dim > 0); + CuVector vec(dim); + if (scale_ok) { + vec.Set(scale); + } else { + vec.SetRandn(); + } + Init(vec); + } +} + + +std::string ScaleGradientComponent::Info() const { + std::ostringstream stream; + stream << Component::Info(); + PrintParameterStats(stream, "scales", scales_, true); + return stream.str(); +} + +void ScaleGradientComponent::Propagate(const ComponentPrecomputedIndexes *indexes, + const CuMatrixBase &in, + CuMatrixBase *out) const { + out->CopyFromMat(in); // does nothing if same matrix. +} + +void ScaleGradientComponent::Backprop(const std::string &debug_info, + const ComponentPrecomputedIndexes *indexes, + const CuMatrixBase &, // in_value + const CuMatrixBase &, // out_value + const CuMatrixBase &out_deriv, + Component *, // to_update + CuMatrixBase *in_deriv) const { + in_deriv->CopyFromMat(out_deriv); // does nothing if same memory. + in_deriv->MulColsVec(scales_); +} + +Component* ScaleGradientComponent::Copy() const { + ScaleGradientComponent *ans = new ScaleGradientComponent(); + ans->scales_ = scales_; + return ans; +} + + +void ScaleGradientComponent::Write(std::ostream &os, bool binary) const { + WriteToken(os, binary, ""); + WriteToken(os, binary, ""); + scales_.Write(os, binary); + WriteToken(os, binary, ""); +} + +void ScaleGradientComponent::Read(std::istream &is, bool binary) { + ExpectOneOrTwoTokens(is, binary, "", ""); + scales_.Read(is, binary); + ExpectToken(is, binary, ""); +} + + void TanhComponent::Propagate(const ComponentPrecomputedIndexes *indexes, const CuMatrixBase &in, CuMatrixBase *out) const { @@ -2492,6 +2573,26 @@ void ConstantFunctionComponent::UnVectorize(const VectorBase ¶ms) output_.CopyFromVec(params); } +void ExpComponent::Propagate(const ComponentPrecomputedIndexes *indexes, + const CuMatrixBase &in, + CuMatrixBase *out) const { + // Applied exp function + out->CopyFromMat(in); + out->ApplyExp(); +} + +void ExpComponent::Backprop(const std::string &debug_info, + const ComponentPrecomputedIndexes *indexes, + const CuMatrixBase &,//in_value, + const CuMatrixBase &out_value, + const CuMatrixBase &out_deriv, + Component *to_update, + CuMatrixBase *in_deriv) const { + if (in_deriv != NULL) { + in_deriv->CopyFromMat(out_value); + in_deriv->MulElements(out_deriv); + } +} NaturalGradientAffineComponent::NaturalGradientAffineComponent(): max_change_per_sample_(0.0), @@ -2543,10 +2644,15 @@ void NaturalGradientAffineComponent::Read(std::istream &is, bool binary) { ReadBasicType(is, binary, &max_change_scale_stats_); ReadToken(is, binary, &token); } - if (token != "" && - token != "") - KALDI_ERR << "Expected or " - << ", got " << token; + + std::ostringstream ostr_beg, ostr_end; + ostr_beg << "<" << Type() << ">"; // e.g. "" + ostr_end << ""; // e.g. "" + + if (token != ostr_end.str() && + token != ostr_beg.str()) + KALDI_ERR << "Expected " << ostr_beg.str() << " or " + << ostr_end.str() << ", got " << token; SetNaturalGradientConfigs(); } @@ -2695,7 +2801,10 @@ void NaturalGradientAffineComponent::Write(std::ostream &os, WriteBasicType(os, binary, active_scaling_count_); WriteToken(os, binary, ""); WriteBasicType(os, binary, max_change_scale_stats_); - WriteToken(os, binary, ""); + + std::ostringstream ostr_end; + ostr_end << ""; // e.g. "" + WriteToken(os, binary, ostr_end.str()); } std::string NaturalGradientAffineComponent::Info() const { @@ -3078,6 +3187,126 @@ void SoftmaxComponent::StoreStats(const CuMatrixBase &out_value) { StoreStatsInternal(out_value, NULL); } +std::string LogComponent::Info() const { + std::stringstream stream; + stream << NonlinearComponent::Info() + << ", log-floor=" << log_floor_; + return stream.str(); +} + +void LogComponent::InitFromConfig(ConfigLine *cfl) { + cfl->GetValue("log-floor", &log_floor_); + NonlinearComponent::InitFromConfig(cfl); +} + +void LogComponent::Propagate(const ComponentPrecomputedIndexes *indexes, + const CuMatrixBase &in, + CuMatrixBase *out) const { + // Apllies log function (x >= epsi ? log(x) : log(epsi)). + out->CopyFromMat(in); + out->ApplyFloor(log_floor_); + out->ApplyLog(); +} + +void LogComponent::Backprop(const std::string &debug_info, + const ComponentPrecomputedIndexes *indexes, + const CuMatrixBase &in_value, + const CuMatrixBase &out_value, + const CuMatrixBase &out_deriv, + Component *to_update, + CuMatrixBase *in_deriv) const { + if (in_deriv != NULL) { + CuMatrix divided_in_value(in_value), floored_in_value(in_value); + divided_in_value.Set(1.0); + floored_in_value.CopyFromMat(in_value); + floored_in_value.ApplyFloor(log_floor_); // (x > epsi ? x : epsi) + + divided_in_value.DivElements(floored_in_value); // (x > epsi ? 1/x : 1/epsi) + in_deriv->CopyFromMat(in_value); + in_deriv->Add(-1.0 * log_floor_); // (x - epsi) + in_deriv->ApplyHeaviside(); // (x > epsi ? 1 : 0) + in_deriv->MulElements(divided_in_value); // (dy/dx: x > epsi ? 1/x : 0) + in_deriv->MulElements(out_deriv); // dF/dx = dF/dy * dy/dx + } +} + +void LogComponent::Read(std::istream &is, bool binary) { + std::ostringstream ostr_beg, ostr_end; + ostr_beg << "<" << Type() << ">"; // e.g. "" + ostr_end << ""; // e.g. "" + ExpectOneOrTwoTokens(is, binary, ostr_beg.str(), ""); + ReadBasicType(is, binary, &dim_); // Read dimension. + ExpectToken(is, binary, ""); + value_sum_.Read(is, binary); + ExpectToken(is, binary, ""); + deriv_sum_.Read(is, binary); + ExpectToken(is, binary, ""); + ReadBasicType(is, binary, &count_); + value_sum_.Scale(count_); + deriv_sum_.Scale(count_); + + std::string token; + ReadToken(is, binary, &token); + if (token == "") { + ReadBasicType(is, binary, &self_repair_lower_threshold_); + ReadToken(is, binary, &token); + } + if (token == "") { + ReadBasicType(is, binary, &self_repair_upper_threshold_); + ReadToken(is, binary, &token); + } + if (token == "") { + ReadBasicType(is, binary, &self_repair_scale_); + ReadToken(is, binary, &token); + } + if (token == "") { + ReadBasicType(is, binary, &log_floor_); + ReadToken(is, binary, &token); + } + if (token != ostr_end.str()) { + KALDI_ERR << "Expected token " << ostr_end.str() + << ", got " << token; + } +} + +void LogComponent::Write(std::ostream &os, bool binary) const { + std::ostringstream ostr_beg, ostr_end; + ostr_beg << "<" << Type() << ">"; // e.g. "" + ostr_end << ""; // e.g. "" + WriteToken(os, binary, ostr_beg.str()); + WriteToken(os, binary, ""); + WriteBasicType(os, binary, dim_); + // Write the values and derivatives in a count-normalized way, for + // greater readability in text form. + WriteToken(os, binary, ""); + Vector temp(value_sum_); + if (count_ != 0.0) temp.Scale(1.0 / count_); + temp.Write(os, binary); + WriteToken(os, binary, ""); + + temp.Resize(deriv_sum_.Dim(), kUndefined); + temp.CopyFromVec(deriv_sum_); + if (count_ != 0.0) temp.Scale(1.0 / count_); + temp.Write(os, binary); + WriteToken(os, binary, ""); + WriteBasicType(os, binary, count_); + if (self_repair_lower_threshold_ != kUnsetThreshold) { + WriteToken(os, binary, ""); + WriteBasicType(os, binary, self_repair_lower_threshold_); + } + if (self_repair_upper_threshold_ != kUnsetThreshold) { + WriteToken(os, binary, ""); + WriteBasicType(os, binary, self_repair_upper_threshold_); + } + if (self_repair_scale_ != 0.0) { + WriteToken(os, binary, ""); + WriteBasicType(os, binary, self_repair_scale_); + } + WriteToken(os, binary, ""); + WriteBasicType(os, binary, log_floor_); + WriteToken(os, binary, ostr_end.str()); +} + void LogSoftmaxComponent::Propagate(const ComponentPrecomputedIndexes *indexes, const CuMatrixBase &in, @@ -3118,12 +3347,18 @@ void FixedScaleComponent::InitFromConfig(ConfigLine *cfl) { Init(vec); } else { int32 dim; + BaseFloat scale = 1.0; + bool scale_ok = cfl->GetValue("scale", &scale); if (!cfl->GetValue("dim", &dim) || cfl->HasUnusedValues()) KALDI_ERR << "Invalid initializer for layer of type " << Type() << ": \"" << cfl->WholeLine() << "\""; KALDI_ASSERT(dim > 0); CuVector vec(dim); - vec.SetRandn(); + if (scale_ok) { + vec.Set(scale); + } else { + vec.SetRandn(); + } Init(vec); } } diff --git a/src/nnet3/nnet-simple-component.h b/src/nnet3/nnet-simple-component.h index 60fd1634598..a7935bcc7a7 100644 --- a/src/nnet3/nnet-simple-component.h +++ b/src/nnet3/nnet-simple-component.h @@ -703,6 +703,71 @@ class LogSoftmaxComponent: public NonlinearComponent { LogSoftmaxComponent &operator = (const LogSoftmaxComponent &other); // Disallow. }; +// The LogComponent outputs the log of input values as y = Log(max(x, epsi)) +class LogComponent: public NonlinearComponent { + public: + explicit LogComponent(const LogComponent &other): + NonlinearComponent(other), log_floor_(other.log_floor_) { } + LogComponent(): log_floor_(1e-20) { } + virtual std::string Type() const { return "LogComponent"; } + virtual int32 Properties() const { + return kSimpleComponent|kBackpropNeedsInput|kStoresStats; + } + + virtual std::string Info() const; + + virtual void InitFromConfig(ConfigLine *cfl); + + virtual void Propagate(const ComponentPrecomputedIndexes *indexes, + const CuMatrixBase &in, + CuMatrixBase *out) const; + virtual void Backprop(const std::string &debug_info, + const ComponentPrecomputedIndexes *indexes, + const CuMatrixBase &in_value, + const CuMatrixBase &out_value, + const CuMatrixBase &out_deriv, + Component *to_update, + CuMatrixBase *in_deriv) const; + + virtual Component* Copy() const { return new LogComponent(*this); } + + virtual void Read(std::istream &is, bool binary); + + virtual void Write(std::ostream &os, bool binary) const; + + private: + LogComponent &operator = (const LogComponent &other); // Disallow. + BaseFloat log_floor_; +}; + + +// The ExpComponent outputs the exp of input values as y = Exp(x) +class ExpComponent: public NonlinearComponent { + public: + explicit ExpComponent(const ExpComponent &other): + NonlinearComponent(other) { } + ExpComponent() { } + virtual std::string Type() const { return "ExpComponent"; } + virtual int32 Properties() const { + return kSimpleComponent|kBackpropNeedsOutput|kStoresStats; + } + virtual void Propagate(const ComponentPrecomputedIndexes *indexes, + const CuMatrixBase &in, + CuMatrixBase *out) const; + virtual void Backprop(const std::string &debug_info, + const ComponentPrecomputedIndexes *indexes, + const CuMatrixBase &, + const CuMatrixBase &out_value, + const CuMatrixBase &, + Component *to_update, + CuMatrixBase *in_deriv) const; + + virtual Component* Copy() const { return new ExpComponent(*this); } + private: + ExpComponent &operator = (const ExpComponent &other); // Disallow. +}; + + /// Keywords: natural gradient descent, NG-SGD, naturalgradient. For /// the top-level of the natural gradient code look here, and also in /// nnet-precondition-online.h. @@ -832,6 +897,8 @@ class FixedAffineComponent: public Component { // Function to provide access to linear_params_. const CuMatrix &LinearParams() const { return linear_params_; } + const CuVector &BiasParams() const { return bias_params_; } + protected: friend class AffineComponent; CuMatrix linear_params_; @@ -1135,6 +1202,46 @@ class ClipGradientComponent: public Component { }; +// Applied a per-element scale only on the gradient during back propagation +// Duplicates the input during forward propagation +class ScaleGradientComponent : public Component { + public: + ScaleGradientComponent() { } + virtual std::string Type() const { return "ScaleGradientComponent"; } + virtual std::string Info() const; + virtual int32 Properties() const { + return kSimpleComponent|kLinearInInput|kPropagateInPlace|kBackpropInPlace; + } + + void Init(const CuVectorBase &scales); + + // The ConfigLine cfl contains only the option scales=, + // where the string is the filename of a Kaldi-format matrix to read. + virtual void InitFromConfig(ConfigLine *cfl); + + virtual int32 InputDim() const { return scales_.Dim(); } + virtual int32 OutputDim() const { return scales_.Dim(); } + + virtual void Propagate(const ComponentPrecomputedIndexes *indexes, + const CuMatrixBase &in, + CuMatrixBase *out) const; + virtual void Backprop(const std::string &debug_info, + const ComponentPrecomputedIndexes *indexes, + const CuMatrixBase &, // in_value + const CuMatrixBase &, // out_value + const CuMatrixBase &out_deriv, + Component *, // to_update + CuMatrixBase *in_deriv) const; + virtual Component* Copy() const; + virtual void Read(std::istream &is, bool binary); + virtual void Write(std::ostream &os, bool binary) const; + + protected: + CuVector scales_; + KALDI_DISALLOW_COPY_AND_ASSIGN(ScaleGradientComponent); +}; + + /** PermuteComponent changes the order of the columns (i.e. the feature or activation dimensions). Output dimension i is mapped to input dimension column_map_[i], so it's like doing: diff --git a/src/nnet3/nnet-test-utils.cc b/src/nnet3/nnet-test-utils.cc index 18131aaa213..caaec3b416c 100644 --- a/src/nnet3/nnet-test-utils.cc +++ b/src/nnet3/nnet-test-utils.cc @@ -1130,7 +1130,7 @@ void ComputeExampleComputationRequestSimple( static void GenerateRandomComponentConfig(std::string *component_type, std::string *config) { - int32 n = RandInt(0, 30); + int32 n = RandInt(0, 33); BaseFloat learning_rate = 0.001 * RandInt(1, 100); std::ostringstream os; @@ -1436,6 +1436,21 @@ static void GenerateRandomComponentConfig(std::string *component_type, << " self-repair-scale=0.0"; break; } + case 31: { + *component_type = "LogComponent"; + os << "dim=" << RandInt(1, 50); + break; + } + case 32: { + *component_type = "ExpComponent"; + os << "dim=" << RandInt(1, 50); + break; + } + case 33: { + *component_type = "ScaleGradientComponent"; + os << "dim=" << RandInt(1, 100); + break; + } default: KALDI_ERR << "Error generating random component"; } diff --git a/src/nnet3/nnet-training.cc b/src/nnet3/nnet-training.cc index 2a081920738..c01308288b9 100644 --- a/src/nnet3/nnet-training.cc +++ b/src/nnet3/nnet-training.cc @@ -86,14 +86,17 @@ void NnetTrainer::ProcessOutputs(const NnetExample &eg, ObjectiveType obj_type = nnet_->GetNode(node_index).u.objective_type; BaseFloat tot_weight, tot_objf; bool supply_deriv = true; + const Vector *deriv_weights = NULL; + if (config_.apply_deriv_weights && io.deriv_weights.Dim() > 0) + deriv_weights = &(io.deriv_weights); ComputeObjectiveFunction(io.features, obj_type, io.name, supply_deriv, computer, - &tot_weight, &tot_objf); + &tot_weight, &tot_objf, deriv_weights); objf_info_[io.name].UpdateStats(io.name, config_.print_interval, - num_minibatches_processed_++, tot_weight, tot_objf); } } + num_minibatches_processed_++; } void NnetTrainer::UpdateParamsWithMaxChange() { @@ -183,11 +186,12 @@ bool NnetTrainer::PrintTotalStats() const { unordered_map::const_iterator iter = objf_info_.begin(), end = objf_info_.end(); - bool ans = false; + bool ans = true; for (; iter != end; ++iter) { const std::string &name = iter->first; const ObjectiveFunctionInfo &info = iter->second; - ans = ans || info.PrintTotalStats(name); + if (!info.PrintTotalStats(name)) + ans = false; } PrintMaxChangeStats(); return ans; @@ -220,11 +224,10 @@ void NnetTrainer::PrintMaxChangeStats() const { void ObjectiveFunctionInfo::UpdateStats( const std::string &output_name, int32 minibatches_per_phase, - int32 minibatch_counter, BaseFloat this_minibatch_weight, BaseFloat this_minibatch_tot_objf, BaseFloat this_minibatch_tot_aux_objf) { - int32 phase = minibatch_counter / minibatches_per_phase; + int32 phase = num_minibatches++ / minibatches_per_phase; if (phase != current_phase) { KALDI_ASSERT(phase == current_phase + 1); // or doesn't really make sense. PrintStatsForThisPhase(output_name, minibatches_per_phase); @@ -298,7 +301,8 @@ void ComputeObjectiveFunction(const GeneralMatrix &supervision, bool supply_deriv, NnetComputer *computer, BaseFloat *tot_weight, - BaseFloat *tot_objf) { + BaseFloat *tot_objf, + const VectorBase *deriv_weights) { const CuMatrixBase &output = computer->GetOutput(output_name); if (output.NumCols() != supervision.NumCols()) @@ -307,6 +311,51 @@ void ComputeObjectiveFunction(const GeneralMatrix &supervision, << " (nnet) vs. " << supervision.NumCols() << " (egs)\n"; switch (objective_type) { + case kXentPerDim: { + // objective is x * log(y) + (1-x) * log(1-y) + CuMatrix cu_post(supervision.NumRows(), supervision.NumCols(), + kUndefined); // x + cu_post.CopyFromGeneralMat(supervision); + + CuMatrix n_cu_post(cu_post.NumRows(), cu_post.NumCols()); + n_cu_post.Set(1.0); + n_cu_post.AddMat(-1.0, cu_post); // 1-x + + CuMatrix log_prob(output); // y + log_prob.ApplyLog(); // log(y) + + CuMatrix n_output(output.NumRows(), + output.NumCols(), kSetZero); + n_output.Set(1.0); + n_output.AddMat(-1.0, output); // 1-y + n_output.ApplyLog(); // log(1-y) + + BaseFloat num_elements = static_cast(cu_post.NumRows()); + if (deriv_weights) { + CuVector cu_deriv_weights(*deriv_weights); + num_elements = cu_deriv_weights.Sum(); + cu_post.MulRowsVec(cu_deriv_weights); + n_cu_post.MulRowsVec(cu_deriv_weights); + } + + *tot_weight = num_elements * cu_post.NumCols(); + *tot_objf = TraceMatMat(log_prob, cu_post, kTrans) + + TraceMatMat(n_output, n_cu_post, kTrans); + + if (supply_deriv) { + // deriv is x / y - (1-x) / (1-y) + n_output.ApplyExp(); // 1-y + n_cu_post.DivElements(n_output); // 1-x / (1-y) + + log_prob.ApplyExp(); // y + cu_post.DivElements(log_prob); // x / y + + cu_post.AddMat(-1.0, n_cu_post); // x / y - (1-x) / (1-y) + computer->AcceptInput(output_name, &cu_post); + } + + break; + } case kLinear: { // objective is x * y. switch (supervision.Type()) { @@ -316,20 +365,38 @@ void ComputeObjectiveFunction(const GeneralMatrix &supervision, // The cross-entropy objective is computed by a simple dot product, // because after the LogSoftmaxLayer, the output is already in the form // of log-likelihoods that are normalized to sum to one. - *tot_weight = cu_post.Sum(); - *tot_objf = TraceMatSmat(output, cu_post, kTrans); - if (supply_deriv) { + if (deriv_weights) { CuMatrix output_deriv(output.NumRows(), output.NumCols(), kUndefined); cu_post.CopyToMat(&output_deriv); - computer->AcceptInput(output_name, &output_deriv); + CuVector cu_deriv_weights(*deriv_weights); + output_deriv.MulRowsVec(cu_deriv_weights); + *tot_weight = cu_deriv_weights.Sum(); + *tot_objf = TraceMatMat(output, output_deriv, kTrans); + if (supply_deriv) { + computer->AcceptInput(output_name, &output_deriv); + } + } else { + *tot_weight = cu_post.Sum(); + *tot_objf = TraceMatSmat(output, cu_post, kTrans); + if (supply_deriv) { + CuMatrix output_deriv(output.NumRows(), output.NumCols(), + kUndefined); + cu_post.CopyToMat(&output_deriv); + computer->AcceptInput(output_name, &output_deriv); + } } + break; } case kFullMatrix: { // there is a redundant matrix copy in here if we're not using a GPU // but we don't anticipate this code branch being used in many cases. CuMatrix cu_post(supervision.GetFullMatrix()); + if (deriv_weights) { + CuVector cu_deriv_weights(*deriv_weights); + cu_post.MulRowsVec(cu_deriv_weights); + } *tot_weight = cu_post.Sum(); *tot_objf = TraceMatMat(output, cu_post, kTrans); if (supply_deriv) @@ -341,6 +408,10 @@ void ComputeObjectiveFunction(const GeneralMatrix &supervision, supervision.GetMatrix(&post); CuMatrix cu_post; cu_post.Swap(&post); + if (deriv_weights) { + CuVector cu_deriv_weights(*deriv_weights); + cu_post.MulRowsVec(cu_deriv_weights); + } *tot_weight = cu_post.Sum(); *tot_objf = TraceMatMat(output, cu_post, kTrans); if (supply_deriv) @@ -358,6 +429,11 @@ void ComputeObjectiveFunction(const GeneralMatrix &supervision, diff.CopyFromGeneralMat(supervision); diff.AddMat(-1.0, output); *tot_weight = diff.NumRows(); + if (deriv_weights) { + CuVector cu_deriv_weights(*deriv_weights); + diff.MulRowsVec(cu_deriv_weights); + *tot_weight = deriv_weights->Sum(); + } *tot_objf = -0.5 * TraceMatMat(diff, diff, kTrans); if (supply_deriv) computer->AcceptInput(output_name, &diff); diff --git a/src/nnet3/nnet-training.h b/src/nnet3/nnet-training.h index 55d3e02ea67..4ce64305d60 100644 --- a/src/nnet3/nnet-training.h +++ b/src/nnet3/nnet-training.h @@ -43,6 +43,8 @@ struct NnetTrainerOptions { NnetOptimizeOptions optimize_config; NnetComputeOptions compute_config; CachingOptimizingCompilerOptions compiler_config; + bool apply_deriv_weights; + NnetTrainerOptions(): zero_component_stats(true), store_component_stats(true), @@ -50,7 +52,8 @@ struct NnetTrainerOptions { debug_computation(false), momentum(0.0), binary_write_cache(true), - max_param_change(2.0) { } + max_param_change(2.0), + apply_deriv_weights(true) { } void Register(OptionsItf *opts) { opts->Register("store-component-stats", &store_component_stats, "If true, store activations and derivatives for nonlinear " @@ -70,6 +73,9 @@ struct NnetTrainerOptions { "so that the 'effective' learning rate is the same as " "before (because momentum would normally increase the " "effective learning rate by 1/(1-momentum))"); + opts->Register("apply-deriv-weights", &apply_deriv_weights, + "If true, apply the per-frame derivative weights stored with " + "the example"); opts->Register("read-cache", &read_cache, "the location where we can read " "the cached computation from"); opts->Register("write-cache", &write_cache, "the location where we want to " @@ -93,6 +99,7 @@ struct NnetTrainerOptions { // Also see struct AccuracyInfo, in nnet-diagnostics.h. struct ObjectiveFunctionInfo { int32 current_phase; + int32 num_minibatches; double tot_weight; double tot_objf; @@ -105,7 +112,7 @@ struct ObjectiveFunctionInfo { double tot_aux_objf_this_phase; ObjectiveFunctionInfo(): - current_phase(0), + current_phase(0), num_minibatches(0), tot_weight(0.0), tot_objf(0.0), tot_aux_objf(0.0), tot_weight_this_phase(0.0), tot_objf_this_phase(0.0), tot_aux_objf_this_phase(0.0) { } @@ -116,7 +123,6 @@ struct ObjectiveFunctionInfo { // control how frequently we print logging messages. void UpdateStats(const std::string &output_name, int32 minibatches_per_phase, - int32 minibatch_counter, BaseFloat this_minibatch_weight, BaseFloat this_minibatch_tot_objf, BaseFloat this_minibatch_tot_aux_objf = 0.0); @@ -227,7 +233,8 @@ void ComputeObjectiveFunction(const GeneralMatrix &supervision, bool supply_deriv, NnetComputer *computer, BaseFloat *tot_weight, - BaseFloat *tot_objf); + BaseFloat *tot_objf, + const VectorBase* deriv_weights = NULL); diff --git a/src/nnet3/nnet-utils.cc b/src/nnet3/nnet-utils.cc index 27415fe8775..9c00913f012 100644 --- a/src/nnet3/nnet-utils.cc +++ b/src/nnet3/nnet-utils.cc @@ -152,7 +152,7 @@ void ComputeSimpleNnetContext(const Nnet &nnet, // This will crash if the total context (left + right) is greater // than window_size. - int32 window_size = 100; + int32 window_size = 150; // by going "<= modulus" instead of "< modulus" we do one more computation // than we really need; it becomes a sanity check. for (int32 input_start = 0; input_start <= modulus; input_start++) diff --git a/src/nnet3bin/Makefile b/src/nnet3bin/Makefile index 2bae1dcdc43..1645157e84d 100644 --- a/src/nnet3bin/Makefile +++ b/src/nnet3bin/Makefile @@ -17,7 +17,8 @@ BINFILES = nnet3-init nnet3-info nnet3-get-egs nnet3-copy-egs nnet3-subset-egs \ nnet3-discriminative-merge-egs nnet3-discriminative-shuffle-egs \ nnet3-discriminative-compute-objf nnet3-discriminative-train \ nnet3-discriminative-subset-egs \ - nnet3-discriminative-compute-from-egs nnet3-latgen-faster-looped + nnet3-discriminative-compute-from-egs nnet3-latgen-faster-looped \ + nnet3-get-egs-multiple-targets nnet3-am-compute nnet3-copy-egs-overlap-detection OBJFILES = diff --git a/src/nnet3bin/nnet3-acc-lda-stats.cc b/src/nnet3bin/nnet3-acc-lda-stats.cc index c8911a4a39f..ca6ef27d451 100644 --- a/src/nnet3bin/nnet3-acc-lda-stats.cc +++ b/src/nnet3bin/nnet3-acc-lda-stats.cc @@ -87,13 +87,18 @@ class NnetLdaStatsAccumulator { // but we're about to do an outer product, so this doesn't dominate. Vector row(cu_row); + BaseFloat deriv_weight = 1.0; + if (output_supervision->deriv_weights.Dim() > 0 && r < output_supervision->deriv_weights.Dim()) { + deriv_weight = output_supervision->deriv_weights(r); + } + const SparseVector &post(smat.Row(r)); const std::pair *post_data = post.Data(), *post_end = post_data + post.NumElements(); for (; post_data != post_end; ++post_data) { MatrixIndexT pdf = post_data->first; BaseFloat weight = post_data->second; - BaseFloat pruned_weight = RandPrune(weight, rand_prune); + BaseFloat pruned_weight = RandPrune(weight, rand_prune) * deriv_weight; if (pruned_weight != 0.0) lda_stats_.Accumulate(row, pdf, pruned_weight); } @@ -110,11 +115,16 @@ class NnetLdaStatsAccumulator { // but we're about to do an outer product, so this doesn't dominate. Vector row(cu_row); + BaseFloat deriv_weight = 1.0; + if (output_supervision->deriv_weights.Dim() > 0 && r < output_supervision->deriv_weights.Dim()) { + deriv_weight = output_supervision->deriv_weights(r); + } + SubVector post(output_mat, r); int32 num_pdfs = post.Dim(); for (int32 pdf = 0; pdf < num_pdfs; pdf++) { BaseFloat weight = post(pdf); - BaseFloat pruned_weight = RandPrune(weight, rand_prune); + BaseFloat pruned_weight = RandPrune(weight, rand_prune) * deriv_weight; if (pruned_weight != 0.0) lda_stats_.Accumulate(row, pdf, pruned_weight); } diff --git a/src/nnet3bin/nnet3-am-compute.cc b/src/nnet3bin/nnet3-am-compute.cc new file mode 100644 index 00000000000..c91417c0aee --- /dev/null +++ b/src/nnet3bin/nnet3-am-compute.cc @@ -0,0 +1,186 @@ +// nnet3bin/nnet3-am-compute.cc + +// Copyright 2012-2015 Johns Hopkins University (author: Daniel Povey) +// 2015 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "nnet3/nnet-am-decodable-simple.h" +#include "base/timer.h" +#include "nnet3/nnet-utils.h" + + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace kaldi::nnet3; + typedef kaldi::int32 int32; + typedef kaldi::int64 int64; + + const char *usage = + "Propagate the features through neural network model " + "and write the pseudo log-likelihoods (after dividing by priors).\n" + "If --apply-exp=true, apply the Exp() function to the output " + "before writing it out.\n" + "\n" + "Usage: nnet3-am-compute [options] " + "\n" + " e.g.: nnet3-am-compute final.mdl scp:feats.scp ark:log_likes.ark\n" + "See also: nnet3-compute-from-egs, nnet3-compute\n"; + + ParseOptions po(usage); + Timer timer; + + NnetSimpleComputationOptions opts; + opts.acoustic_scale = 1.0; // by default do no scaling in this recipe. + + bool apply_exp = false; + std::string use_gpu = "yes"; + + std::string word_syms_filename; + std::string ivector_rspecifier, + online_ivector_rspecifier, + utt2spk_rspecifier; + int32 online_ivector_period = 0; + + opts.Register(&po); + + po.Register("ivectors", &ivector_rspecifier, "Rspecifier for " + "iVectors as vectors (i.e. not estimated online); per utterance " + "by default, or per speaker if you provide the --utt2spk option."); + po.Register("utt2spk", &utt2spk_rspecifier, "Rspecifier for " + "utt2spk option used to get ivectors per speaker"); + po.Register("online-ivectors", &online_ivector_rspecifier, "Rspecifier for " + "iVectors estimated online, as matrices. If you supply this," + " you must set the --online-ivector-period option."); + po.Register("online-ivector-period", &online_ivector_period, "Number of frames " + "between iVectors in matrices supplied to the --online-ivectors " + "option"); + po.Register("apply-exp", &apply_exp, "If true, apply exp function to " + "output"); + po.Register("use-gpu", &use_gpu, + "yes|no|optional|wait, only has effect if compiled with CUDA"); + + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + +#if HAVE_CUDA==1 + CuDevice::Instantiate().SelectGpuId(use_gpu); +#endif + + std::string nnet_rxfilename = po.GetArg(1), + feature_rspecifier = po.GetArg(2), + matrix_wspecifier = po.GetArg(3); + + TransitionModel trans_model; + AmNnetSimple am_nnet; + { + bool binary_read; + Input ki(nnet_rxfilename, &binary_read); + trans_model.Read(ki.Stream(), binary_read); + am_nnet.Read(ki.Stream(), binary_read); + } + + RandomAccessBaseFloatMatrixReader online_ivector_reader( + online_ivector_rspecifier); + RandomAccessBaseFloatVectorReaderMapped ivector_reader( + ivector_rspecifier, utt2spk_rspecifier); + + CachingOptimizingCompiler compiler(am_nnet.GetNnet(), opts.optimize_config); + + BaseFloatMatrixWriter matrix_writer(matrix_wspecifier); + + int32 num_success = 0, num_fail = 0; + int64 frame_count = 0; + + SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier); + + for (; !feature_reader.Done(); feature_reader.Next()) { + std::string utt = feature_reader.Key(); + const Matrix &features (feature_reader.Value()); + if (features.NumRows() == 0) { + KALDI_WARN << "Zero-length utterance: " << utt; + num_fail++; + continue; + } + const Matrix *online_ivectors = NULL; + const Vector *ivector = NULL; + if (!ivector_rspecifier.empty()) { + if (!ivector_reader.HasKey(utt)) { + KALDI_WARN << "No iVector available for utterance " << utt; + num_fail++; + continue; + } else { + ivector = &ivector_reader.Value(utt); + } + } + if (!online_ivector_rspecifier.empty()) { + if (!online_ivector_reader.HasKey(utt)) { + KALDI_WARN << "No online iVector available for utterance " << utt; + num_fail++; + continue; + } else { + online_ivectors = &online_ivector_reader.Value(utt); + } + } + + DecodableNnetSimple nnet_computer( + opts, am_nnet.GetNnet(), am_nnet.Priors(), + features, &compiler, + ivector, online_ivectors, + online_ivector_period); + + Matrix matrix(nnet_computer.NumFrames(), + nnet_computer.OutputDim()); + for (int32 t = 0; t < nnet_computer.NumFrames(); t++) { + SubVector row(matrix, t); + nnet_computer.GetOutputForFrame(t, &row); + } + + if (apply_exp) + matrix.ApplyExp(); + + matrix_writer.Write(utt, matrix); + + frame_count += features.NumRows(); + num_success++; + } + +#if HAVE_CUDA==1 + CuDevice::Instantiate().PrintProfile(); +#endif + double elapsed = timer.Elapsed(); + KALDI_LOG << "Time taken "<< elapsed + << "s: real-time factor assuming 100 frames/sec is " + << (elapsed*100.0/frame_count); + KALDI_LOG << "Done " << num_success << " utterances, failed for " + << num_fail; + + if (num_success != 0) return 0; + else return 1; + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/nnet3bin/nnet3-compute-from-egs.cc b/src/nnet3bin/nnet3-compute-from-egs.cc index 648b5e1408f..cb4cbc5a6e8 100644 --- a/src/nnet3bin/nnet3-compute-from-egs.cc +++ b/src/nnet3bin/nnet3-compute-from-egs.cc @@ -36,7 +36,8 @@ class NnetComputerFromEg { // Compute the output (which will have the same number of rows as the number // of Indexes in the output of the eg), and put it in "output". - void Compute(const NnetExample &eg, Matrix *output) { + void Compute(const NnetExample &eg, const std::string &output_name, + Matrix *output) { ComputationRequest request; bool need_backprop = false, store_stats = false; GetComputationRequest(nnet_, eg, need_backprop, store_stats, &request); @@ -47,7 +48,7 @@ class NnetComputerFromEg { NnetComputer computer(options, computation, nnet_, NULL); computer.AcceptInputs(nnet_, eg.io); computer.Run(); - const CuMatrixBase &nnet_output = computer.GetOutput("output"); + const CuMatrixBase &nnet_output = computer.GetOutput(output_name); output->Resize(nnet_output.NumRows(), nnet_output.NumCols()); nnet_output.CopyToMat(output); } @@ -80,11 +81,14 @@ int main(int argc, char *argv[]) { bool binary_write = true, apply_exp = false; std::string use_gpu = "yes"; + std::string output_name = "output"; ParseOptions po(usage); po.Register("binary", &binary_write, "Write output in binary mode"); po.Register("apply-exp", &apply_exp, "If true, apply exp function to " "output"); + po.Register("output-name", &output_name, "Do computation for " + "specified output"); po.Register("use-gpu", &use_gpu, "yes|no|optional|wait, only has effect if compiled with CUDA"); @@ -115,7 +119,7 @@ int main(int argc, char *argv[]) { for (; !example_reader.Done(); example_reader.Next(), num_egs++) { Matrix output; - computer.Compute(example_reader.Value(), &output); + computer.Compute(example_reader.Value(), output_name, &output); KALDI_ASSERT(output.NumRows() != 0); if (apply_exp) output.ApplyExp(); diff --git a/src/nnet3bin/nnet3-compute.cc b/src/nnet3bin/nnet3-compute.cc index 9305ef7e6b6..d46220c7ffd 100644 --- a/src/nnet3bin/nnet3-compute.cc +++ b/src/nnet3bin/nnet3-compute.cc @@ -159,6 +159,9 @@ int main(int argc, char *argv[]) { num_success++; } +#if HAVE_CUDA==1 + CuDevice::Instantiate().PrintProfile(); +#endif double elapsed = timer.Elapsed(); KALDI_LOG << "Time taken "<< elapsed << "s: real-time factor assuming 100 frames/sec is " diff --git a/src/nnet3bin/nnet3-copy-egs-overlap-detection.cc b/src/nnet3bin/nnet3-copy-egs-overlap-detection.cc new file mode 100644 index 00000000000..3f180a6393e --- /dev/null +++ b/src/nnet3bin/nnet3-copy-egs-overlap-detection.cc @@ -0,0 +1,187 @@ +// nnet3bin/nnet3-copy-egs.cc + +// Copyright 2012-2015 Johns Hopkins University (author: Daniel Povey) +// 2014 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "hmm/transition-model.h" +#include "nnet3/nnet-example.h" +#include "nnet3/nnet-example-utils.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace kaldi::nnet3; + typedef kaldi::int32 int32; + typedef kaldi::int64 int64; + + const char *usage = + "Copy examples (single frames or fixed-size groups of frames) for neural\n" + "network training, possibly changing the binary mode. Supports multiple wspecifiers, in\n" + "which case it will write the examples round-robin to the outputs.\n" + "\n" + "Usage: nnet3-copy-egs [options] \n" + "\n" + "e.g.\n" + "nnet3-copy-egs ark:train.egs ark,t:text.egs\n" + "or:\n" + "nnet3-copy-egs ark:train.egs ark:1.egs\n"; + + ParseOptions po(usage); + + bool add_silence_output = true; + bool add_speech_output = true; + int32 srand_seed = 0; + + std::string keep_proportion_positive_rxfilename; + std::string keep_proportion_negative_rxfilename; + + po.Register("add-silence-output", &add_silence_output, + "Add silence output"); + po.Register("add-speech-output", &add_speech_output, + "Add speech output"); + po.Register("srand", &srand_seed, "Seed for random number generator " + "(only relevant if --keep-proportion-vec is specified"); + po.Register("keep-proportion-positive-vec", &keep_proportion_positive_rxfilename, + "If a dimension of this is <1.0, this program will " + "randomly set deriv weight 0 for this proportion of the input samples of the " + "corresponding positive examples"); + po.Register("keep-proportion-negative-vec", &keep_proportion_negative_rxfilename, + "If a dimension of this is <1.0, this program will " + "randomly set deriv weight 0 for this proportion of the input samples of the " + "corresponding negative examples"); + + Vector p_positive_vec(3); + p_positive_vec.Set(1); + if (!keep_proportion_positive_rxfilename.empty()) + ReadKaldiObject(keep_proportion_positive_rxfilename, &p_positive_vec); + + Vector p_negative_vec(3); + p_negative_vec.Set(1); + if (!keep_proportion_negative_rxfilename.empty()) + ReadKaldiObject(keep_proportion_negative_rxfilename, &p_negative_vec); + + po.Read(argc, argv); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + std::string examples_rspecifier = po.GetArg(1); + std::string examples_wspecifier = po.GetArg(2); + + SequentialNnetExampleReader example_reader(examples_rspecifier); + NnetExampleWriter example_writer(examples_wspecifier); + + int64 num_read = 0, num_written = 0; + for (; !example_reader.Done(); example_reader.Next(), num_read++) { + std::string key = example_reader.Key(); + NnetExample eg = example_reader.Value(); + + KALDI_ASSERT(eg.io.size() == 2); + NnetIo &io = eg.io[1]; + + KALDI_ASSERT(io.name == "output"); + + NnetIo silence_output(io); + silence_output.name = "output-silence"; + + NnetIo speech_output(io); + speech_output.name = "output-speech"; + + NnetIo overlap_speech_output(io); + overlap_speech_output.name = "output-overlap_speech"; + + io.features.Uncompress(); + + KALDI_ASSERT(io.features.Type() == kFullMatrix); + const Matrix &feats = io.features.GetFullMatrix(); + + typedef std::vector > SparseVec; + std::vector silence_post(feats.NumRows(), SparseVec()); + std::vector speech_post(feats.NumRows(), SparseVec()); + std::vector overlap_speech_post(feats.NumRows(), SparseVec()); + + Vector silence_deriv_weights(feats.NumRows()); + Vector speech_deriv_weights(feats.NumRows()); + Vector overlap_speech_deriv_weights(feats.NumRows()); + + for (int32 i = 0; i < feats.NumRows(); i++) { + if (feats(i,0) < 0.5) { + silence_deriv_weights(i) = WithProb(p_negative_vec(0)) ? 1.0 : 0.0; + silence_post[i].push_back(std::make_pair(0, 1)); + } else { + silence_deriv_weights(i) = WithProb(p_positive_vec(0)) ? 1.0 : 0.0; + silence_post[i].push_back(std::make_pair(1, 1)); + } + + if (feats(i,1) < 0.5) { + speech_deriv_weights(i) = WithProb(p_negative_vec(1)) ? 1.0 : 0.0; + speech_post[i].push_back(std::make_pair(0, 1)); + } else { + speech_deriv_weights(i) = WithProb(p_positive_vec(1)) ? 1.0 : 0.0; + speech_post[i].push_back(std::make_pair(1, 1)); + } + + if (feats(i,2) < 0.5) { + overlap_speech_deriv_weights(i) = WithProb(p_negative_vec(2)) ? 1.0 : 0.0; + overlap_speech_post[i].push_back(std::make_pair(0, 1)); + } else { + overlap_speech_deriv_weights(i) = WithProb(p_positive_vec(2)) ? 1.0 : 0.0; + overlap_speech_post[i].push_back(std::make_pair(1, 1)); + } + } + + SparseMatrix silence_feats(2, silence_post); + SparseMatrix speech_feats(2, speech_post); + SparseMatrix overlap_speech_feats(2, overlap_speech_post); + + silence_output.features = silence_feats; + speech_output.features = speech_feats; + overlap_speech_output.features = overlap_speech_feats; + + io = overlap_speech_output; + io.deriv_weights.MulElements(overlap_speech_deriv_weights); + + if (add_silence_output) { + silence_output.deriv_weights.MulElements(silence_deriv_weights); + eg.io.push_back(silence_output); + } + + if (add_speech_output) { + speech_output.deriv_weights.MulElements(speech_deriv_weights); + eg.io.push_back(speech_output); + } + + example_writer.Write(key, eg); + num_written++; + } + + KALDI_LOG << "Read " << num_read << " neural-network training examples, wrote " + << num_written; + return (num_written == 0 ? 1 : 0); + } catch(const std::exception &e) { + std::cerr << e.what() << '\n'; + return -1; + } +} + + + diff --git a/src/nnet3bin/nnet3-copy-egs.cc b/src/nnet3bin/nnet3-copy-egs.cc index 42413114af3..b3fbc766e9d 100644 --- a/src/nnet3bin/nnet3-copy-egs.cc +++ b/src/nnet3bin/nnet3-copy-egs.cc @@ -23,10 +23,57 @@ #include "hmm/transition-model.h" #include "nnet3/nnet-example.h" #include "nnet3/nnet-example-utils.h" +#include namespace kaldi { namespace nnet3 { +// rename io-name of eg w.r.t io_names list e.g. input/input-1,output/output-1 +// 'input' is renamed to input-1 and 'output' renamed to output-1. +void RenameIoNames(const std::string &io_names, + NnetExample *eg_modified) { + std::vector separated_io_names; + SplitStringToVector(io_names, ",", true, &separated_io_names); + int32 num_modified_io = separated_io_names.size(), + io_size = eg_modified->io.size(); + std::vector orig_io_list; + for (int32 io_ind = 0; io_ind < io_size; io_ind++) + orig_io_list.push_back(eg_modified->io[io_ind].name); + + for (int32 ind = 0; ind < num_modified_io; ind++) { + std::vector rename_io_name; + SplitStringToVector(separated_io_names[ind], "/", true, &rename_io_name); + // find the io in eg with specific name and rename it to new name. + + int32 rename_io_ind = + std::find(orig_io_list.begin(), orig_io_list.end(), rename_io_name[0]) - + orig_io_list.begin(); + + if (rename_io_ind >= io_size) + KALDI_ERR << "No io-node with name " << rename_io_name[0] + << "exists in eg."; + eg_modified->io[rename_io_ind].name = rename_io_name[1]; + } +} + +bool KeepOutputs(const std::vector &keep_outputs, + NnetExample *eg) { + std::vector io_new; + int32 num_outputs = 0; + for (std::vector::iterator it = eg->io.begin(); + it != eg->io.end(); ++it) { + if (it->name.find("output") != std::string::npos) { + if (!std::binary_search(keep_outputs.begin(), keep_outputs.end(), it->name)) + continue; + num_outputs++; + } + io_new.push_back(*it); + } + eg->io.swap(io_new); + + return num_outputs; +} + // returns an integer randomly drawn with expected value "expected_count" // (will be either floor(expected_count) or ceil(expected_count)). int32 GetCount(double expected_count) { @@ -58,7 +105,7 @@ bool ContainsSingleExample(const NnetExample &eg, end = io.indexes.end(); // Should not have an empty input/output type. KALDI_ASSERT(!io.indexes.empty()); - if (io.name == "input" || io.name == "output") { + if (io.name == "input" || io.name.find("output") != std::string::npos) { int32 min_t = iter->t, max_t = iter->t; for (; iter != end; ++iter) { int32 this_t = iter->t; @@ -75,7 +122,7 @@ bool ContainsSingleExample(const NnetExample &eg, *min_input_t = min_t; *max_input_t = max_t; } else { - KALDI_ASSERT(io.name == "output"); + KALDI_ASSERT(io.name.find("output") != std::string::npos); done_output = true; *min_output_t = min_t; *max_output_t = max_t; @@ -127,7 +174,7 @@ void FilterExample(const NnetExample &eg, min_t = min_input_t; max_t = max_input_t; is_input_or_output = true; - } else if (name == "output") { + } else if (name.find("output") != std::string::npos) { min_t = min_output_t; max_t = max_output_t; is_input_or_output = true; @@ -137,6 +184,7 @@ void FilterExample(const NnetExample &eg, if (!is_input_or_output) { // Just copy everything. io_out.indexes = io_in.indexes; io_out.features = io_in.features; + io_out.deriv_weights = io_in.deriv_weights; } else { const std::vector &indexes_in = io_in.indexes; std::vector &indexes_out = io_out.indexes; @@ -157,6 +205,19 @@ void FilterExample(const NnetExample &eg, } } KALDI_ASSERT(iter_out == keep.end()); + + if (io_in.deriv_weights.Dim() > 0) { + io_out.deriv_weights.Resize(num_kept, kUndefined); + int32 in_dim = 0, out_dim = 0; + iter_out = keep.begin(); + for (; iter_out != keep.end(); ++iter_out, in_dim++) { + if (*iter_out) + io_out.deriv_weights(out_dim++) = io_in.deriv_weights(in_dim); + } + KALDI_ASSERT(out_dim == num_kept); + KALDI_ASSERT(iter_out == keep.end()); + } + if (num_kept == 0) KALDI_ERR << "FilterExample removed all indexes for '" << name << "'"; @@ -249,6 +310,22 @@ bool SelectFromExample(const NnetExample &eg, return true; } +bool RemoveZeroDerivOutputs(NnetExample *eg) { + std::vector io_new; + int32 num_outputs = 0; + for (std::vector::iterator it = eg->io.begin(); + it != eg->io.end(); ++it) { + if (it->name.find("output") != std::string::npos) { + if (it->deriv_weights.Dim() > 0 && it->deriv_weights.Sum() == 0) + continue; + num_outputs++; + } + io_new.push_back(*it); + } + eg->io.swap(io_new); + + return (num_outputs > 0); +} } // namespace nnet3 } // namespace kaldi @@ -276,6 +353,8 @@ int main(int argc, char *argv[]) { int32 srand_seed = 0; int32 frame_shift = 0; BaseFloat keep_proportion = 1.0; + std::string keep_outputs_str; + bool remove_zero_deriv_outputs = false; // The following config variables, if set, can be used to extract a single // frame of labels from a multi-frame example, and/or to reduce the amount @@ -285,6 +364,8 @@ int main(int argc, char *argv[]) { // you can set frame to a number to select a single frame with a particular // offset, or to 'random' to select a random single frame. std::string frame_str; + std::string weight_str; + std::string output_str; ParseOptions po(usage); po.Register("random", &random, "If true, will write frames to output " @@ -307,7 +388,21 @@ int main(int argc, char *argv[]) { "feature left-context that we output."); po.Register("right-context", &right_context, "Can be used to truncate the " "feature right-context that we output."); - + po.Register("keep-outputs", &keep_outputs_str, "Comma separated list of " + "output nodes to keep"); + po.Register("remove-zero-deriv-outputs", &remove_zero_deriv_outputs, + "Remove outputs that do not contribute to the objective " + "because of zero deriv-weights"); + po.Register("weights", &weight_str, + "Rspecifier maps the output posterior to each example" + "If provided, the supervision weight for output is scaled." + " Scaling supervision weight is the same as scaling to the derivative during training " + " in case of linear objective." + "The default is one, which means we are not applying per-example weights."); + po.Register("outputs", &output_str, + "Rspecifier maps example old output-name to new output-name in example." + " If provided, the NnetIo with name 'output' in each example " + " is renamed to new output name."); po.Read(argc, argv); @@ -321,29 +416,91 @@ int main(int argc, char *argv[]) { std::string examples_rspecifier = po.GetArg(1); SequentialNnetExampleReader example_reader(examples_rspecifier); + RandomAccessTokenReader output_reader(output_str); + RandomAccessBaseFloatReader egs_weight_reader(weight_str); int32 num_outputs = po.NumArgs() - 1; std::vector example_writers(num_outputs); for (int32 i = 0; i < num_outputs; i++) example_writers[i] = new NnetExampleWriter(po.GetArg(i+2)); + std::vector keep_outputs; + if (!keep_outputs_str.empty()) { + SplitStringToVector(keep_outputs_str, ",:", true, &keep_outputs); + std::sort(keep_outputs.begin(), keep_outputs.end()); + } - int64 num_read = 0, num_written = 0; + int64 num_read = 0, num_written = 0, num_err = 0; for (; !example_reader.Done(); example_reader.Next(), num_read++) { // count is normally 1; could be 0, or possibly >1. int32 count = GetCount(keep_proportion); std::string key = example_reader.Key(); - const NnetExample &eg = example_reader.Value(); + KALDI_VLOG(2) << "Copying eg " << key; + NnetExample eg(example_reader.Value()); + + if (!keep_outputs_str.empty()) { + if (!KeepOutputs(keep_outputs, &eg)) continue; + } + for (int32 c = 0; c < count; c++) { int32 index = (random ? Rand() : num_written) % num_outputs; if (frame_str == "" && left_context == -1 && right_context == -1 && frame_shift == 0) { + if (remove_zero_deriv_outputs) + if (!RemoveZeroDerivOutputs(&eg)) continue; + if (!weight_str.empty()) { + if (!egs_weight_reader.HasKey(key)) { + KALDI_WARN << "No weight for example key " << key; + num_err++; + continue; + } + BaseFloat weight = egs_weight_reader.Value(key); + for (int32 i = 0; i < eg.io.size(); i++) + if (eg.io[i].name.find("output") != std::string::npos) + eg.io[i].features.Scale(weight); + } + if (!output_str.empty()) { + if (!output_reader.HasKey(key)) { + KALDI_WARN << "No new output-name for example key " << key; + num_err++; + continue; + } + std::string new_output_name = output_reader.Value(key); + // rename output io name to $new_output_name. + std::string rename_io_names = "output/" + new_output_name; + RenameIoNames(rename_io_names, &eg); + } example_writers[index]->Write(key, eg); num_written++; } else { // the --frame option or context options were set. NnetExample eg_modified; if (SelectFromExample(eg, frame_str, left_context, right_context, frame_shift, &eg_modified)) { + if (remove_zero_deriv_outputs) + if (!RemoveZeroDerivOutputs(&eg_modified)) continue; + if (!weight_str.empty()) { + // scale the supervision weight for egs + if (!egs_weight_reader.HasKey(key)) { + KALDI_WARN << "No weight for example key " << key; + num_err++; + continue; + } + int32 weight = egs_weight_reader.Value(key); + for (int32 i = 0; i < eg_modified.io.size(); i++) + if (eg_modified.io[i].name.find("output") != std::string::npos) + eg_modified.io[i].features.Scale(weight); + } + if (!output_str.empty()) { + if (!output_reader.HasKey(key)) { + KALDI_WARN << "No new output-name for example key " << key; + num_err++; + continue; + } + std::string new_output_name = output_reader.Value(key); + // rename output io name to $new_output_name. + std::string rename_io_names = "output/" + new_output_name; + RenameIoNames(rename_io_names, &eg_modified); + } // this branch of the if statement will almost always be taken (should only // not be taken for shorter-than-normal egs from the end of a file. example_writers[index]->Write(key, eg_modified); diff --git a/src/nnet3bin/nnet3-get-egs-dense-targets.cc b/src/nnet3bin/nnet3-get-egs-dense-targets.cc index 54d607466b5..0e387e19fcb 100644 --- a/src/nnet3bin/nnet3-get-egs-dense-targets.cc +++ b/src/nnet3bin/nnet3-get-egs-dense-targets.cc @@ -31,12 +31,15 @@ namespace kaldi { namespace nnet3 { -static void ProcessFile(const MatrixBase &feats, +static bool ProcessFile(const MatrixBase &feats, const MatrixBase *ivector_feats, int32 ivector_period, + const VectorBase *deriv_weights, const MatrixBase &targets, const std::string &utt_id, bool compress, + int32 input_compress_format, + int32 feats_compress_format, int32 num_targets, UtteranceSplitter *utt_splitter, NnetExampleWriter *example_writer) { @@ -44,7 +47,7 @@ static void ProcessFile(const MatrixBase &feats, if (!utt_splitter->LengthsMatch(utt_id, num_input_frames, targets.NumRows())) { if (targets.NumRows() == 0) - return; + return false; // normally we wouldn't process such an utterance but there may be // situations when a small disagreement is acceptable. KALDI_WARN << " .. processing this utterance anyway."; @@ -59,7 +62,7 @@ static void ProcessFile(const MatrixBase &feats, KALDI_WARN << "Not producing egs for utterance " << utt_id << " because it is too short: " << num_input_frames << " frames."; - return; + return false; } // 'frame_subsampling_factor' is not used in any recipes at the time of @@ -93,6 +96,9 @@ static void ProcessFile(const MatrixBase &feats, // call the regular input "input". eg.io.push_back(NnetIo("input", -chunk.left_context, input_frames)); + if (compress) + eg.io.back().Compress(input_compress_format); + if (ivector_feats != NULL) { // if applicable, add the iVector feature. // choose iVector from a random frame in the chunk @@ -131,9 +137,23 @@ static void ProcessFile(const MatrixBase &feats, this_target_dest.CopyFromVec(this_target_src); } - // push this created targets matrix into the eg - eg.io.push_back(NnetIo("output", 0, targets_part)); + if (!deriv_weights) { + // push this created targets matrix into the eg + eg.io.push_back(NnetIo("output", 0, targets_part)); + } else { + Vector this_deriv_weights(num_frames_subsampled); + for (int32 i = 0; i < num_frames_subsampled; i++) { + int32 t = i + start_frame_subsampled; + if (t >= targets.NumRows()) + t = targets.NumRows() - 1; + this_deriv_weights(i) = (*deriv_weights)(t); + } + eg.io.push_back(NnetIo("output", this_deriv_weights, 0, targets_part)); + } + if (compress) + eg.Compress(feats_compress_format); + if (compress) eg.Compress(); @@ -144,9 +164,9 @@ static void ProcessFile(const MatrixBase &feats, example_writer->Write(key, eg); } -} - + return true; +} } // namespace nnet2 } // namespace kaldi @@ -176,16 +196,20 @@ int main(int argc, char *argv[]) { bool compress = true; + int32 input_compress_format = 0, feats_compress_format = 0; int32 num_targets = -1, length_tolerance = 100, online_ivector_period = 1; ExampleGenerationConfig eg_config; // controls num-frames, // left/right-context, etc. - - std::string online_ivector_rspecifier; + std::string online_ivector_rspecifier, deriv_weights_rspecifier; ParseOptions po(usage); eg_config.Register(&po); po.Register("compress", &compress, "If true, write egs in " "compressed format."); + po.Register("compress-format", &feats_compress_format, "Format for " + "compressing all feats in general"); + po.Register("input-compress-format", &input_compress_format, "Format for " + "compressing input feats e.g. Use 2 for compressing wave"); po.Register("num-targets", &num_targets, "Output dimension in egs, " "only used to check targets have correct dim if supplied."); po.Register("ivectors", &online_ivector_rspecifier, "Alias for " @@ -197,6 +221,11 @@ int main(int argc, char *argv[]) { "--online-ivectors option"); po.Register("length-tolerance", &length_tolerance, "Tolerance for " "difference in num-frames between feat and ivector matrices"); + po.Register("deriv-weights-rspecifier", &deriv_weights_rspecifier, + "Per-frame weights (only binary - 0 or 1) that specifies " + "whether a frame's gradient must be backpropagated or not. " + "Not specifying this is equivalent to specifying a vector of " + "all 1s."); po.Read(argc, argv); @@ -217,7 +246,8 @@ int main(int argc, char *argv[]) { RandomAccessBaseFloatMatrixReader matrix_reader(matrix_rspecifier); NnetExampleWriter example_writer(examples_wspecifier); RandomAccessBaseFloatMatrixReader online_ivector_reader(online_ivector_rspecifier); - + RandomAccessBaseFloatVectorReader deriv_weights_reader(deriv_weights_rspecifier); + int32 num_err = 0; for (; !feat_reader.Done(); feat_reader.Next()) { @@ -228,10 +258,10 @@ int main(int argc, char *argv[]) { num_err++; } else { const Matrix &target_matrix = matrix_reader.Value(key); - if (target_matrix.NumRows() != feats.NumRows()) { - KALDI_WARN << "Target matrix has wrong size " - << target_matrix.NumRows() - << " versus " << feats.NumRows(); + if ((target_matrix.NumRows() - feats.NumRows()) > length_tolerance) { + KALDI_WARN << "Length difference between feats " << feats.NumRows() + << " and target matrix " << target_matrix.NumRows() + << "exceeds tolerance " << length_tolerance; num_err++; continue; } @@ -258,9 +288,34 @@ int main(int argc, char *argv[]) { continue; } - ProcessFile(feats, online_ivector_feats, online_ivector_period, - target_matrix, key, compress, num_targets, - &utt_splitter, &example_writer); + const Vector *deriv_weights = NULL; + if (!deriv_weights_rspecifier.empty()) { + if (!deriv_weights_reader.HasKey(key)) { + KALDI_WARN << "No deriv weights for utterance " << key; + num_err++; + continue; + } else { + // this address will be valid until we call HasKey() or Value() + // again. + deriv_weights = &(deriv_weights_reader.Value(key)); + } + } + + if (deriv_weights && + (abs(feats.NumRows() - deriv_weights->Dim()) > length_tolerance + || deriv_weights->Dim() == 0)) { + KALDI_WARN << "Length difference between feats " << feats.NumRows() + << " and deriv weights " << deriv_weights->Dim() + << " exceeds tolerance " << length_tolerance; + num_err++; + continue; + } + + if (!ProcessFile(feats, online_ivector_feats, online_ivector_period, + deriv_weights, target_matrix, key, compress, + input_compress_format, feats_compress_format, num_targets, + &utt_splitter, &example_writer)) + num_err++; } } if (num_err > 0) diff --git a/src/nnet3bin/nnet3-get-egs-multiple-targets.cc b/src/nnet3bin/nnet3-get-egs-multiple-targets.cc new file mode 100644 index 00000000000..63ebce5ab0e --- /dev/null +++ b/src/nnet3bin/nnet3-get-egs-multiple-targets.cc @@ -0,0 +1,529 @@ +// nnet3bin/nnet3-get-egs-multiple-targets.cc + +// Copyright 2012-2015 Johns Hopkins University (author: Daniel Povey) +// 2014-2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "hmm/transition-model.h" +#include "hmm/posterior.h" +#include "nnet3/nnet-example.h" +#include "nnet3/nnet-example-utils.h" + +namespace kaldi { +namespace nnet3 { + +bool ToBool(std::string str) { + std::transform(str.begin(), str.end(), str.begin(), ::tolower); + + if ((str.compare("true") == 0) || (str.compare("t") == 0) + || (str.compare("1") == 0)) + return true; + if ((str.compare("false") == 0) || (str.compare("f") == 0) + || (str.compare("0") == 0)) + return false; + KALDI_ERR << "Invalid format for boolean argument [expected true or false]: " + << str; + return false; // never reached +} + +static void ProcessFile( + const MatrixBase &feats, + const MatrixBase *ivector_feats, + const std::vector &output_names, + const std::vector &output_dims, + const std::vector* > &dense_target_matrices, + const std::vector &posteriors, + const std::vector* > &deriv_weights, + const std::string &utt_id, + bool compress_input, + int32 input_compress_format, + const std::vector &compress_targets, + const std::vector &targets_compress_formats, + int32 left_context, + int32 right_context, + int32 frames_per_eg, + std::vector *num_frames_written, + std::vector *num_egs_written, + NnetExampleWriter *example_writer) { + + KALDI_ASSERT(output_names.size() > 0); + + for (int32 t = 0; t < feats.NumRows(); t += frames_per_eg) { + + int32 tot_frames = left_context + frames_per_eg + right_context; + + Matrix input_frames(tot_frames, feats.NumCols(), kUndefined); + + // Set up "input_frames". + for (int32 j = -left_context; j < frames_per_eg + right_context; j++) { + int32 t2 = j + t; + if (t2 < 0) t2 = 0; + if (t2 >= feats.NumRows()) t2 = feats.NumRows() - 1; + SubVector src(feats, t2), + dest(input_frames, j + left_context); + dest.CopyFromVec(src); + } + + NnetExample eg; + + // call the regular input "input". + eg.io.push_back(NnetIo("input", - left_context, + input_frames)); + + if (compress_input) + eg.io.back().Compress(input_compress_format); + + // if applicable, add the iVector feature. + if (ivector_feats) { + int32 actual_frames_per_eg = std::min(frames_per_eg, + feats.NumRows() - t); + // try to get closest frame to middle of window to get + // a representative iVector. + int32 closest_frame = t + (actual_frames_per_eg / 2); + KALDI_ASSERT(ivector_feats->NumRows() > 0); + if (closest_frame >= ivector_feats->NumRows()) + closest_frame = ivector_feats->NumRows() - 1; + Matrix ivector(1, ivector_feats->NumCols()); + ivector.Row(0).CopyFromVec(ivector_feats->Row(closest_frame)); + eg.io.push_back(NnetIo("ivector", 0, ivector)); + } + + int32 num_outputs_added = 0; + + for (int32 n = 0; n < output_names.size(); n++) { + Vector this_deriv_weights(0); + if (deriv_weights[n]) { + // actual_frames_per_eg is the number of frames with actual targets. + // At the end of the file, we pad with the last frame repeated + // so that all examples have the same structure (prevents the need + // for recompilations). + int32 actual_frames_per_eg = std::min( + std::min(frames_per_eg, feats.NumRows() - t), + deriv_weights[n]->Dim() - t); + + this_deriv_weights.Resize(frames_per_eg); + int32 frames_to_copy = std::min(t + actual_frames_per_eg, + deriv_weights[n]->Dim()) - t; + this_deriv_weights.Range(0, frames_to_copy).CopyFromVec( + deriv_weights[n]->Range(t, frames_to_copy)); + } + + if (dense_target_matrices[n]) { + const MatrixBase &targets = *dense_target_matrices[n]; + Matrix targets_dest(frames_per_eg, targets.NumCols()); + + // actual_frames_per_eg is the number of frames with actual targets. + // At the end of the file, we pad with the last frame repeated + // so that all examples have the same structure (prevents the need + // for recompilations). + int32 actual_frames_per_eg = std::min( + std::min(frames_per_eg, feats.NumRows() - t), + targets.NumRows() - t); + + for (int32 i = 0; i < actual_frames_per_eg; i++) { + // Copy the i^th row of the target matrix from the (t+i)^th row of the + // input targets matrix + SubVector this_target_dest(targets_dest, i); + SubVector this_target_src(targets, t+i); + this_target_dest.CopyFromVec(this_target_src); + } + + // Copy the last frame's target to the padded frames + for (int32 i = actual_frames_per_eg; i < frames_per_eg; i++) { + // Copy the i^th row of the target matrix from the last row of the + // input targets matrix + KALDI_ASSERT(t + actual_frames_per_eg - 1 == targets.NumRows() - 1); + SubVector this_target_dest(targets_dest, i); + SubVector this_target_src(targets, + t + actual_frames_per_eg - 1); + this_target_dest.CopyFromVec(this_target_src); + } + + if (deriv_weights[n]) { + eg.io.push_back(NnetIo(output_names[n], this_deriv_weights, + 0, targets_dest)); + } else { + eg.io.push_back(NnetIo(output_names[n], 0, targets_dest)); + } + } else if (posteriors[n]) { + const Posterior &pdf_post = *(posteriors[n]); + + // actual_frames_per_eg is the number of frames with actual targets. + // At the end of the file, we pad with the last frame repeated + // so that all examples have the same structure (prevents the need + // for recompilations). + int32 actual_frames_per_eg = std::min( + std::min(frames_per_eg, feats.NumRows() - t), + static_cast(pdf_post.size()) - t); + + Posterior labels(frames_per_eg); + for (int32 i = 0; i < actual_frames_per_eg; i++) + labels[i] = pdf_post[t + i]; + // remaining posteriors for frames are empty. + + if (deriv_weights[n]) { + eg.io.push_back(NnetIo(output_names[n], this_deriv_weights, + output_dims[n], 0, labels)); + } else { + eg.io.push_back(NnetIo(output_names[n], output_dims[n], 0, labels)); + } + } else + continue; + if (compress_targets[n]) + eg.io.back().Compress(targets_compress_formats[n]); + + num_outputs_added++; + // Actually actual_frames_per_eg, but that depends on the different + // output. For simplification, frames_per_eg is used. + (*num_frames_written)[n] += frames_per_eg; + (*num_egs_written)[n] += 1; + } + + if (num_outputs_added != output_names.size()) continue; + + std::ostringstream os; + os << utt_id << "-" << t; + + std::string key = os.str(); // key is - + + KALDI_ASSERT(NumOutputs(eg) == num_outputs_added); + + example_writer->Write(key, eg); + } +} + + +} // namespace nnet2 +} // namespace kaldi + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace kaldi::nnet3; + typedef kaldi::int32 int32; + typedef kaldi::int64 int64; + + const char *usage = + "Get frame-by-frame examples of data for nnet3 neural network training.\n" + "This program is similar to nnet3-get-egs, but the targets here are " + "dense matrices instead of posteriors (sparse matrices).\n" + "This is useful when you want the targets to be continuous real-valued " + "with the neural network possibly trained with a quadratic objective\n" + "\n" + "Usage: nnet3-get-egs-multiple-targets [options] " + " ::[:] " + "[ :: ... ] \n" + "\n" + "Here is any random string for output node name, \n" + " is the rspecifier for either dense targets in matrix format or sparse targets in posterior format,\n" + "and is the target dimension of output node for sparse targets or -1 for dense targets\n" + "\n" + "An example [where $feats expands to the actual features]:\n" + "nnet-get-egs-multiple-targets --left-context=12 \\\n" + "--right-context=9 --num-frames=8 \"$feats\" \\\n" + "output-snr:\"ark:copy-matrix ark:exp/snrs/snr.1.ark ark:- |\":-1 \n" + " ark:- \n"; + + + bool compress_input = true; + int32 input_compress_format = 0; + int32 left_context = 0, right_context = 0, + num_frames = 1, length_tolerance = 2; + + std::string ivector_rspecifier, + targets_compress_formats_str, + compress_targets_str; + std::string output_dims_str; + std::string output_names_str; + + ParseOptions po(usage); + po.Register("compress-input", &compress_input, "If true, write egs in " + "compressed format."); + po.Register("input-compress-format", &input_compress_format, "Format for " + "compressing input feats e.g. Use 2 for compressing wave"); + po.Register("compress-targets", &compress_targets_str, "CSL of whether " + "targets must be compressed for each of the outputs"); + po.Register("targets-compress-formats", &targets_compress_formats_str, + "Format for compressing all feats in general"); + po.Register("left-context", &left_context, "Number of frames of left " + "context the neural net requires."); + po.Register("right-context", &right_context, "Number of frames of right " + "context the neural net requires."); + po.Register("num-frames", &num_frames, "Number of frames with labels " + "that each example contains."); + po.Register("ivectors", &ivector_rspecifier, "Rspecifier of ivector " + "features, as matrix."); + po.Register("length-tolerance", &length_tolerance, "Tolerance for " + "difference in num-frames between feat and ivector matrices"); + po.Register("output-dims", &output_dims_str, "CSL of output node dims"); + po.Register("output-names", &output_names_str, "CSL of output node names"); + + po.Read(argc, argv); + + if (po.NumArgs() < 3) { + po.PrintUsage(); + exit(1); + } + + std::string feature_rspecifier = po.GetArg(1), + examples_wspecifier = po.GetArg(po.NumArgs()); + + // Read in all the training files. + SequentialBaseFloatMatrixReader feat_reader(feature_rspecifier); + RandomAccessBaseFloatMatrixReader ivector_reader(ivector_rspecifier); + NnetExampleWriter example_writer(examples_wspecifier); + + int32 num_outputs = (po.NumArgs() - 2) / 2; + KALDI_ASSERT(num_outputs > 0); + + std::vector deriv_weights_readers( + num_outputs, static_cast(NULL)); + std::vector dense_targets_readers( + num_outputs, static_cast(NULL)); + std::vector sparse_targets_readers( + num_outputs, static_cast(NULL)); + + std::vector compress_targets(1, true); + std::vector compress_targets_vector; + + if (!compress_targets_str.empty()) { + SplitStringToVector(compress_targets_str, ":,", + true, &compress_targets_vector); + } + + if (compress_targets_vector.size() == 1 && num_outputs != 1) { + KALDI_WARN << "compress-targets is of size 1. " + << "Extending it to size num-outputs=" << num_outputs; + compress_targets[0] = ToBool(compress_targets_vector[0]); + compress_targets.resize(num_outputs, ToBool(compress_targets_vector[0])); + } else { + if (compress_targets_vector.size() != num_outputs) { + KALDI_ERR << "Mismatch in length of compress-targets and num-outputs; " + << compress_targets_vector.size() << " vs " << num_outputs; + } + for (int32 n = 0; n < num_outputs; n++) { + compress_targets[n] = ToBool(compress_targets_vector[n]); + } + } + + std::vector targets_compress_formats(1, 1); + if (!targets_compress_formats_str.empty()) { + SplitStringToIntegers(targets_compress_formats_str, ":,", + true, &targets_compress_formats); + } + + if (targets_compress_formats.size() == 1 && num_outputs != 1) { + KALDI_WARN << "targets-compress-formats is of size 1. " + << "Extending it to size num-outputs=" << num_outputs; + targets_compress_formats.resize(num_outputs, targets_compress_formats[0]); + } + + if (targets_compress_formats.size() != num_outputs) { + KALDI_ERR << "Mismatch in length of targets-compress-formats " + << " and num-outputs; " + << targets_compress_formats.size() << " vs " << num_outputs; + } + + std::vector output_dims(num_outputs); + SplitStringToIntegers(output_dims_str, ":,", + true, &output_dims); + + std::vector output_names(num_outputs); + SplitStringToVector(output_names_str, ":,", true, &output_names); + + std::vector targets_rspecifiers(num_outputs); + std::vector deriv_weights_rspecifiers(num_outputs); + + for (int32 n = 0; n < num_outputs; n++) { + const std::string &targets_rspecifier = po.GetArg(2*n + 2); + const std::string &deriv_weights_rspecifier = po.GetArg(2*n + 3); + + targets_rspecifiers[n] = targets_rspecifier; + deriv_weights_rspecifiers[n] = deriv_weights_rspecifier; + + if (output_dims[n] >= 0) { + sparse_targets_readers[n] = new RandomAccessPosteriorReader( + targets_rspecifier); + } else { + dense_targets_readers[n] = new RandomAccessBaseFloatMatrixReader( + targets_rspecifier); + } + + if (!deriv_weights_rspecifier.empty()) + deriv_weights_readers[n] = new RandomAccessBaseFloatVectorReader( + deriv_weights_rspecifier); + + KALDI_LOG << "output-name=" << output_names[n] + << " target-dim=" << output_dims[n] + << " targets-rspecifier=\"" << targets_rspecifiers[n] << "\"" + << " deriv-weights-rspecifier=\"" + << deriv_weights_rspecifiers[n] << "\"" + << " compress-target=" + << (compress_targets[n] ? "true" : "false") + << " target-compress-format=" << targets_compress_formats[n]; + } + + int32 num_done = 0, num_err = 0; + + std::vector num_frames_written(num_outputs, 0); + std::vector num_egs_written(num_outputs, 0); + + for (; !feat_reader.Done(); feat_reader.Next()) { + std::string key = feat_reader.Key(); + const Matrix &feats = feat_reader.Value(); + + const Matrix *ivector_feats = NULL; + if (!ivector_rspecifier.empty()) { + if (!ivector_reader.HasKey(key)) { + KALDI_WARN << "No iVectors for utterance " << key; + continue; + } else { + // this address will be valid until we call HasKey() or Value() + // again. + ivector_feats = &(ivector_reader.Value(key)); + } + } + + if (ivector_feats && + (abs(feats.NumRows() - ivector_feats->NumRows()) > length_tolerance + || ivector_feats->NumRows() == 0)) { + KALDI_WARN << "Length difference between feats " << feats.NumRows() + << " and iVectors " << ivector_feats->NumRows() + << "exceeds tolerance " << length_tolerance; + num_err++; + continue; + } + + std::vector* > dense_targets( + num_outputs, static_cast* >(NULL)); + std::vector sparse_targets( + num_outputs, static_cast(NULL)); + std::vector* > deriv_weights( + num_outputs, static_cast* >(NULL)); + + int32 num_outputs_found = 0; + for (int32 n = 0; n < num_outputs; n++) { + if (dense_targets_readers[n]) { + if (!dense_targets_readers[n]->HasKey(key)) { + KALDI_WARN << "No dense targets matrix for key " << key << " in " + << "rspecifier " << targets_rspecifiers[n] + << " for output " << output_names[n]; + break; + } + const MatrixBase *target_matrix = &( + dense_targets_readers[n]->Value(key)); + + if ((target_matrix->NumRows() - feats.NumRows()) > length_tolerance) { + KALDI_WARN << "Length difference between feats " << feats.NumRows() + << " and target matrix " << target_matrix->NumRows() + << "exceeds tolerance " << length_tolerance; + break; + } + + dense_targets[n] = target_matrix; + } else { + if (!sparse_targets_readers[n]->HasKey(key)) { + KALDI_WARN << "No sparse target matrix for key " << key << " in " + << "rspecifier " << targets_rspecifiers[n] + << " for output " << output_names[n]; + break; + } + const Posterior *posterior = &(sparse_targets_readers[n]->Value(key)); + + if (abs(static_cast(posterior->size()) - feats.NumRows()) + > length_tolerance + || posterior->size() < feats.NumRows()) { + KALDI_WARN << "Posterior has wrong size " << posterior->size() + << " versus " << feats.NumRows(); + break; + } + + sparse_targets[n] = posterior; + } + + if (deriv_weights_readers[n]) { + if (!deriv_weights_readers[n]->HasKey(key)) { + KALDI_WARN << "No deriv weights for key " << key << " in " + << "rspecifier " << deriv_weights_rspecifiers[n] + << " for output " << output_names[n]; + break; + } else { + // this address will be valid until we call HasKey() or Value() + // again. + deriv_weights[n] = &(deriv_weights_readers[n]->Value(key)); + } + } + + if (deriv_weights[n] + && (abs(feats.NumRows() - deriv_weights[n]->Dim()) + > length_tolerance + || deriv_weights[n]->Dim() == 0)) { + KALDI_WARN << "Length difference between feats " << feats.NumRows() + << " and deriv weights " << deriv_weights[n]->Dim() + << " exceeds tolerance " << length_tolerance; + break; + } + + num_outputs_found++; + } + + if (num_outputs_found != num_outputs) { + KALDI_WARN << "Not all outputs found for key " << key; + num_err++; + continue; + } + + ProcessFile(feats, ivector_feats, output_names, output_dims, + dense_targets, sparse_targets, + deriv_weights, key, + compress_input, input_compress_format, + compress_targets, targets_compress_formats, + left_context, right_context, num_frames, + &num_frames_written, &num_egs_written, + &example_writer); + num_done++; + } + + int64 max_num_egs_written = 0, max_num_frames_written = 0; + for (int32 n = 0; n < num_outputs; n++) { + delete dense_targets_readers[n]; + delete sparse_targets_readers[n]; + delete deriv_weights_readers[n]; + if (num_egs_written[n] == 0) return false; + if (num_egs_written[n] > max_num_egs_written) { + max_num_egs_written = num_egs_written[n]; + max_num_frames_written = num_frames_written[n]; + } + } + + KALDI_LOG << "Finished generating examples, " + << "successfully processed " << num_done + << " feature files, wrote at most " << max_num_egs_written + << " examples, " + << " with at most " << max_num_frames_written << " egs in total; " + << num_err << " files had errors."; + + return (num_err > num_done ? 1 : 0); + } catch(const std::exception &e) { + std::cerr << e.what() << '\n'; + return -1; + } +} diff --git a/src/nnet3bin/nnet3-get-egs.cc b/src/nnet3bin/nnet3-get-egs.cc index 03623f02a07..490124449ac 100644 --- a/src/nnet3bin/nnet3-get-egs.cc +++ b/src/nnet3bin/nnet3-get-egs.cc @@ -33,9 +33,12 @@ namespace nnet3 { static bool ProcessFile(const MatrixBase &feats, const MatrixBase *ivector_feats, int32 ivector_period, + const VectorBase *deriv_weights, const Posterior &pdf_post, const std::string &utt_id, bool compress, + int32 input_compress_format, + int32 feats_compress_format, int32 num_pdfs, UtteranceSplitter *utt_splitter, NnetExampleWriter *example_writer) { @@ -84,6 +87,9 @@ static bool ProcessFile(const MatrixBase &feats, // call the regular input "input". eg.io.push_back(NnetIo("input", -chunk.left_context, input_frames)); + + if (compress) + eg.io.back().Compress(input_compress_format); if (ivector_feats != NULL) { // if applicable, add the iVector feature. @@ -126,10 +132,23 @@ static bool ProcessFile(const MatrixBase &feats, iter->second *= chunk.output_weights[i]; } - eg.io.push_back(NnetIo("output", num_pdfs, 0, labels)); + if (!deriv_weights) { + eg.io.push_back(NnetIo("output", num_pdfs, 0, labels)); + } else { + KALDI_ASSERT(start_frame_subsampled + num_frames_subsampled - 1 < + deriv_weights->Dim()); + Vector this_deriv_weights(num_frames_subsampled); + for (int32 i = 0; i < num_frames_subsampled; i++) { + int32 t = i + start_frame_subsampled; + this_deriv_weights(i) = (*deriv_weights)(t); + } + // Ignore frames that have frame weights 0 + if (this_deriv_weights.Sum() == 0) continue; + eg.io.push_back(NnetIo("output", this_deriv_weights, num_pdfs, 0, labels)); + } if (compress) - eg.Compress(); + eg.Compress(feats_compress_format); std::ostringstream os; os << utt_id << "-" << chunk.first_frame; @@ -171,18 +190,24 @@ int main(int argc, char *argv[]) { bool compress = true; + int32 input_compress_format = 0, feats_compress_format = 0; int32 num_pdfs = -1, length_tolerance = 100, online_ivector_period = 1; - + ExampleGenerationConfig eg_config; // controls num-frames, // left/right-context, etc. std::string online_ivector_rspecifier; + std::string deriv_weights_rspecifier; ParseOptions po(usage); po.Register("compress", &compress, "If true, write egs in " - "compressed format (recommended)."); + "compressed format."); + po.Register("compress-format", &feats_compress_format, "Format for " + "compressing all feats in general"); + po.Register("input-compress-format", &input_compress_format, "Format for " + "compressing input feats e.g. Use 2 for compressing wave"); po.Register("num-pdfs", &num_pdfs, "Number of pdfs in the acoustic " "model"); po.Register("ivectors", &online_ivector_rspecifier, "Alias for " @@ -194,6 +219,12 @@ int main(int argc, char *argv[]) { "--online-ivectors option"); po.Register("length-tolerance", &length_tolerance, "Tolerance for " "difference in num-frames between feat and ivector matrices"); + po.Register("deriv-weights-rspecifier", &deriv_weights_rspecifier, + "Per-frame weights (only binary - 0 or 1) that specifies " + "whether a frame's gradient must be backpropagated or not. " + "Not specifying this is equivalent to specifying a vector of " + "all 1s."); + eg_config.Register(&po); po.Read(argc, argv); @@ -219,9 +250,9 @@ int main(int argc, char *argv[]) { NnetExampleWriter example_writer(examples_wspecifier); RandomAccessBaseFloatMatrixReader online_ivector_reader( online_ivector_rspecifier); - + RandomAccessBaseFloatVectorReader deriv_weights_reader(deriv_weights_rspecifier); + int32 num_err = 0; - for (; !feat_reader.Done(); feat_reader.Next()) { std::string key = feat_reader.Key(); const Matrix &feats = feat_reader.Value(); @@ -229,8 +260,9 @@ int main(int argc, char *argv[]) { KALDI_WARN << "No pdf-level posterior for key " << key; num_err++; } else { - const Posterior &pdf_post = pdf_post_reader.Value(key); - if (pdf_post.size() != feats.NumRows()) { + Posterior pdf_post = pdf_post_reader.Value(key); + if (abs(static_cast(pdf_post.size()) - feats.NumRows()) > length_tolerance + || pdf_post.size() < feats.NumRows()) { KALDI_WARN << "Posterior has wrong size " << pdf_post.size() << " versus " << feats.NumRows(); num_err++; @@ -260,8 +292,32 @@ int main(int argc, char *argv[]) { continue; } + const Vector *deriv_weights = NULL; + if (!deriv_weights_rspecifier.empty()) { + if (!deriv_weights_reader.HasKey(key)) { + KALDI_WARN << "No deriv weights for utterance " << key; + num_err++; + continue; + } else { + // this address will be valid until we call HasKey() or Value() + // again. + deriv_weights = &(deriv_weights_reader.Value(key)); + } + } + + if (deriv_weights && + (abs(feats.NumRows() - deriv_weights->Dim()) > length_tolerance + || deriv_weights->Dim() == 0)) { + KALDI_WARN << "Length difference between feats " << feats.NumRows() + << " and deriv weights " << deriv_weights->Dim() + << " exceeds tolerance " << length_tolerance; + num_err++; + continue; + } + if (!ProcessFile(feats, online_ivector_feats, online_ivector_period, - pdf_post, key, compress, num_pdfs, + deriv_weights, pdf_post, key, compress, + input_compress_format, feats_compress_format, num_pdfs, &utt_splitter, &example_writer)) num_err++; } diff --git a/src/nnet3bin/nnet3-info.cc b/src/nnet3bin/nnet3-info.cc index 6b7fb2c629e..c722c3b0a85 100644 --- a/src/nnet3bin/nnet3-info.cc +++ b/src/nnet3bin/nnet3-info.cc @@ -20,6 +20,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" #include "nnet3/nnet-nnet.h" +#include "nnet3/nnet-utils.h" int main(int argc, char *argv[]) { try { @@ -36,7 +37,11 @@ int main(int argc, char *argv[]) { " nnet3-info 0.raw\n" "See also: nnet3-am-info\n"; + bool print_detailed_info = false; + ParseOptions po(usage); + po.Register("print-detailed-info", &print_detailed_info, + "Print more detailed info"); po.Read(argc, argv); @@ -50,7 +55,10 @@ int main(int argc, char *argv[]) { Nnet nnet; ReadKaldiObject(raw_nnet_rxfilename, &nnet); - std::cout << nnet.Info(); + if (print_detailed_info) + std::cout << NnetInfo(nnet); + else + std::cout << nnet.Info(); return 0; } catch(const std::exception &e) { diff --git a/src/nnet3bin/nnet3-latgen-faster.cc b/src/nnet3bin/nnet3-latgen-faster.cc index 6bd5cd7c453..f1c18534c2c 100644 --- a/src/nnet3bin/nnet3-latgen-faster.cc +++ b/src/nnet3bin/nnet3-latgen-faster.cc @@ -65,6 +65,8 @@ int main(int argc, char *argv[]) { po.Register("ivectors", &ivector_rspecifier, "Rspecifier for " "iVectors as vectors (i.e. not estimated online); per utterance " "by default, or per speaker if you provide the --utt2spk option."); + po.Register("utt2spk", &utt2spk_rspecifier, "Rspecifier for " + "utt2spk option used to get ivectors per speaker"); po.Register("online-ivectors", &online_ivector_rspecifier, "Rspecifier for " "iVectors estimated online, as matrices. If you supply this," " you must set the --online-ivector-period option."); diff --git a/src/nnet3bin/nnet3-merge-egs.cc b/src/nnet3bin/nnet3-merge-egs.cc index db4508a2835..0684dfc1bae 100644 --- a/src/nnet3bin/nnet3-merge-egs.cc +++ b/src/nnet3bin/nnet3-merge-egs.cc @@ -26,11 +26,13 @@ namespace kaldi { namespace nnet3 { -// returns the number of indexes/frames in the NnetIo named "output" in the eg, -// or crashes if it is not there. +// returns the number of indexes/frames in the output NnetIo +// assumes the output name starts with "output" and only looks at the +// first such output to get the indexes size. +// crashes if it there is no such output int32 NumOutputIndexes(const NnetExample &eg) { for (size_t i = 0; i < eg.io.size(); i++) - if (eg.io[i].name == "output") + if (eg.io[i].name.find("output") != std::string::npos) return eg.io[i].indexes.size(); KALDI_ERR << "No output named 'output' in the eg."; return 0; // Suppress compiler warning. diff --git a/src/nnet3bin/nnet3-show-progress.cc b/src/nnet3bin/nnet3-show-progress.cc index 10898dc0ca6..785d3d0aa88 100644 --- a/src/nnet3bin/nnet3-show-progress.cc +++ b/src/nnet3bin/nnet3-show-progress.cc @@ -107,17 +107,39 @@ int main(int argc, char *argv[]) { eg_end = examples.end(); for (; eg_iter != eg_end; ++eg_iter) prob_computer.Compute(*eg_iter); - const SimpleObjectiveInfo *objf_info = prob_computer.GetObjective("output"); - double objf_per_frame = objf_info->tot_objective / objf_info->tot_weight; + + double tot_weight = 0.0; + + { + const unordered_map &objf_info = prob_computer.GetAllObjectiveInfo(); + + unordered_map::const_iterator objf_it = objf_info.begin(), + objf_end = objf_info.end(); + + + for (; objf_it != objf_end; ++objf_it) { + double objf_per_frame = objf_it->second.tot_objective / objf_it->second.tot_weight; + + if (objf_it->first == "output") { + KALDI_LOG << "At position " << middle + << ", objf per frame is " << objf_per_frame; + } else { + KALDI_LOG << "At position " << middle + << ", objf per frame for '" << objf_it->first + << "' is " << objf_per_frame; + } + + tot_weight += objf_it->second.tot_weight; + } + } + const Nnet &nnet_gradient = prob_computer.GetDeriv(); - KALDI_LOG << "At position " << middle - << ", objf per frame is " << objf_per_frame; Vector old_dotprod(num_updatable), new_dotprod(num_updatable); ComponentDotProducts(nnet_gradient, nnet1, &old_dotprod); ComponentDotProducts(nnet_gradient, nnet2, &new_dotprod); - old_dotprod.Scale(1.0 / objf_info->tot_weight); - new_dotprod.Scale(1.0 / objf_info->tot_weight); + old_dotprod.Scale(1.0 / tot_weight); + new_dotprod.Scale(1.0 / tot_weight); diff.AddVec(1.0/ num_segments, new_dotprod); diff.AddVec(-1.0 / num_segments, old_dotprod); KALDI_VLOG(1) << "By segment " << s << ", objf change is " diff --git a/src/online2bin/ivector-extract-online2.cc b/src/online2bin/ivector-extract-online2.cc index 3251d93b5dd..f597f66763b 100644 --- a/src/online2bin/ivector-extract-online2.cc +++ b/src/online2bin/ivector-extract-online2.cc @@ -55,6 +55,8 @@ int main(int argc, char *argv[]) { g_num_threads = 8; bool repeat = false; + int32 length_tolerance = 0; + std::string frame_weights_rspecifier; po.Register("num-threads", &g_num_threads, "Number of threads to use for computing derived variables " @@ -62,6 +64,12 @@ int main(int argc, char *argv[]) { po.Register("repeat", &repeat, "If true, output the same number of iVectors as input frames " "(including repeated data)."); + po.Register("frame-weights-rspecifier", &frame_weights_rspecifier, + "Archive of frame weights to scale stats"); + po.Register("length-tolerance", &length_tolerance, + "Tolerance on the difference in number of frames " + "for feats and weights"); + po.Read(argc, argv); if (po.NumArgs() != 3) { @@ -82,9 +90,9 @@ int main(int argc, char *argv[]) { SequentialTokenVectorReader spk2utt_reader(spk2utt_rspecifier); RandomAccessBaseFloatMatrixReader feature_reader(feature_rspecifier); + RandomAccessBaseFloatVectorReader frame_weights_reader(frame_weights_rspecifier); BaseFloatMatrixWriter ivector_writer(ivectors_wspecifier); - for (; !spk2utt_reader.Done(); spk2utt_reader.Next()) { std::string spk = spk2utt_reader.Key(); const std::vector &uttlist = spk2utt_reader.Value(); @@ -105,6 +113,31 @@ int main(int argc, char *argv[]) { &matrix_feature); ivector_feature.SetAdaptationState(adaptation_state); + + if (!frame_weights_rspecifier.empty()) { + if (!frame_weights_reader.HasKey(utt)) { + KALDI_WARN << "Did not find weights for utterance " << utt; + num_err++; + continue; + } + const Vector &weights = frame_weights_reader.Value(utt); + + if (std::abs(weights.Dim() - feats.NumRows()) > length_tolerance) { + num_err++; + continue; + } + + std::vector > frame_weights; + for (int32 i = 0; i < feats.NumRows(); i++) { + if (i < weights.Dim()) + frame_weights.push_back(std::make_pair(i, weights(i))); + else + frame_weights.push_back(std::make_pair(i, 0.0)); + } + + + ivector_feature.UpdateFrameWeights(frame_weights); + } int32 T = feats.NumRows(), n = (repeat ? 1 : ivector_config.ivector_period), diff --git a/src/segmenter/Makefile b/src/segmenter/Makefile new file mode 100644 index 00000000000..8259de32c1f --- /dev/null +++ b/src/segmenter/Makefile @@ -0,0 +1,18 @@ +all: + +include ../kaldi.mk + +TESTFILES = segmentation-io-test information-bottleneck-clusterable-test + +OBJFILES = segment.o segmentation.o segmentation-utils.o \ + segmentation-post-processor.o #\ + #information-bottleneck-clusterable.o \ + #information-bottleneck-cluster-utils.o + +LIBNAME = kaldi-segmenter + +ADDLIBS = ../tree/kaldi-tree.a ../gmm/kaldi-gmm.a \ + ../util/kaldi-util.a ../matrix/kaldi-matrix.a ../base/kaldi-base.a ../thread/kaldi-thread.a + +include ../makefiles/default_rules.mk + diff --git a/src/segmenter/information-bottleneck-cluster-utils.cc b/src/segmenter/information-bottleneck-cluster-utils.cc new file mode 100644 index 00000000000..5ed283da564 --- /dev/null +++ b/src/segmenter/information-bottleneck-cluster-utils.cc @@ -0,0 +1,209 @@ +// segmenter/information-bottleneck-cluster-utils.cc + +// Copyright 2017 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "tree/cluster-utils.h" +#include "segmenter/information-bottleneck-cluster-utils.h" + +namespace kaldi { + +typedef uint16 uint_smaller; +typedef int16 int_smaller; + +class InformationBottleneckBottomUpClusterer : public BottomUpClusterer { + public: + InformationBottleneckBottomUpClusterer( + const std::vector &points, + const InformationBottleneckClustererOptions &opts, + BaseFloat max_merge_thresh, + int32 min_clusters, + std::vector *clusters_out, + std::vector *assignments_out); + + private: + virtual void SetInitialDistances(); + virtual BaseFloat ComputeDistance(int32 i, int32 j); + virtual bool StoppingCriterion() const; + virtual void UpdateClustererStats(int32 i, int32 j); + + virtual BaseFloat MergeThreshold(int32 i, int32 j) { + if (opts_.normalize_by_count) + return max_merge_thresh_ + * ((*clusters_)[i]->Normalizer() + (*clusters_)[j]->Normalizer()); + else if (opts_.normalize_by_entropy) + return -max_merge_thresh_ * (*clusters_)[i]->ObjfPlus(*(*clusters_)[j]); + else + return max_merge_thresh_; + } + + BaseFloat NormalizedMutualInformation() const { + return ((merged_entropy_ - current_entropy_) + / (merged_entropy_ - initial_entropy_)); + } + + const InformationBottleneckClustererOptions &opts_; + + /// Running entropy of the clusters. + BaseFloat current_entropy_; + + /// Some stats computed by the constructor that will be useful for + /// adding stopping criterion. + BaseFloat initial_entropy_; + BaseFloat merged_entropy_; +}; + + +InformationBottleneckBottomUpClusterer::InformationBottleneckBottomUpClusterer( + const std::vector &points, + const InformationBottleneckClustererOptions &opts, + BaseFloat max_merge_thresh, + int32 min_clusters, + std::vector *clusters_out, + std::vector *assignments_out) : + BottomUpClusterer(points, max_merge_thresh, min_clusters, + clusters_out, assignments_out), + opts_(opts), + current_entropy_(0.0), initial_entropy_(0.0), merged_entropy_(0.0) { + if (points.size() == 0) return; + + InformationBottleneckClusterable* ibc = + static_cast(points[0]->Copy()); + initial_entropy_ -= ibc->Objf(1.0, 0.0); + + for (size_t i = 1; i < points.size(); i++) { + InformationBottleneckClusterable *c = + static_cast(points[i]); + ibc->Add(*points[i]); + initial_entropy_ -= c->Objf(1.0, 0.0); + } + + merged_entropy_ = -ibc->Objf(1.0, 0.0); + current_entropy_ = initial_entropy_; +} + +void InformationBottleneckBottomUpClusterer::SetInitialDistances() { + for (int32 i = 0; i < npoints_; i++) { + for (int32 j = 0; j < i; j++) { + BaseFloat dist = ComputeDistance(i, j); + if (dist <= MergeThreshold(i, j)) { + queue_.push(std::make_pair( + dist, std::make_pair(static_cast(i), + static_cast(j)))); + } + if (j == i - 1) + KALDI_VLOG(2) << "Distance(" << i << ", " << j << ") = " << dist; + } + } +} + +BaseFloat InformationBottleneckBottomUpClusterer::ComputeDistance( + int32 i, int32 j) { + const InformationBottleneckClusterable* cluster_i + = static_cast((*clusters_)[i]); + const InformationBottleneckClusterable* cluster_j + = static_cast((*clusters_)[j]); + + BaseFloat dist = (cluster_i->Distance(*cluster_j, opts_.relevance_factor, + opts_.input_factor)); + // / (cluster_i->Normalizer() + cluster_j->Normalizer())); + Distance(i, j) = dist; // set the distance in the array. + return dist; +} + +bool InformationBottleneckBottomUpClusterer::StoppingCriterion() const { + bool flag = (nclusters_ <= min_clust_ || queue_.empty() || + NormalizedMutualInformation() < opts_.stopping_threshold); + if (GetVerboseLevel() < 2 || !flag) return flag; + + if (NormalizedMutualInformation() < opts_.stopping_threshold) { + KALDI_VLOG(2) << "Stopping at " << nclusters_ << " clusters " + << "because NMI = " << NormalizedMutualInformation() + << " < stopping_threshold (" + << opts_.stopping_threshold << ")"; + } else if (nclusters_ < min_clust_) { + KALDI_VLOG(2) << "Stopping at " << nclusters_ << " clusters " + << "<= min-clusters (" << min_clust_ << ")"; + } else if (queue_.empty()) { + KALDI_VLOG(2) << "Stopping at " << nclusters_ << " clusters " + << "because queue is empty."; + } + + return flag; +} + +void InformationBottleneckBottomUpClusterer::UpdateClustererStats( + int32 i, int32 j) { + const InformationBottleneckClusterable* cluster_i + = static_cast((*clusters_)[i]); + current_entropy_ += cluster_i->Distance(*(*clusters_)[j], 1.0, 0.0); + + if (GetVerboseLevel() > 2) { + const InformationBottleneckClusterable* cluster_j + = static_cast((*clusters_)[j]); + std::vector cluster_i_points; + { + std::map::const_iterator it + = cluster_i->Counts().begin(); + for (; it != cluster_i->Counts().end(); ++it) + cluster_i_points.push_back(it->first); + } + + std::vector cluster_j_points; + { + std::map::const_iterator it + = cluster_j->Counts().begin(); + for (; it != cluster_j->Counts().end(); ++it) + cluster_j_points.push_back(it->first); + } + KALDI_VLOG(3) << "Merging clusters " + << "(" << cluster_i_points + << ", " << cluster_j_points + << ").. distance=" << Distance(i, j) + << ", num-clusters-after-merge= " << nclusters_ - 1 + << ", NMI= " << NormalizedMutualInformation(); + } +} + +BaseFloat IBClusterBottomUp( + const std::vector &points, + const InformationBottleneckClustererOptions &opts, + BaseFloat max_merge_thresh, + int32 min_clust, + std::vector *clusters_out, + std::vector *assignments_out) { + KALDI_ASSERT(max_merge_thresh >= 0.0 && min_clust >= 0); + KALDI_ASSERT(opts.stopping_threshold >= 0.0); + KALDI_ASSERT(opts.relevance_factor >= 0.0 && opts.input_factor >= 0.0); + + KALDI_ASSERT(!ContainsNullPointers(points)); + int32 npoints = points.size(); + // make sure fits in uint_smaller and does not hit the -1 which is reserved. + KALDI_ASSERT(sizeof(uint_smaller)==sizeof(uint32) || + npoints < static_cast(static_cast(-1))); + + KALDI_VLOG(2) << "Initializing clustering object."; + InformationBottleneckBottomUpClusterer bc( + points, opts, max_merge_thresh, min_clust, + clusters_out, assignments_out); + BaseFloat ans = bc.Cluster(); + if (clusters_out) KALDI_ASSERT(!ContainsNullPointers(*clusters_out)); + return ans; +} + +} // end namespace kaldi diff --git a/src/segmenter/information-bottleneck-cluster-utils.h b/src/segmenter/information-bottleneck-cluster-utils.h new file mode 100644 index 00000000000..82b5c285c65 --- /dev/null +++ b/src/segmenter/information-bottleneck-cluster-utils.h @@ -0,0 +1,74 @@ +// segmenter/information-bottleneck-cluster-utils.h + +// Copyright 2017 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_SEGMENTER_INFORMATION_BOTTLENECK_CLUSTER_UTILS_H_ +#define KALDI_SEGMENTER_INFORMATION_BOTTLENECK_CLUSTER_UTILS_H_ + +#include "base/kaldi-common.h" +#include "tree/cluster-utils.h" +#include "segmenter/information-bottleneck-clusterable.h" +#include "util/common-utils.h" + +namespace kaldi { + +struct InformationBottleneckClustererOptions { + BaseFloat distance_threshold; + int32 num_clusters; + BaseFloat stopping_threshold; + BaseFloat relevance_factor; + BaseFloat input_factor; + bool normalize_by_count; + bool normalize_by_entropy; + + InformationBottleneckClustererOptions() : + distance_threshold(std::numeric_limits::max()), num_clusters(1), + stopping_threshold(0.3), relevance_factor(1.0), input_factor(0.1), + normalize_by_count(false), normalize_by_entropy(false) { } + + + void Register(OptionsItf *opts) { + opts->Register("stopping-threshold", &stopping_threshold, + "Stopping merging/splitting when an objective such as " + "NMI reaches this value."); + opts->Register("relevance-factor", &relevance_factor, + "Weight factor of the entropy of relevant variables " + "in the objective function"); + opts->Register("input-factor", &input_factor, + "Weight factor of the entropy of input variables " + "in the objective function"); + opts->Register("normalize-by-count", &normalize_by_count, + "If provided, normalizes the score (distance) by " + "the count post-merge."); + opts->Register("normalize-by-entropy", &normalize_by_entropy, + "If provided, normalizes the score (distance) by " + "the entropy post-merge."); + } +}; + +BaseFloat IBClusterBottomUp( + const std::vector &points, + const InformationBottleneckClustererOptions &opts, + BaseFloat max_merge_thresh, + int32 min_clusters, + std::vector *clusters_out, + std::vector *assignments_out); + +} // end namespace kaldi + +#endif // KALDI_SEGMENTER_INFORMATION_BOTTLENECK_CLUSTER_UTILS_H_ diff --git a/src/segmenter/information-bottleneck-clusterable-test.cc b/src/segmenter/information-bottleneck-clusterable-test.cc new file mode 100644 index 00000000000..ee0358c8f05 --- /dev/null +++ b/src/segmenter/information-bottleneck-clusterable-test.cc @@ -0,0 +1,94 @@ + +#include "base/kaldi-common.h" +#include "segmenter/information-bottleneck-clusterable.h" + +namespace kaldi { + +static void TestClusterable() { + { + Vector a_vec(3); + a_vec(0) = 0.5; + a_vec(1) = 0.5; + int32 a_count = 100; + KALDI_ASSERT(ApproxEqual(a_vec.Sum(), 1.0)); + + Vector b_vec(3); + b_vec(1) = 0.333; + b_vec(2) = 0.667; + int32 b_count = 100; + KALDI_ASSERT(ApproxEqual(b_vec.Sum(), 1.0)); + + InformationBottleneckClusterable a(1, a_count, a_vec); + InformationBottleneckClusterable b(2, b_count, b_vec); + + Vector sum_vec(a_vec.Dim()); + sum_vec.AddVec(a_count, a_vec); + sum_vec.AddVec(b_count, b_vec); + sum_vec.Scale(1.0 / (a_count + b_count)); + KALDI_ASSERT(ApproxEqual(sum_vec.Sum(), 1.0)); + + InformationBottleneckClusterable sum(3); + InformationBottleneckClusterable c(3); + + sum.Add(a); + sum.Add(b); + + c.AddStats(1, a_count, a_vec); + c.AddStats(2, b_count, b_vec); + + KALDI_ASSERT(c.Counts() == sum.Counts()); + KALDI_ASSERT(ApproxEqual(c.Objf(), sum.Objf())); + KALDI_ASSERT(ApproxEqual(-c.Objf() + a.Objf() + b.Objf(), a.Distance(b))); + KALDI_ASSERT(sum_vec.ApproxEqual(c.RelevanceDist())); + KALDI_ASSERT(sum_vec.ApproxEqual(sum.RelevanceDist())); + } + + for (int32 i = 0; i < 100; i++) { + int32 dim = RandInt(2, 10); + + Vector a_vec(dim); + a_vec.SetRandn(); + a_vec.ApplyPowAbs(1.0); + a_vec.Scale(1 / a_vec.Sum()); + KALDI_ASSERT(ApproxEqual(a_vec.Sum(), 1.0)); + int32 a_count = RandInt(1, 100); + InformationBottleneckClusterable a(1, a_count, a_vec); + + Vector b_vec(dim); + b_vec.SetRandn(); + b_vec.ApplyPowAbs(1.0); + b_vec.Scale(1 / b_vec.Sum()); + KALDI_ASSERT(ApproxEqual(b_vec.Sum(), 1.0)); + int32 b_count = RandInt(1, 100); + InformationBottleneckClusterable b(2, b_count, b_vec); + + Vector sum_vec(a_vec.Dim()); + sum_vec.AddVec(a_count, a_vec); + sum_vec.AddVec(b_count, b_vec); + sum_vec.Scale(1.0 / (a_count + b_count)); + KALDI_ASSERT(ApproxEqual(sum_vec.Sum(), 1.0)); + + InformationBottleneckClusterable sum(dim); + InformationBottleneckClusterable c(dim); + + sum.Add(a); + sum.Add(b); + + c.AddStats(1, a_count, a_vec); + c.AddStats(2, b_count, b_vec); + + KALDI_ASSERT(c.Counts() == sum.Counts()); + KALDI_ASSERT(ApproxEqual(c.Objf(), sum.Objf())); + KALDI_ASSERT(ApproxEqual(-c.Objf() + a.Objf() + b.Objf(), a.Distance(b))); + KALDI_ASSERT(sum_vec.ApproxEqual(c.RelevanceDist())); + KALDI_ASSERT(sum_vec.ApproxEqual(sum.RelevanceDist())); + } +} + +} // end namespace kaldi + +int main() { + using namespace kaldi; + + TestClusterable(); +} diff --git a/src/segmenter/information-bottleneck-clusterable.cc b/src/segmenter/information-bottleneck-clusterable.cc new file mode 100644 index 00000000000..05850c1eebc --- /dev/null +++ b/src/segmenter/information-bottleneck-clusterable.cc @@ -0,0 +1,231 @@ +// segmenter/information-bottleneck-clusterable.cc + +// Copyright 2017 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "segmenter/information-bottleneck-clusterable.h" + +namespace kaldi { + +void InformationBottleneckClusterable::AddStats( + int32 id, BaseFloat count, + const VectorBase &relevance_dist) { + std::map::iterator it = counts_.find(id); + KALDI_ASSERT(it == counts_.end() || it->first != id); + counts_.insert(it, std::make_pair(id, count)); + + double sum = relevance_dist.Sum(); + KALDI_ASSERT (sum != 0.0); + + p_yp_c_.Scale(total_count_); + p_yp_c_.AddVec(count / sum, relevance_dist); + total_count_ += count; + p_yp_c_.Scale(1.0 / total_count_); +} + +BaseFloat InformationBottleneckClusterable::Objf( + BaseFloat relevance_factor, BaseFloat input_factor) const { + double relevance_entropy = 0.0, count = 0.0; + for (int32 i = 0; i < p_yp_c_.Dim(); i++) { + if (p_yp_c_(i) > 1e-20) { + relevance_entropy -= p_yp_c_(i) * Log(p_yp_c_(i)); + count += p_yp_c_(i); + } + } + relevance_entropy = total_count_ * (relevance_entropy / count - Log(count)); + + double input_entropy = total_count_ * Log(total_count_); + for (std::map::const_iterator it = counts_.begin(); + it != counts_.end(); ++it) { + input_entropy -= it->second * Log(it->second); + } + + BaseFloat objf = -relevance_factor * relevance_entropy + + input_factor * input_entropy; + return objf; +} + +void InformationBottleneckClusterable::Add(const Clusterable &other_in) { + KALDI_ASSERT(other_in.Type() == "information-bottleneck"); + const InformationBottleneckClusterable *other = + static_cast (&other_in); + + for (std::map::const_iterator it = other->counts_.begin(); + it != other->counts_.end(); ++it) { + std::map::iterator hint_it = counts_.lower_bound( + it->first); + if (hint_it != counts_.end() && hint_it->first == it->first) { + KALDI_ERR << "Duplicate segment id " << it->first; + } + counts_.insert(hint_it, *it); + } + + p_yp_c_.Scale(total_count_); + p_yp_c_.AddVec(other->total_count_, other->p_yp_c_); + total_count_ += other->total_count_; + p_yp_c_.Scale(1.0 / total_count_); +} + +void InformationBottleneckClusterable::Sub(const Clusterable &other_in) { + KALDI_ASSERT(other_in.Type() == "information-bottleneck"); + const InformationBottleneckClusterable *other = + static_cast (&other_in); + + for (std::map::const_iterator it = other->counts_.begin(); + it != other->counts_.end(); ++it) { + std::map::iterator hint_it = counts_.lower_bound( + it->first); + KALDI_ASSERT (hint_it->first == it->first); + counts_.erase(hint_it); + } + + p_yp_c_.Scale(total_count_); + p_yp_c_.AddVec(-other->total_count_, other->p_yp_c_); + total_count_ -= other->total_count_; + p_yp_c_.Scale(1.0 / total_count_); +} + +Clusterable* InformationBottleneckClusterable::Copy() const { + InformationBottleneckClusterable *ans = + new InformationBottleneckClusterable(RelevanceDim()); + ans->Add(*this); + return ans; +} + +void InformationBottleneckClusterable::Scale(BaseFloat f) { + KALDI_ASSERT(f >= 0.0); + for (std::map::iterator it = counts_.begin(); + it != counts_.end(); ++it) { + it->second *= f; + } + total_count_ *= f; +} + +void InformationBottleneckClusterable::Write( + std::ostream &os, bool binary) const { + WriteToken(os, binary, "IBCL"); // magic string. + WriteBasicType(os, binary, counts_.size()); + BaseFloat total_count = 0.0; + for (std::map::const_iterator it = counts_.begin(); + it != counts_.end(); ++it) { + WriteBasicType(os, binary, it->first); + WriteBasicType(os, binary, it->second); + total_count += it->second; + } + KALDI_ASSERT(ApproxEqual(total_count_, total_count)); + WriteToken(os, binary, ""); + p_yp_c_.Write(os, binary); +} + +Clusterable* InformationBottleneckClusterable::ReadNew( + std::istream &is, bool binary) const { + InformationBottleneckClusterable *ibc = + new InformationBottleneckClusterable(); + ibc->Read(is, binary); + return ibc; +} + +void InformationBottleneckClusterable::Read(std::istream &is, bool binary) { + ExpectToken(is, binary, "IBCL"); // magic string. + int32 size; + ReadBasicType(is, binary, &size); + + for (int32 i = 0; i < 2 * size; i++) { + int32 id; + BaseFloat count; + ReadBasicType(is, binary, &id); + ReadBasicType(is, binary, &count); + std::pair::iterator, bool> ret; + ret = counts_.insert(std::make_pair(id, count)); + if (!ret.second) { + KALDI_ERR << "Duplicate element " << id << " when reading counts"; + } + total_count_ += count; + } + + ExpectToken(is, binary, ""); + p_yp_c_.Read(is, binary); +} + +BaseFloat InformationBottleneckClusterable::ObjfPlus( + const Clusterable &other, BaseFloat relevance_factor, + BaseFloat input_factor) const { + InformationBottleneckClusterable *copy = static_cast(Copy()); + copy->Add(other); + BaseFloat ans = copy->Objf(relevance_factor, input_factor); + delete copy; + return ans; +} + +BaseFloat InformationBottleneckClusterable::ObjfMinus( + const Clusterable &other, BaseFloat relevance_factor, + BaseFloat input_factor) const { + InformationBottleneckClusterable *copy = static_cast(Copy()); + copy->Add(other); + BaseFloat ans = copy->Objf(relevance_factor, input_factor); + delete copy; + return ans; +} + +BaseFloat InformationBottleneckClusterable::Distance( + const Clusterable &other_in, BaseFloat relevance_factor, + BaseFloat input_factor) const { + KALDI_ASSERT(other_in.Type() == "information-bottleneck"); + const InformationBottleneckClusterable *other = + static_cast (&other_in); + + BaseFloat normalizer = this->Normalizer() + other->Normalizer(); + BaseFloat pi_i = this->Normalizer() / normalizer; + BaseFloat pi_j = other->Normalizer() / normalizer; + + // Compute the distribution q_Y(y) = p(y|{c_i} + {c_j}) + Vector relevance_dist(this->RelevanceDim()); + relevance_dist.AddVec(pi_i, this->RelevanceDist()); + relevance_dist.AddVec(pi_j, other->RelevanceDist()); + + BaseFloat relevance_divergence + = pi_i * KLDivergence(this->RelevanceDist(), relevance_dist) + + pi_j * KLDivergence(other->RelevanceDist(), relevance_dist); + + BaseFloat input_divergence + = Log(normalizer) - pi_i * Log(this->Normalizer()) + - pi_j * Log(other->Normalizer()); + + KALDI_ASSERT(relevance_divergence > -1e-4); + KALDI_ASSERT(input_divergence > -1e-4); + + double ans = (normalizer * (relevance_factor * relevance_divergence + - input_factor * input_divergence)); + KALDI_ASSERT(input_factor != 0.0 || ans > -1e-4); + return ans; +} + +BaseFloat KLDivergence(const VectorBase &p1, + const VectorBase &p2) { + KALDI_ASSERT(p1.Dim() == p2.Dim()); + + double ans = 0.0, sum = 0.0; + for (int32 i = 0; i < p1.Dim(); i++) { + if (p1(i) > 1e-20) { + ans += p1(i) * Log(p1(i) / p2(i)); + sum += p1(i); + } + } + return ans / sum - Log(sum); +} + +} // end namespace kaldi diff --git a/src/segmenter/information-bottleneck-clusterable.h b/src/segmenter/information-bottleneck-clusterable.h new file mode 100644 index 00000000000..cb88d1221f7 --- /dev/null +++ b/src/segmenter/information-bottleneck-clusterable.h @@ -0,0 +1,163 @@ +// segmenter/information-bottleneck-clusterable.h + +// Copyright 2017 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_SEGMENTER_INFORMATION_BOTTLENECK_CLUSTERABLE_H_ +#define KALDI_SEGMENTER_INFORMATION_BOTTLENECK_CLUSTERABLE_H_ + +#include "base/kaldi-common.h" +#include "matrix/kaldi-matrix.h" +#include "itf/clusterable-itf.h" + +namespace kaldi { + +class InformationBottleneckClusterable: public Clusterable { + public: + /// Constructor used for creating empty object e.g. when reading from file. + InformationBottleneckClusterable(): total_count_(0.0) { } + + /// Constructor initializing the relevant variable dimension. + /// Used for making Copy() of object. + InformationBottleneckClusterable(int32 relevance_dim) : + total_count_(0.0), p_yp_c_(relevance_dim) { } + + /// Constructor initializing from input stats corresponding to a + /// segment. + InformationBottleneckClusterable(int32 id, BaseFloat count, + const VectorBase &relevance_dist): + total_count_(0.0), p_yp_c_(relevance_dist.Dim()) { + AddStats(id, count, relevance_dist); + } + + /// Return a copy of this object. + virtual Clusterable* Copy() const; + + /// Return the objective function, which is + /// N(c) * (-r * H(Y|c) + ibeta * H(X|c)) + /// where N(c) is the total count in the cluster + /// H(Y|c) is the conditional entropy of the relevance + /// variable distribution + /// H(X|c) is the conditional entropy of the input variable + /// distribution + /// r is the weight on the relevant variables + /// ibeta is the weight on the input variables + virtual BaseFloat Objf(BaseFloat relevance_factor, + BaseFloat input_factor) const; + + /// Return the objective function with the default values + /// for relevant_factor (1.0) and input_factor (0.1) + virtual BaseFloat Objf() const { return Objf(1.0, 0.1); } + + /// Return the count in this cluster. + virtual BaseFloat Normalizer() const { return total_count_; } + + /// Set stats to empty. + virtual void SetZero() { + counts_.clear(); + p_yp_c_.Resize(0); + total_count_ = 0.0; + } + + /// Add stats to this object + virtual void AddStats(int32 id, BaseFloat count, + const VectorBase &relevance_dist); + + /// Add other stats. + virtual void Add(const Clusterable &other); + /// Subtract other stats. + virtual void Sub(const Clusterable &other); + /// Scale the stats by a positive number f. + virtual void Scale(BaseFloat f); + + /// Return a string that describes the clusterable type. + virtual std::string Type() const { return "information-bottleneck"; } + + /// Write data to stream. + virtual void Write(std::ostream &os, bool binary) const; + + /// Read data from a stream and return the corresponding object (const + /// function; it's a class member because we need access to the vtable + /// so generic code can read derived types). + virtual Clusterable* ReadNew(std::istream &is, bool binary) const; + + /// Read data from stream + virtual void Read(std::istream &is, bool binary); + + /// Return the objective function of the combined object this + other. + virtual BaseFloat ObjfPlus(const Clusterable &other, + BaseFloat relevance_factor, + BaseFloat input_factor) const; + + /// Same as the above function, but using default values for + /// relevance_factor (1.0) and input_factor (0.1) + virtual BaseFloat ObjfPlus(const Clusterable &other) const { + return ObjfPlus(other, 1.0, 0.1); + } + + /// Return the objective function of the combined object this + other. + virtual BaseFloat ObjfMinus(const Clusterable &other, + BaseFloat relevance_factor, + BaseFloat input_factor) const; + + /// Same as the above function, but using default values for + /// relevance_factor (1.0) and input_factor (0.1) + virtual BaseFloat ObjfMinus(const Clusterable &other) const { + return ObjfMinus(other, 1.0, 0.1); + } + + /// Return the objective function decrease from merging the two + /// clusters. + /// Always a non-negative number. + virtual BaseFloat Distance(const Clusterable &other, + BaseFloat relevance_factor, + BaseFloat input_factor) const; + + /// Same as the above function, but using default values for + /// relevance_factor (1.0) and input_factor (0.1) + virtual BaseFloat Distance(const Clusterable &other) const { + return Distance(other, 1.0, 0.1); + } + + virtual ~InformationBottleneckClusterable() {} + + /// Public accessors + virtual const Vector& RelevanceDist() const { return p_yp_c_; } + virtual int32 RelevanceDim() const { return p_yp_c_.Dim(); } + + virtual const std::map& Counts() const { return counts_; } + + private: + /// A list of the original segments this cluster contains along with + /// their corresponding counts. + std::map counts_; + + /// Total count in this cluster. + BaseFloat total_count_; + + /// Relevant variable distribution. + /// TODO: Make sure that this is a valid probability distribution. + Vector p_yp_c_; +}; + +/// Returns the KL Divergence between two probability distributions. +BaseFloat KLDivergence(const VectorBase &p1, + const VectorBase &p2); + +} // end namespace kaldi + +#endif // KALDI_SEGMENTER_INFORMATION_BOTTLENECK_CLUSTERABLE_H_ diff --git a/src/segmenter/segment.cc b/src/segmenter/segment.cc new file mode 100644 index 00000000000..65a91a39264 --- /dev/null +++ b/src/segmenter/segment.cc @@ -0,0 +1,42 @@ +#include "segmenter/segment.h" + +namespace kaldi { +namespace segmenter { + +void Segment::Write(std::ostream &os, bool binary) const { + if (binary) { + os.write(reinterpret_cast(&start_frame), sizeof(start_frame)); + os.write(reinterpret_cast(&end_frame), sizeof(start_frame)); + os.write(reinterpret_cast(&class_id), sizeof(class_id)); + } else { + WriteBasicType(os, binary, start_frame); + WriteBasicType(os, binary, end_frame); + WriteBasicType(os, binary, Label()); + } +} + +void Segment::Read(std::istream &is, bool binary) { + if (binary) { + is.read(reinterpret_cast(&start_frame), sizeof(start_frame)); + is.read(reinterpret_cast(&end_frame), sizeof(end_frame)); + is.read(reinterpret_cast(&class_id), sizeof(class_id)); + } else { + ReadBasicType(is, binary, &start_frame); + ReadBasicType(is, binary, &end_frame); + int32 label; + ReadBasicType(is, binary, &label); + SetLabel(label); + } + + KALDI_ASSERT(end_frame >= start_frame && start_frame >= 0); +} + +std::ostream& operator<<(std::ostream& os, const Segment &seg) { + os << "[ "; + seg.Write(os, false); + os << "]"; + return os; +} + +} // end namespace segmenter +} // end namespace kaldi diff --git a/src/segmenter/segment.h b/src/segmenter/segment.h new file mode 100644 index 00000000000..b172fa854a8 --- /dev/null +++ b/src/segmenter/segment.h @@ -0,0 +1,105 @@ +// segmenter/segment.h" + +// Copyright 2016 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_SEGMENTER_SEGMENT_H_ +#define KALDI_SEGMENTER_SEGMENT_H_ + +#include "base/kaldi-common.h" +#include "matrix/kaldi-matrix.h" + +namespace kaldi { +namespace segmenter { + +/** + * This structure defines a single segment. It consists of the following basic + * properties: + * 1) start_frame : This is the frame index of the first frame in the + * segment. + * 2) end_frame : This is the frame index of the last frame in the segment. + * Note that the end_frame is included in the segment. + * 3) class_id : This is the class corresponding to the segments. For e.g., + * could be 0, 1 or 2 depending on whether the segment is + * silence, speech or noise. In general, it can be any + * integer class label. +**/ + +struct Segment { + int32 start_frame; + int32 end_frame; + int32 class_id; + + // Accessors for labels or class id. This is useful in the future when + // we might change the type of label. + inline int32 Label() const { return class_id; } + inline void SetLabel(int32 label) { class_id = label; } + inline int32 Length() const { return end_frame - start_frame + 1; } + + // This is the default constructor that sets everything to undefined values. + Segment() : start_frame(-1), end_frame(-1), class_id(-1) { } + + // This constructor initializes the segmented with the provided start and end + // frames and the segment label. This is the main constructor. + Segment(int32 start, int32 end, int32 label) : + start_frame(start), end_frame(end), class_id(label) { } + + void Write(std::ostream &os, bool binary) const; + void Read(std::istream &is, bool binary); + + // This is a function that returns the size of the elements in the structure. + // It is used during I/O in binary mode, which checks for the total size + // required to store the segment. + static size_t SizeInBytes() { + return (sizeof(int32) + sizeof(int32) + sizeof(int32)); + } + + void Reset() { + start_frame = -1; + end_frame = -1; + class_id = -1; + } +}; + +/** + * Comparator to order segments based on start frame +**/ + +class SegmentComparator { + public: + bool operator() (const Segment &lhs, const Segment &rhs) const { + return lhs.start_frame < rhs.start_frame; + } +}; + +/** + * Comparator to order segments based on length +**/ + +class SegmentLengthComparator { + public: + bool operator() (const Segment &lhs, const Segment &rhs) const { + return lhs.Length() < rhs.Length(); + } +}; + +std::ostream& operator<<(std::ostream& os, const Segment &seg); + +} // end namespace segmenter +} // end namespace kaldi + +#endif // KALDI_SEGMENTER_SEGMENT_H_ diff --git a/src/segmenter/segmentation-io-test.cc b/src/segmenter/segmentation-io-test.cc new file mode 100644 index 00000000000..f019a653a4a --- /dev/null +++ b/src/segmenter/segmentation-io-test.cc @@ -0,0 +1,63 @@ +// segmenter/segmentation-io-test.cc + +// Copyright 2015 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "segmenter/segmentation.h" + +namespace kaldi { +namespace segmenter { + +void UnitTestSegmentationIo() { + Segmentation seg; + int32 max_length = RandInt(0, 1000), + max_segment_length = max_length / 10, + num_classes = RandInt(0, 3); + + if (max_segment_length == 0) + max_segment_length = 1; + + seg.GenRandomSegmentation(max_length, max_segment_length, num_classes); + + bool binary = ( RandInt(0,1) == 0 ); + std::ostringstream os; + + seg.Write(os, binary); + + Segmentation seg2; + std::istringstream is(os.str()); + seg2.Read(is, binary); + + std::ostringstream os2; + seg2.Write(os2, binary); + + KALDI_ASSERT(os2.str() == os.str()); +} + +} // namespace segmenter +} // namespace kaldi + +int main() { + using namespace kaldi; + using namespace kaldi::segmenter; + + for (int32 i = 0; i < 100; i++) + UnitTestSegmentationIo(); + return 0; +} + + diff --git a/src/segmenter/segmentation-post-processor.cc b/src/segmenter/segmentation-post-processor.cc new file mode 100644 index 00000000000..1bec12360fc --- /dev/null +++ b/src/segmenter/segmentation-post-processor.cc @@ -0,0 +1,201 @@ +// segmenter/segmentation-post-processor.h + +// Copyright 2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "segmenter/segmentation-utils.h" +#include "segmenter/segmentation-post-processor.h" + +namespace kaldi { +namespace segmenter { + +static inline bool IsMergingLabelsToBeDone( + const SegmentationPostProcessingOptions &opts) { + return (!opts.merge_labels_csl.empty() || opts.merge_dst_label != -1); +} + +static inline bool IsPaddingSegmentsToBeDone( + const SegmentationPostProcessingOptions &opts) { + return (opts.pad_label != -1 || opts.pad_length != -1); +} + +static inline bool IsShrinkingSegmentsToBeDone( + const SegmentationPostProcessingOptions &opts) { + return (opts.shrink_label != -1 || opts.shrink_length != -1); +} + +static inline bool IsBlendingShortSegmentsToBeDone( + const SegmentationPostProcessingOptions &opts) { + return (opts.blend_short_segments_class != -1 || opts.max_blend_length != -1); +} + +static inline bool IsRemovingSegmentsToBeDone( + const SegmentationPostProcessingOptions &opts) { + return (!opts.remove_labels_csl.empty()); +} + +static inline bool IsMergingAdjacentSegmentsToBeDone( + const SegmentationPostProcessingOptions &opts) { + return (opts.merge_adjacent_segments); +} + +static inline bool IsSplittingSegmentsToBeDone( + const SegmentationPostProcessingOptions &opts) { + return (opts.max_segment_length != -1); +} + + +SegmentationPostProcessor::SegmentationPostProcessor( + const SegmentationPostProcessingOptions &opts) : opts_(opts) { + if (!opts_.remove_labels_csl.empty()) { + if (!SplitStringToIntegers(opts_.remove_labels_csl, ":", + false, &remove_labels_)) { + KALDI_ERR << "Bad value for --remove-labels option: " + << opts_.remove_labels_csl; + } + std::sort(remove_labels_.begin(), remove_labels_.end()); + } + + if (!opts_.merge_labels_csl.empty()) { + if (!SplitStringToIntegers(opts_.merge_labels_csl, ":", + false, &merge_labels_)) { + KALDI_ERR << "Bad value for --merge-labels option: " + << opts_.merge_labels_csl; + } + std::sort(merge_labels_.begin(), merge_labels_.end()); + } + + Check(); +} + +void SegmentationPostProcessor::Check() const { + if (IsPaddingSegmentsToBeDone(opts_) && opts_.pad_label < 0) { + KALDI_ERR << "Invalid value " << opts_.pad_label << " for option " + << "--pad-label. It must be non-negative."; + } + + if (IsPaddingSegmentsToBeDone(opts_) && opts_.pad_length <= 0) { + KALDI_ERR << "Invalid value " << opts_.pad_length << " for option " + << "--pad-length. It must be positive."; + } + + if (IsShrinkingSegmentsToBeDone(opts_) && opts_.shrink_label < 0) { + KALDI_ERR << "Invalid value " << opts_.shrink_label << " for option " + << "--shrink-label. It must be non-negative."; + } + + if (IsShrinkingSegmentsToBeDone(opts_) && opts_.shrink_length <= 0) { + KALDI_ERR << "Invalid value " << opts_.shrink_length << " for option " + << "--shrink-length. It must be positive."; + } + + if (IsBlendingShortSegmentsToBeDone(opts_) && + opts_.blend_short_segments_class < 0) { + KALDI_ERR << "Invalid value " << opts_.blend_short_segments_class + << " for option " << "--blend-short-segments-class. " + << "It must be non-negative."; + } + + if (IsBlendingShortSegmentsToBeDone(opts_) && opts_.max_blend_length <= 0) { + KALDI_ERR << "Invalid value " << opts_.max_blend_length << " for option " + << "--max-blend-length. It must be positive."; + } + + if (IsRemovingSegmentsToBeDone(opts_) && + (remove_labels_[0] < -1 || + (remove_labels_.size() > 1 && remove_labels_[0] == -1))) { + KALDI_ERR << "Invalid value " << opts_.remove_labels_csl + << " for option " << "--remove-labels. " + << "The labels must be non-negative."; + } + + if (IsMergingAdjacentSegmentsToBeDone(opts_) && + opts_.max_intersegment_length < 0) { + KALDI_ERR << "Invalid value " << opts_.max_intersegment_length + << " for option " + << "--max-intersegment-length. It must be non-negative."; + } + + if (IsSplittingSegmentsToBeDone(opts_) && opts_.max_segment_length <= 0) { + KALDI_ERR << "Invalid value " << opts_.max_segment_length + << " for option " + << "--max-segment-length. It must be positive."; + } + + if (opts_.post_process_label != -1 && opts_.post_process_label < 0) { + KALDI_ERR << "Invalid value " << opts_.post_process_label << " for option " + << "--post-process-label. It must be non-negative."; + } +} + +bool SegmentationPostProcessor::PostProcess(Segmentation *seg) const { + DoMergingLabels(seg); + DoPaddingSegments(seg); + DoShrinkingSegments(seg); + DoBlendingShortSegments(seg); + DoRemovingSegments(seg); + DoMergingAdjacentSegments(seg); + DoSplittingSegments(seg); + + return true; +} + +void SegmentationPostProcessor::DoMergingLabels(Segmentation *seg) const { + if (!IsMergingLabelsToBeDone(opts_)) return; + MergeLabels(merge_labels_, opts_.merge_dst_label, seg); +} + +void SegmentationPostProcessor::DoPaddingSegments(Segmentation *seg) const { + if (!IsPaddingSegmentsToBeDone(opts_)) return; + PadSegments(opts_.pad_label, opts_.pad_length, seg); +} + +void SegmentationPostProcessor::DoShrinkingSegments(Segmentation *seg) const { + if (!IsShrinkingSegmentsToBeDone(opts_)) return; + ShrinkSegments(opts_.shrink_label, opts_.shrink_length, seg); +} + +void SegmentationPostProcessor::DoBlendingShortSegments( + Segmentation *seg) const { + if (!IsBlendingShortSegmentsToBeDone(opts_)) return; + BlendShortSegmentsWithNeighbors(opts_.blend_short_segments_class, + opts_.max_blend_length, + opts_.max_intersegment_length, seg); +} + +void SegmentationPostProcessor::DoRemovingSegments(Segmentation *seg) const { + if (!IsRemovingSegmentsToBeDone(opts_)) return; + RemoveSegments(remove_labels_, opts_.max_remove_length, + seg); +} + +void SegmentationPostProcessor::DoMergingAdjacentSegments( + Segmentation *seg) const { + if (!IsMergingAdjacentSegmentsToBeDone(opts_)) return; + MergeAdjacentSegments(opts_.max_intersegment_length, seg); +} + +void SegmentationPostProcessor::DoSplittingSegments(Segmentation *seg) const { + if (!IsSplittingSegmentsToBeDone(opts_)) return; + SplitSegments(opts_.max_segment_length, + opts_.max_segment_length / 2, + opts_.overlap_length, + opts_.post_process_label, seg); +} + +} // end namespace segmenter +} // end namespace kaldi diff --git a/src/segmenter/segmentation-post-processor.h b/src/segmenter/segmentation-post-processor.h new file mode 100644 index 00000000000..040d6c44383 --- /dev/null +++ b/src/segmenter/segmentation-post-processor.h @@ -0,0 +1,174 @@ +// segmenter/segmentation-post-processor.h + +// Copyright 2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_SEGMENTER_SEGMENTATION_POST_PROCESSOR_H_ +#define KALDI_SEGMENTER_SEGMENTATION_POST_PROCESSOR_H_ + +#include "base/kaldi-common.h" +#include "itf/options-itf.h" +#include "segmenter/segmentation.h" + +namespace kaldi { +namespace segmenter { + +/** + * Structure for some common options related to segmentation that would be used + * in multiple segmentation programs. Some of the operations include merging, + * filtering etc. +**/ + +struct SegmentationPostProcessingOptions { + std::string merge_labels_csl; + int32 merge_dst_label; + + int32 pad_label; + int32 pad_length; + + int32 shrink_label; + int32 shrink_length; + + int32 blend_short_segments_class; + int32 max_blend_length; + + std::string remove_labels_csl; + int32 max_remove_length; + + bool merge_adjacent_segments; + int32 max_intersegment_length; + + int32 max_segment_length; + int32 overlap_length; + + int32 post_process_label; + + SegmentationPostProcessingOptions() : + merge_dst_label(-1), + pad_label(-1), pad_length(-1), + shrink_label(-1), shrink_length(-1), + blend_short_segments_class(-1), max_blend_length(-1), + max_remove_length(-1), + merge_adjacent_segments(false), + max_intersegment_length(0), + max_segment_length(-1), overlap_length(0), + post_process_label(-1) { } + + void Register(OptionsItf *opts) { + opts->Register("merge-labels", &merge_labels_csl, "Merge labels into a " + "single label defined by merge-dst-label. " + "The labels are specified as a colon-separated list. " + "Refer to the MergeLabels() code for details. " + "Used in conjunction with the option --merge-dst-label"); + opts->Register("merge-dst-label", &merge_dst_label, + "Merge labels specified by merge-labels into this label. " + "Refer to the MergeLabels() code for details. " + "Used in conjunction with the option --merge-labels."); + opts->Register("pad-label", &pad_label, + "Pad segments of this label by pad_length frames." + "Refer to the PadSegments() code for details. " + "Used in conjunction with the option --pad-length."); + opts->Register("pad-length", &pad_length, "Pad segments by this many " + "frames on either side. " + "Refer to the PadSegments() code for details. " + "Used in conjunction with the option --pad-label."); + opts->Register("shrink-label", &shrink_label, + "Shrink segments of this label by shrink_length frames. " + "Refer to the ShrinkSegments() code for details. " + "Used in conjunction with the option --shrink-length."); + opts->Register("shrink-length", &shrink_length, "Shrink segments by this " + "many frames on either side. " + "Refer to the ShrinkSegments() code for details. " + "Used in conjunction with the option --shrink-label."); + opts->Register("blend-short-segments-class", &blend_short_segments_class, + "The label for which the short segments are to be " + "blended with the neighboring segments that are less than " + "max_intersegment_length frames away. " + "Refer to BlendShortSegments() code for details. " + "Used in conjunction with the option --max-blend-length " + "and --max-intersegment-length."); + opts->Register("max-blend-length", &max_blend_length, + "The maximum length of segment in number of frames that " + "will be blended with the neighboring segments provided " + "they both have the same label. " + "Refer to BlendShortSegments() code for details. " + "Used in conjunction with the option " + "--blend-short-segments-class"); + opts->Register("remove-labels", &remove_labels_csl, + "Remove any segment whose label is contained in " + "remove_labels_csl. " + "Refer to the RemoveLabels() code for details."); + opts->Register("max-remove-length", &max_remove_length, + "If provided, specifies the maximum length of segments " + "that will be removed by --remove-labels option"); + opts->Register("merge-adjacent-segments", &merge_adjacent_segments, + "Merge adjacent segments of the same label if they are " + "within max-intersegment-length distance. " + "Refer to the MergeAdjacentSegments() code for details. " + "Used in conjunction with the option " + "--max-intersegment-length\n"); + opts->Register("max-intersegment-length", &max_intersegment_length, + "The maximum intersegment length that is allowed for " + "two adjacent segments to be merged. " + "Refer to the MergeAdjacentSegments() code for details. " + "Used in conjunction with the option " + "--merge-adjacent-segments or " + "--blend-short-segments-class\n"); + opts->Register("max-segment-length", &max_segment_length, + "If segment is longer than this length, split it into " + "pieces with less than these many frames. " + "Refer to the SplitSegments() code for details. " + "Used in conjunction with the option --overlap-length."); + opts->Register("overlap-length", &overlap_length, + "When splitting segments longer than max-segment-length, " + "have the pieces overlap by these many frames. " + "Refer to the SplitSegments() code for details. " + "Used in conjunction with the option --max-segment-length."); + opts->Register("post-process-label", &post_process_label, + "Do post processing only on this label. This option is " + "applicable to only a few operations including " + "SplitSegments"); + } +}; + +class SegmentationPostProcessor { + public: + explicit SegmentationPostProcessor( + const SegmentationPostProcessingOptions &opts); + + bool PostProcess(Segmentation *seg) const; + + void DoMergingLabels(Segmentation *seg) const; + void DoPaddingSegments(Segmentation *seg) const; + void DoShrinkingSegments(Segmentation *seg) const; + void DoBlendingShortSegments(Segmentation *seg) const; + void DoRemovingSegments(Segmentation *seg) const; + void DoMergingAdjacentSegments(Segmentation *seg) const; + void DoSplittingSegments(Segmentation *seg) const; + + private: + const SegmentationPostProcessingOptions &opts_; + std::vector merge_labels_; + std::vector remove_labels_; + + void Check() const; +}; + +} // end namespace segmenter +} // end namespace kaldi + +#endif // KALDI_SEGMENTER_SEGMENTATION_POST_PROCESSOR_H_ diff --git a/src/segmenter/segmentation-test.cc b/src/segmenter/segmentation-test.cc new file mode 100644 index 00000000000..7654b23b119 --- /dev/null +++ b/src/segmenter/segmentation-test.cc @@ -0,0 +1,226 @@ +// segmenter/segmentation-test.cc + +// Copyright 2015 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "segmenter/segmentation.h" + +namespace kaldi { +namespace segmenter { + +void GenerateRandomSegmentation(int32 max_length, int32 num_classes, + Segmentation *segmentation) { + Clear(); + int32 s = max_length; + int32 e = max_length; + + while (s >= 0) { + int32 chunk_size = rand() % (max_length / 10); + s = e - chunk_size + 1; + int32 k = rand() % num_classes; + + if (k != 0) { + segmentation.Emplace(s, e, k); + } + e = s - 1; + } + Check(); +} + + +int32 GenerateRandomAlignment(int32 max_length, int32 num_classes, + std::vector *ali) { + int32 N = RandInt(1, max_length); + int32 C = RandInt(1, num_classes); + + ali->clear(); + + int32 len = 0; + while (len < N) { + int32 c = RandInt(0, C-1); + int32 n = std::min(RandInt(1, N), N - len); + ali->insert(ali->begin() + len, n, c); + len += n; + } + KALDI_ASSERT(ali->size() == N && len == N); + + int32 state = -1, num_segments = 0; + for (std::vector::const_iterator it = ali->begin(); + it != ali->end(); ++it) { + if (*it != state) num_segments++; + state = *it; + } + + return num_segments; +} + +void TestConversionToAlignment() { + std::vector ali; + int32 max_length = 1000, num_classes = 3; + int32 num_segments = GenerateRandomAlignment(max_length, num_classes, &ali); + + Segmentation seg; + KALDI_ASSERT(num_segments == seg.InsertFromAlignment(ali, 0)); + + std::vector out_ali; + { + seg.ConvertToAlignment(&out_ali); + KALDI_ASSERT(ali == out_ali); + } + + { + seg.ConvertToAlignment(&out_ali, num_classes, max_length * 2); + std::vector tmp_ali(out_ali.begin(), out_ali.begin() + ali.size()); + KALDI_ASSERT(ali == tmp_ali); + for (std::vector::const_iterator it = out_ali.begin() + ali.size(); + it != out_ali.end(); ++it) { + KALDI_ASSERT(*it == num_classes); + } + } + + seg.Clear(); + KALDI_ASSERT(num_segments == seg.InsertFromAlignment(ali, max_length)); + { + seg.ConvertToAlignment(&out_ali, num_classes, max_length * 2); + + for (std::vector::const_iterator it = out_ali.begin(); + it != out_ali.begin() + max_length; ++it) { + KALDI_ASSERT(*it == num_classes); + } + std::vector tmp_ali(out_ali.begin() + max_length, out_ali.begin() + max_length + ali.size()); + KALDI_ASSERT(tmp_ali == ali); + + for (std::vector::const_iterator it = out_ali.begin() + max_length + ali.size(); + it != out_ali.end(); ++it) { + KALDI_ASSERT(*it == num_classes); + } + } +} + +void TestRemoveSegments() { + std::vector ali; + int32 max_length = 1000, num_classes = 10; + int32 num_segments = GenerateRandomAlignment(max_length, num_classes, &ali); + + Segmentation seg; + KALDI_ASSERT(num_segments == seg.InsertFromAlignment(ali, 0)); + + for (int32 i = 0; i < num_classes; i++) { + Segmentation out_seg(seg); + out_seg.RemoveSegments(i); + std::vector out_ali; + out_seg.ConvertToAlignment(&out_ali, i, ali.size()); + KALDI_ASSERT(ali == out_ali); + } + + { + std::vector classes; + for (int32 i = 0; i < 3; i++) + classes.push_back(RandInt(0, num_classes - 1)); + std::sort(classes.begin(), classes.end()); + + Segmentation out_seg1(seg); + out_seg1.RemoveSegments(classes); + + Segmentation out_seg2(seg); + for (std::vector::const_iterator it = classes.begin(); + it != classes.end(); ++it) + out_seg2.RemoveSegments(*it); + + std::vector out_ali1, out_ali2; + out_seg1.ConvertToAlignment(&out_ali1); + out_seg2.ConvertToAlignment(&out_ali2); + + KALDI_ASSERT(out_ali1 == out_ali2); + } +} + +void TestIntersectSegments() { + int32 max_length = 100, num_classes = 3; + + std::vector primary_ali; + GenerateRandomAlignment(max_length, num_classes, &primary_ali); + + std::vector secondary_ali; + GenerateRandomAlignment(max_length, num_classes, &secondary_ali); + + Segmentation primary_seg; + primary_seg.InsertFromAlignment(primary_ali); + + Segmentation secondary_seg; + secondary_seg.InsertFromAlignment(secondary_ali); + + { + Segmentation out_seg; + primary_seg.IntersectSegments(secondary_seg, &out_seg, num_classes); + + std::vector out_ali; + out_seg.ConvertToAlignment(&out_ali); + + std::vector oracle_ali(primary_ali.size()); + + for (size_t i = 0; i < oracle_ali.size(); i++) { + int32 p = (i < primary_ali.size()) ? primary_ali[i] : -1; + int32 s = (i < secondary_ali.size()) ? secondary_ali[i] : -2; + + oracle_ali[i] = (p == s) ? p : num_classes; + } + + KALDI_ASSERT(oracle_ali == out_ali); + } + + { + Segmentation out_seg; + primary_seg.IntersectSegments(secondary_seg, &out_seg); + + std::vector out_ali; + out_seg.ConvertToAlignment(&out_ali, num_classes); + + std::vector oracle_ali(out_ali.size()); + + for (size_t i = 0; i < oracle_ali.size(); i++) { + int32 p = (i < primary_ali.size()) ? primary_ali[i] : -1; + int32 s = (i < secondary_ali.size()) ? secondary_ali[i] : -2; + + oracle_ali[i] = (p == s) ? p : num_classes; + } + + KALDI_ASSERT(oracle_ali == out_ali); + } + +} + +void UnitTestSegmentation() { + TestConversionToAlignment(); + TestRemoveSegments(); + TestIntersectSegments(); +} + +} // namespace segmenter +} // namespace kaldi + +int main() { + using namespace kaldi; + using namespace kaldi::segmenter; + + for (int32 i = 0; i < 10; i++) + UnitTestSegmentation(); + return 0; +} + + + diff --git a/src/segmenter/segmentation-utils.cc b/src/segmenter/segmentation-utils.cc new file mode 100644 index 00000000000..4d76afba0b8 --- /dev/null +++ b/src/segmenter/segmentation-utils.cc @@ -0,0 +1,768 @@ +// segmenter/segmentation-utils.cc + +// Copyright 2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "segmenter/segmentation-utils.h" + +namespace kaldi { +namespace segmenter { + +void MergeLabels(const std::vector &merge_labels, + int32 dest_label, + Segmentation *segmentation) { + KALDI_ASSERT(segmentation); + + // Check if sorted and unique + KALDI_ASSERT(std::adjacent_find(merge_labels.begin(), + merge_labels.end(), std::greater()) + == merge_labels.end()); + + for (SegmentList::iterator it = segmentation->Begin(); + it != segmentation->End(); ++it) { + if (std::binary_search(merge_labels.begin(), merge_labels.end(), + it->Label())) { + it->SetLabel(dest_label); + } + } +#ifdef KALDI_PARANOID + segmentation->Check(); +#endif +} + +void RelabelSegmentsUsingMap(const unordered_map &label_map, + Segmentation *segmentation) { + int32 default_label = -1; + unordered_map::const_iterator it = label_map.find(-1); + if (it != label_map.end()) { + default_label = it->second; + KALDI_ASSERT(default_label != -1); + } + + for (SegmentList::iterator it = segmentation->Begin(); + it != segmentation->End(); ) { + unordered_map::const_iterator map_it = label_map.find( + it->Label()); + int32 dest_label = -100; + if (map_it == label_map.end()) { + if (default_label == -1) + KALDI_ERR << "Could not find label " << it->Label() + << " in label map."; + else + dest_label = default_label; + } else { + dest_label = map_it->second; + } + + if (dest_label == -1) { + // Remove segments that will be mapped to label -1. + it = segmentation->Erase(it); + continue; + } + it->SetLabel(dest_label); + ++it; + } +} + +void RelabelAllSegments(int32 label, Segmentation *segmentation) { + for (SegmentList::iterator it = segmentation->Begin(); + it != segmentation->End(); ++it) + it->SetLabel(label); +} + +void ScaleFrameShift(BaseFloat factor, Segmentation *segmentation) { + for (SegmentList::iterator it = segmentation->Begin(); + it != segmentation->End(); ++it) { + it->start_frame *= factor; + it->end_frame *= factor; + } +} + +void RemoveSegments(int32 label, Segmentation *segmentation) { + for (SegmentList::iterator it = segmentation->Begin(); + it != segmentation->End(); ) { + if (it->Label() == label) { + it = segmentation->Erase(it); + } else { + ++it; + } + } +#ifdef KALDI_PARANOID + segmentation->Check(); +#endif +} + +void RemoveSegments(const std::vector &labels, + int32 max_remove_length, + Segmentation *segmentation) { + // Check if sorted and unique + KALDI_ASSERT(std::adjacent_find(labels.begin(), + labels.end(), std::greater()) == labels.end()); + + for (SegmentList::iterator it = segmentation->Begin(); + it != segmentation->End(); ) { + if (max_remove_length < 0) { + if (std::binary_search(labels.begin(), labels.end(), + it->Label())) + it = segmentation->Erase(it); + else + ++it; + } else if (it->Length() < max_remove_length) { + if (std::binary_search(labels.begin(), labels.end(), + it->Label()) || + (labels.size() == 1 && labels[0] == -1)) + it = segmentation->Erase(it); + else + ++it; + } else { + ++it; + } + } +#ifdef KALDI_PARANOID + segmentation->Check(); +#endif +} + +// Opposite of RemoveSegments() +void KeepSegments(int32 label, Segmentation *segmentation) { + for (SegmentList::iterator it = segmentation->Begin(); + it != segmentation->End(); ) { + if (it->Label() != label) { + it = segmentation->Erase(it); + } else { + ++it; + } + } +#ifdef KALDI_PARANOID + segmentation->Check(); +#endif +} + +// TODO(Vimal): Write test function for this. +void SplitInputSegmentation(const Segmentation &in_segmentation, + int32 segment_length, + Segmentation *out_segmentation) { + out_segmentation->Clear(); + for (SegmentList::const_iterator it = in_segmentation.Begin(); + it != in_segmentation.End(); ++it) { + int32 length = it->Length(); + + // Since ceil is used, this results in all pieces to be smaller than + // segment_length rather than being larger. + int32 num_chunks = std::ceil(static_cast(length) + / segment_length); + int32 actual_segment_length = static_cast(length) / num_chunks; + + int32 start_frame = it->start_frame; + for (int32 j = 0; j < num_chunks; j++) { + int32 end_frame = std::min(start_frame + actual_segment_length - 1, + it->end_frame); + out_segmentation->EmplaceBack(start_frame, end_frame, it->Label()); + start_frame = end_frame + 1; + } + } +#ifdef KALDI_PARANOID + out_segmentation->Check(); +#endif +} + +// TODO(Vimal): Write test function for this. +void SplitSegments(int32 segment_length, int32 min_remainder, + int32 overlap_length, int32 segment_label, + Segmentation *segmentation) { + KALDI_ASSERT(segmentation); + KALDI_ASSERT(segment_length > 0 && min_remainder > 0); + KALDI_ASSERT(overlap_length >= 0); + + KALDI_ASSERT(overlap_length < segment_length); + for (SegmentList::iterator it = segmentation->Begin(); + it != segmentation->End(); ++it) { + if (segment_label != -1 && it->Label() != segment_label) continue; + + int32 start_frame = it->start_frame; + int32 length = it->Length(); + + if (length > segment_length + min_remainder) { + // Split segment + // To show what this is doing, consider the following example, where it is + // currently pointing to B. + // A <--> B <--> C + + // Modify the start_frame of the current frame. This prepares the current + // segment to be used as the "next segment" when we move the iterator in + // the next statement. + // In the example, the start_frame for B has just been modified. + it->start_frame = start_frame + segment_length - overlap_length; + + // Create a new segment and add it to the where the current iterator is. + // The statement below results in this: + // A <--> B1 <--> B <--> C + // with the iterator it pointing at B1. So when the iterator is + // incremented in the for loop, it will point to B again, but whose + // start_frame had been modified. + it = segmentation->Emplace(it, start_frame, + start_frame + segment_length - 1, + it->Label()); + } + } +#ifdef KALDI_PARANOID + segmentation->Check(); +#endif +} + +// TODO(Vimal): Write test code for this +void SplitSegmentsUsingAlignment(int32 segment_length, + int32 segment_label, + const std::vector &ali, + int32 ali_label, + int32 min_silence_length, + Segmentation *segmentation) { + KALDI_ASSERT(segmentation); + KALDI_ASSERT(segment_length > 0); + + for (SegmentList::iterator it = segmentation->Begin(); + it != segmentation->End();) { + // Safety check. In practice, should never fail. + KALDI_ASSERT(segmentation->Dim() <= ali.size()); + + if (segment_label != -1 && it->Label() != segment_label) { + ++it; + continue; + } + + int32 start_frame = it->start_frame; + int32 length = it->Length(); + int32 label = it->Label(); + + if (length <= segment_length) { + ++it; + continue; + } + + // Split segment + // To show what this is doing, consider the following example, where it is + // currently pointing to B. + // A <--> B <--> C + + Segmentation ali_segmentation; + InsertFromAlignment(ali, start_frame, + start_frame + length, + 0, &ali_segmentation, NULL); + KeepSegments(ali_label, &ali_segmentation); + MergeAdjacentSegments(0, &ali_segmentation); + + // Get largest alignment chunk where label == ali_label + SegmentList::iterator s_it = ali_segmentation.MaxElement(); + + if (s_it == ali_segmentation.End() || s_it->Length() < min_silence_length) { + ++it; + continue; + } + + KALDI_ASSERT(s_it->start_frame >= start_frame); + KALDI_ASSERT(s_it->end_frame <= start_frame + length); + + // Modify the start_frame of the current frame. This prepares the current + // segment to be used as the "next segment" when we move the iterator in + // the next statement. + // In the example, the start_frame for B has just been modified. + int32 end_frame; + if (s_it->Length() > 1) { + end_frame = s_it->start_frame + s_it->Length() / 2 - 2; + it->start_frame = end_frame + 2; + } else { + end_frame = s_it->start_frame - 1; + it->start_frame = s_it->end_frame + 1; + } + + // end_frame is within this current segment + KALDI_ASSERT(end_frame < start_frame + length); + // The first new segment length is smaller than the old segment length + KALDI_ASSERT(end_frame - start_frame + 1 < length); + + // The second new segment length is smaller than the old segment length + KALDI_ASSERT(it->end_frame - end_frame - 1 < length); + + if (it->Length() < 0) { + // This is possible when the beginning of the segment is silence + it = segmentation->Erase(it); + } + + // Create a new segment and add it to the where the current iterator is. + // The statement below results in this: + // A <--> B1 <--> B <--> C + // with the iterator it pointing at B1. + if (end_frame >= start_frame) { + it = segmentation->Emplace(it, start_frame, end_frame, label); + } + } +#ifdef KALDI_PARANOID + segmentation->Check(); +#endif +} + +// TODO(Vimal): Write test code for this +void IntersectSegmentationAndAlignment(const Segmentation &in_segmentation, + const std::vector &alignment, + int32 ali_label, + int32 min_align_chunk_length, + Segmentation *out_segmentation) { + KALDI_ASSERT(out_segmentation); + + for (SegmentList::const_iterator it = in_segmentation.Begin(); + it != in_segmentation.End(); ++it) { + Segmentation filter_segmentation; + InsertFromAlignment(alignment, it->start_frame, + std::min(it->end_frame + 1, + static_cast(alignment.size())), + 0, &filter_segmentation, NULL); + + for (SegmentList::const_iterator f_it = filter_segmentation.Begin(); + f_it != filter_segmentation.End(); ++f_it) { + if (f_it->Length() < min_align_chunk_length) continue; + if (ali_label != -1 && f_it->Label() != ali_label) continue; + out_segmentation->EmplaceBack(f_it->start_frame, f_it->end_frame, + it->Label()); + } + } +} + +void SubSegmentUsingNonOverlappingSegments( + const Segmentation &primary_segmentation, + const Segmentation &secondary_segmentation, int32 secondary_label, + int32 subsegment_label, int32 unmatched_label, + Segmentation *out_segmentation) { + KALDI_ASSERT(out_segmentation); + KALDI_ASSERT(secondary_segmentation.Dim() > 0); + + std::vector alignment; + ConvertToAlignment(secondary_segmentation, -1, -1, 0, &alignment); + + for (SegmentList::const_iterator it = primary_segmentation.Begin(); + it != primary_segmentation.End(); ++it) { + if (it->end_frame >= alignment.size()) { + alignment.resize(it->end_frame + 1, -1); + } + Segmentation filter_segmentation; + InsertFromAlignment(alignment, it->start_frame, it->end_frame + 1, + 0, &filter_segmentation, NULL); + + for (SegmentList::const_iterator f_it = filter_segmentation.Begin(); + f_it != filter_segmentation.End(); ++f_it) { + int32 label = (unmatched_label >= 0 ? unmatched_label : it->Label()); + if (f_it->Label() == secondary_label) { + if (subsegment_label >= 0) { + label = subsegment_label; + } else { + label = f_it->Label(); + } + } + out_segmentation->EmplaceBack(f_it->start_frame, f_it->end_frame, + label); + } + } +} + +// TODO(Vimal): Write test code for this +void MergeAdjacentSegments(int32 max_intersegment_length, + Segmentation *segmentation) { + SegmentList::iterator it = segmentation->Begin(), + prev_it = segmentation->Begin(); + + while (it != segmentation->End()) { + KALDI_ASSERT(it->start_frame >= prev_it->start_frame); + + if (it != segmentation->Begin() && + it->Label() == prev_it->Label() && + prev_it->end_frame + max_intersegment_length + 1 >= it->start_frame) { + // merge segments + if (prev_it->end_frame < it->end_frame) { + // If the previous segment end before the current segment, then + // extend the previous segment to the end_frame of the current + // segment and remove the current segment. + prev_it->end_frame = it->end_frame; + } // else simply remove the current segment. + it = segmentation->Erase(it); + } else { + // no merging of segments + prev_it = it; + ++it; + } + } + +#ifdef KALDI_PARANOID + segmentation->Check(); +#endif +} + +void PadSegments(int32 label, int32 length, Segmentation *segmentation) { + KALDI_ASSERT(segmentation); + for (SegmentList::iterator it = segmentation->Begin(); + it != segmentation->End(); ++it) { + if (it->Label() != label) continue; + + it->start_frame -= length; + it->end_frame += length; + + if (it->start_frame < 0) it->start_frame = 0; + } +} + +void WidenSegments(int32 label, int32 length, Segmentation *segmentation) { + for (SegmentList::iterator it = segmentation->Begin(); + it != segmentation->End(); ++it) { + if (it->Label() == label) { + if (it != segmentation->Begin()) { + // it is not the beginning of the segmentation, so we can widen it on + // the start_frame side + SegmentList::iterator prev_it = it; + --prev_it; + it->start_frame -= length; + if (prev_it->Label() == label && it->start_frame < prev_it->end_frame) { + // After widening this segment, it overlaps the previous segment that + // also has the same class_id. Then turn this segment into a composite + // one + it->start_frame = prev_it->start_frame; + // and remove the previous segment from the list. + segmentation->Erase(prev_it); + } else if (prev_it->Label() != label && + it->start_frame < prev_it->end_frame) { + // Previous segment is not the same class_id, so we cannot turn this + // into a composite segment. + if (it->start_frame <= prev_it->start_frame) { + // The extended segment absorbs the previous segment into it + // So remove the previous segment + segmentation->Erase(prev_it); + } else { + // The extended segment reduces the length of the previous + // segment. But does not completely overlap it. + prev_it->end_frame -= length; + if (prev_it->end_frame < prev_it->start_frame) + segmentation->Erase(prev_it); + } + } + if (it->start_frame < 0) it->start_frame = 0; + } else { + it->start_frame -= length; + if (it->start_frame < 0) it->start_frame = 0; + } + + SegmentList::iterator next_it = it; + ++next_it; + + if (next_it != segmentation->End()) + // We do not know the length of the file. + // So we don't want to extend the last one. + it->end_frame += length; // Line (1) + } else { // if (it->Label() != label) + if (it != segmentation->Begin()) { + SegmentList::iterator prev_it = it; + --prev_it; + if (prev_it->end_frame >= it->end_frame) { + // The extended previous segment in Line (1) completely + // overlaps the current segment. So remove the current segment. + it = segmentation->Erase(it); + // So that we can increment in the for loop + --it; // TODO(Vimal): This is buggy. + } else if (prev_it->end_frame >= it->start_frame) { + // The extended previous segment in Line (1) reduces the length of + // this segment. + it->start_frame = prev_it->end_frame + 1; + } + } + } + } +} + +void ShrinkSegments(int32 label, int32 length, Segmentation *segmentation) { + for (SegmentList::iterator it = segmentation->Begin(); + it != segmentation->End(); ) { + if (it->Label() == label) { + if (it->Length() <= 2 * length) { + it = segmentation->Erase(it); + } else { + it->start_frame += length; + it->end_frame -= length; + ++it; + } + } else { + ++it; + } + } + +#ifdef KALDI_PARANOID + segmentation->Check(); +#endif +} + +void BlendShortSegmentsWithNeighbors(int32 label, int32 max_length, + int32 max_intersegment_length, + Segmentation *segmentation) { + for (SegmentList::iterator it = segmentation->Begin(); + it != segmentation->End(); ) { + if (it == segmentation->Begin()) { + // Can't blend the first segment + ++it; + continue; + } + + SegmentList::iterator next_it = it; + ++next_it; + + if (next_it == segmentation->End()) // End of segmentation + break; + + SegmentList::iterator prev_it = it; + --prev_it; + + // If the previous and current segments have different labels, + // then ensure that they are not overlapping + KALDI_ASSERT(it->start_frame >= prev_it->start_frame && + (prev_it->Label() == it->Label() || + prev_it->end_frame < it->start_frame)); + + KALDI_ASSERT(next_it->start_frame >= it->start_frame && + (it->Label() == next_it->Label() || + it->end_frame < next_it->start_frame)); + + if (next_it->Label() != prev_it->Label() || it->Label() != label || + it->Length() >= max_length || + next_it->start_frame - it->end_frame - 1 > max_intersegment_length || + it->start_frame - prev_it->end_frame - 1 > max_intersegment_length) { + ++it; + continue; + } + + prev_it->end_frame = next_it->end_frame; + segmentation->Erase(it); + it = segmentation->Erase(next_it); + } +#ifdef KALDI_PARANOID + segmentation->Check(); +#endif +} + +bool ConvertToAlignment(const Segmentation &segmentation, + int32 default_label, int32 length, + int32 tolerance, + std::vector *alignment) { + KALDI_ASSERT(alignment); + alignment->clear(); + + if (length != -1) { + KALDI_ASSERT(length >= 0); + alignment->resize(length, default_label); + } + + SegmentList::const_iterator it = segmentation.Begin(); + for (; it != segmentation.End(); ++it) { + if (length != -1 && it->end_frame >= length + tolerance) { + KALDI_WARN << "End frame (" << it->end_frame << ") " + << ">= length (" << length + << ") + tolerance (" << tolerance << ")." + << "Conversion failed."; + return false; + } + + int32 end_frame = it->end_frame; + if (length == -1) { + alignment->resize(it->end_frame + 1, default_label); + } else { + if (it->end_frame >= length) + end_frame = length - 1; + } + + KALDI_ASSERT(end_frame < alignment->size()); + for (int32 i = it->start_frame; i <= end_frame; i++) { + (*alignment)[i] = it->Label(); + } + } + return true; +} + +int32 InsertFromAlignment(const std::vector &alignment, + int32 start, int32 end, + int32 start_time_offset, + Segmentation *segmentation, + std::map *frame_counts_per_class) { + KALDI_ASSERT(segmentation); + + if (end <= start) return 0; // nothing to insert + + // Correct boundaries + if (end > alignment.size()) end = alignment.size(); + if (start < 0) start = 0; + + KALDI_ASSERT(end > start); // This is possible if end was originally + // greater than alignment.size(). + // The user must resize alignment appropriately + // before passing to this function. + + int32 num_segments = 0; + int32 state = -100, start_frame = -1; + for (int32 i = start; i < end; i++) { + KALDI_ASSERT(alignment[i] >= -1); + if (alignment[i] != state) { + // Change of state i.e. a different class id. + // So the previous segment has ended. + if (start_frame != -1) { + // start_frame == -1 in the beginning of the alignment. That is just + // initialization step and hence no creation of segment. + segmentation->EmplaceBack(start_frame + start_time_offset, + i-1 + start_time_offset, state); + num_segments++; + + if (frame_counts_per_class) + (*frame_counts_per_class)[state] += i - start_frame; + } + start_frame = i; + state = alignment[i]; + } + } + + KALDI_ASSERT(state >= -1 && start_frame >= 0 && start_frame < end); + segmentation->EmplaceBack(start_frame + start_time_offset, + end-1 + start_time_offset, state); + num_segments++; + if (frame_counts_per_class) + (*frame_counts_per_class)[state] += end - start_frame; + +#ifdef KALDI_PARANOID + segmentation->Check(); +#endif + + return num_segments; +} + +int32 InsertFromSegmentation( + const Segmentation &in_segmentation, int32 start_time_offset, + bool sort, + Segmentation *out_segmentation, + std::vector *frame_counts_per_class) { + KALDI_ASSERT(out_segmentation); + + if (in_segmentation.Dim() == 0) return 0; // nothing to insert + + int32 num_segments = 0; + + for (SegmentList::const_iterator it = in_segmentation.Begin(); + it != in_segmentation.End(); ++it) { + out_segmentation->EmplaceBack(it->start_frame + start_time_offset, + it->end_frame + start_time_offset, + it->Label()); + num_segments++; + if (frame_counts_per_class) { + if (frame_counts_per_class->size() <= it->Label()) { + frame_counts_per_class->resize(it->Label() + 1, 0); + } + (*frame_counts_per_class)[it->Label()] += it->Length(); + } + } + + if (sort) out_segmentation->Sort(); + +#ifdef KALDI_PARANOID + out_segmentation->Check(); +#endif + + return num_segments; +} + +void ExtendSegmentation(const Segmentation &in_segmentation, + bool sort, + Segmentation *segmentation) { + InsertFromSegmentation(in_segmentation, 0, sort, segmentation, NULL); +} + +bool GetClassCountsPerFrame( + const Segmentation &segmentation, + int32 length, int32 tolerance, + std::vector > *class_counts_per_frame) { + KALDI_ASSERT(class_counts_per_frame); + + if (length != -1) { + KALDI_ASSERT(length >= 0); + class_counts_per_frame->resize(length, std::map()); + } + + SegmentList::const_iterator it = segmentation.Begin(); + for (; it != segmentation.End(); ++it) { + if (length != -1 && it->end_frame >= length + tolerance) { + KALDI_WARN << "End frame (" << it->end_frame << ") " + << ">= length + tolerance (" << length + tolerance << ")." + << "Conversion failed."; + return false; + } + + int32 end_frame = it->end_frame; + if (length == -1) { + class_counts_per_frame->resize(it->end_frame + 1, + std::map()); + } else { + if (it->end_frame >= length) + end_frame = length - 1; + } + + KALDI_ASSERT(end_frame < class_counts_per_frame->size()); + for (int32 i = it->start_frame; i <= end_frame; i++) { + std::map &this_class_counts = (*class_counts_per_frame)[i]; + std::map::iterator c_it = this_class_counts.lower_bound( + it->Label()); + if (c_it == this_class_counts.end() || it->Label() < c_it->first) { + this_class_counts.insert(c_it, std::make_pair(it->Label(), 1)); + } else { + c_it->second++; + } + } + } + + return true; +} + +bool IsNonOverlapping(const Segmentation &segmentation) { + std::vector vec; + for (SegmentList::const_iterator it = segmentation.Begin(); + it != segmentation.End(); ++it) { + vec.resize(it->end_frame + 1, false); + for (int32 i = it->start_frame; i <= it->end_frame; i++) { + if (vec[i]) return false; + vec[i] = true; + } + } + return true; +} + +void Sort(Segmentation *segmentation) { + segmentation->Sort(); +} + +void TruncateToLength(int32 length, Segmentation *segmentation) { + for (SegmentList::iterator it = segmentation->Begin(); + it != segmentation->End(); ) { + if (it->start_frame >= length) { + it = segmentation->Erase(it); + continue; + } + + if (it->end_frame >= length) + it->end_frame = length - 1; + ++it; + } +} + +} // end namespace segmenter +} // end namespace kaldi diff --git a/src/segmenter/segmentation-utils.h b/src/segmenter/segmentation-utils.h new file mode 100644 index 00000000000..16e63710c6a --- /dev/null +++ b/src/segmenter/segmentation-utils.h @@ -0,0 +1,344 @@ +// segmenter/segmentation-utils.h + +// Copyright 2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_SEGMENTER_SEGMENTATION_UTILS_H_ +#define KALDI_SEGMENTER_SEGMENTATION_UTILS_H_ + +#include "segmenter/segmentation.h" + +namespace kaldi { +namespace segmenter { + +/** + * This function is very straight forward. It just merges the labels in + * merge_labels to the class-id dest_label. This means any segment that + * originally had the class-id as any of the labels in merge_labels would end + * up having the class-id dest_label. + **/ +void MergeLabels(const std::vector &merge_labels, + int32 dest_label, Segmentation *segmentation); + +// Relabel segments using a map from old to new label. +// If segment label is not found in the map, the function exits with +// an error. +void RelabelSegmentsUsingMap(const unordered_map &label_map, + Segmentation *segmentation); + +// Relabel all segments to class-id label +void RelabelAllSegments(int32 label, Segmentation *segmentation); + +// Scale frame shift by this factor. +// Usually frame length is 0.01 and frame shift 0.015. But sometimes +// the alignments are obtained using a subsampling factor of 3. This +// function can be used to maintain consistency among different +// alignments and segmentations. +void ScaleFrameShift(BaseFloat factor, Segmentation *segmentation); + +/** + * This is very straight forward. It removes all segments of label "label" +**/ +void RemoveSegments(int32 label, Segmentation *segmentation); + +/** + * This removes any segment whose label is + * contained in the vector "labels" and has a length smaller than + * max_remove_length. max_remove_length can be provided -1 to + * specify a value of +infinity i.e. to remove segments + * based on only the labels and irrespective of their lengths. +**/ +void RemoveSegments(const std::vector &labels, + int32 max_remove_length, + Segmentation *segmentation); + +void RemoveShortSegments(int32 label, int32 min_length, + Segmentation *segmentation); + +// Keep only segments of label "label" +void KeepSegments(int32 label, Segmentation *segmentation); + +/** + * This function splits an input segmentation in_segmentation into pieces of + * approximately segment_length. Each piece is given the same class id as the + * original segment. + * + * The way this function is written is that it first figures out the number of + * pieces that the segment must be broken into. Then it creates that many pieces + * of equal size (actual_segment_length). This mimics some of the approaches + * used at script level +**/ +void SplitInputSegmentation(const Segmentation &in_segmentation, + int32 segment_length, + Segmentation *out_segmentation); + +/** + * This function splits the segments in the the segmentation + * into pieces of segment_length. + * But if the last remaining piece is smaller than min_remainder, then the last + * piece is merged to the piece before it, resulting in a piece that is of + * length < segment_length + min_remainder. + * If overlap_length > 0, then the created pieces overlap by these many frames. + * If segment_label == -1, then all segments are split. + * Otherwise, only the segments with this label are split. + * + * The way this function works it is it looks at the current segment length and + * checks if it is larger than segment_length + min_remainder. If it is larger, + * then it must be split. To do this, it first modifies the start_frame of + * the current frame to start_frame + segment_length - overlap. + * It then creates a new segment of length segment_length from the original + * start_frame to start_frame + segment_length - 1 and adds it just before the + * current segment. So in the next iteration, we would actually be back to the + * same segment, but whose start_frame had just been modified. +**/ +void SplitSegments(int32 segment_length, + int32 min_remainder, int32 overlap_length, + int32 segment_label, + Segmentation *segmentation); + +/** + * Split this segmentation into pieces of size segment_length, + * but only if possible by creating split points at the + * middle of the chunk where alignment == ali_label and + * the chunk is at least min_segment_length frames long + * + * min_remainder, segment_label serve the same purpose as in the + * above SplitSegments function. +**/ +void SplitSegmentsUsingAlignment(int32 segment_length, + int32 segment_label, + const std::vector &alignment, + int32 alignment_label, + int32 min_align_chunk_length, + Segmentation *segmentation); + +/** + * This function is a standard intersection of the set of times represented by + * the segmentation in_segmentation and the set of times of where + * alignment contains ali_label for at least min_align_chunk_length + * consecutive frames +**/ +void IntersectSegmentationAndAlignment(const Segmentation &in_segmentation, + const std::vector &alignment, + int32 ali_label, + int32 min_align_chunk_length, + Segmentation *out_segmentation); + +/** + * This function is a little complicated in what it does. But this is required + * for one of the applications. + * This function creates a new segmentation by sub-segmenting an arbitrary + * "primary_segmentation" and assign new label "subsegment_label" to regions + * where the "primary_segmentation" intersects the non-overlapping + * "secondary_segmentation" segments with label "secondary_label". + * This is similar to the function "IntersectSegments", but instead of keeping + * only the filtered subsegments, all the subsegments are kept, while only + * changing the class_id of the filtered sub-segments. + * The label for the newly created subsegments is determined as follows: + * if secondary segment's label == secondary_label: + * if subsegment_label >= 0: + * label = subsegment_label + * else: + * label = secondary_label + * else: + * if unmatched_label >= 0: + * label = unmatched_label + * else: + * label = primary_label +**/ +void SubSegmentUsingNonOverlappingSegments( + const Segmentation &primary_segmentation, + const Segmentation &secondary_segmentation, int32 secondary_label, + int32 subsegment_label, int32 unmatched_label, + Segmentation *out_segmentation); + +/** + * This function is used to merge segments next to each other in the SegmentList + * and within a distance of max_intersegment_length frames from each other, + * provided the segments are of the same label. + * This function requires the segmentation to be sorted before passing it. + **/ +void MergeAdjacentSegments(int32 max_intersegment_length, + Segmentation *segmentation); + +/** + * This function is used to pad segments of label "label" by "length" + * frames on either side of the segment. + * This is useful to pad segments of speech. +**/ +void PadSegments(int32 label, int32 length, Segmentation *segmentation); + +/** + * This function is used to widen segments of label "label" by "length" + * frames on either side of the segment. + * This is similar to PadSegments, but while widening, it also reduces the + * length of the segment adjacent to it. + * This may not be required in some applications, but it is ok for speech / + * silence. By this process, we are calling frames within a "length" number of + * frames near the speech segment as speech and hence we reduce the width of the + * silence segment before it. +**/ +void WidenSegments(int32 label, int32 length, Segmentation *segmentation); + +/** + * This function is used to shrink segments of class_id "label" by "length" + * frames on either side of the segment. + * If the whole segment is smaller than 2*length, then the segment is + * removed entirely. +**/ +void ShrinkSegments(int32 label, int32 length, Segmentation *segmentation); + +/** + * This function blends segments of label "label" that are shorter than + * "max_length" frames, provided the segments before and after it are of the + * same label "other_label" and the distance to the neighbor is less than + * "max_intersegment_distance". + * After blending, the three segments have the same label "other_label" and + * hence can be merged into a composite segment. + * An example where this is useful is when there is a short segment of silence + * with speech segments on either sides. Then the short segment of silence is + * removed and called speech instead. The three continguous segments of speech + * are merged into a single composite segment. +**/ +void BlendShortSegmentsWithNeighbors(int32 label, int32 max_length, + int32 max_intersegment_distance, + Segmentation *segmentation); + +/** + * This function is used to convert the segmentation into frame-level alignment + * with the label for each frame begin the class_id of segment the frame belongs + * to. + * The arguments are used to provided extended functionality that are required + * for most cases. + * default_label : the label that is used as filler in regions where the frame + * is not in any of the segments. In most applications, certain + * segments are removed, such as the ones that are silence. Then + * the segments would not span the entire duration of the file. + * e.g. + * 10 35 1 + * 41 190 2 + * ... + * Here there is no segment from 36-40. These frames are + * filled with default_label. + * length : the number of frames required in the alignment. + * If set to -1, then this length is ignored. + * In most applications, the length of the alignment required is + * known. Usually it must match the length of the features + * (obtained using feat-to-len). Then the alignment is resized + * to this length and filled with default_label. The segments + * are then read and the frames corresponding to the segments + * are relabeled with the class_id of the respective segments. + * tolerance : the tolerance in number of frames that we allow for the + * frame index corresponding to the end_frame of the last + * segment. Applicable when length != -1. + * Since, we use 25 ms widows with 10 ms frame shift, + * it is possible that the features length is 2 frames less than + * the end of the last segment. So the user can set the + * tolerance to 2 in order to avoid returning with error in this + * function. + * Function returns true is successful. +**/ +bool ConvertToAlignment(const Segmentation &segmentation, + int32 default_label, int32 length, + int32 tolerance, + std::vector *alignment); + +/** + * Insert segments created from alignment starting from frame index "start" + * until and excluding frame index "end". + * The inserted segments are shifted by "start_time_offset". + * "start_time_offset" is useful when the "alignment" is per-utterance, in which + * case the start time of the utterance can be provided as the + * "start_time_offset" + * The function returns the number of segments created. + * If "frame_counts_per_class" is provided, then the number of frames per class + * is accumulated there. +**/ +int32 InsertFromAlignment(const std::vector &alignment, + int32 start, int32 end, + int32 start_time_offset, + Segmentation *segmentation, + std::map *frame_counts_per_class = NULL); + +/** + * Insert segments from in_segmentation, but shift them by + * start_time offset. + * If sort is true, then the final segmentation is sorted. + * It is useful in some applications to set sort to false. + * Returns number of segments inserted. +**/ +int32 InsertFromSegmentation(const Segmentation &in_segmentation, + int32 start_time_offset, bool sort, + Segmentation *segmentation, + std::vector *frame_counts_per_class = NULL); + +/** + * Extend a segmentation by adding another one. + * If "sort" is set to true, then resultant segmentation would be sorted. + * If its known that the other segmentation must all be after this segmentation, + * then the user may set "sort" false. +**/ +void ExtendSegmentation(const Segmentation &in_segmentation, bool sort, + Segmentation *segmentation); + +/** + * This function is used to get per-frame count of number of classes. + * The output is in the format of a vector of maps. + * class_counts_per_frame: A pointer to a vector of maps used to get the output. + * The size of the vector is the number of frames. + * For each frame, there is a map from the "class_id" + * to the number of segments where the label the + * corresponding "class_id". + * The size of the map gives the number of unique + * labels in this frame e.g. number of speakers. + * The count for each "class_id" is the number + * of segments with that "class_id" at that frame. + * length : the number of frames required in the output. + * In most applications, this length is known. + * Usually it must match the length of the features (obtained + * using feat-to-len). Then the output is resized to this + * length. The map is empty for frames where no segments are + * seen. + * tolerance : the tolerance in number of frames that we allow for the + * frame index corresponding to the end_frame of the last + * segment. Since, we use 25 ms widows with 10 ms frame shift, + * it is possible that the features length is 2 frames less than + * the end of the last segment. So the user can set the + * tolerance to 2 in order to avoid returning an error in this + * function. + * Function returns true is successful. +**/ +bool GetClassCountsPerFrame( + const Segmentation &segmentation, + int32 length, int32 tolerance, + std::vector > *class_counts_per_frame); + +// Checks if segmentation is non-overlapping +bool IsNonOverlapping(const Segmentation &segmentation); + +// Sorts segments on start frame. +void Sort(Segmentation *segmentation); + +// Truncate segmentation to "length". +// Removes any segments with "start_time" >= "length" +// and truncates any segments with "end_time" >= "length" +void TruncateToLength(int32 length, Segmentation *segmentation); + +} // end namespace segmenter +} // end namespace kaldi + +#endif // KALDI_SEGMENTER_SEGMENTATION_UTILS_H_ diff --git a/src/segmenter/segmentation.cc b/src/segmenter/segmentation.cc new file mode 100644 index 00000000000..01f8b0e8057 --- /dev/null +++ b/src/segmenter/segmentation.cc @@ -0,0 +1,217 @@ +// segmenter/segmentation.cc + +// Copyright 2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "segmenter/segmentation.h" +#include + +namespace kaldi { +namespace segmenter { + +void Segmentation::PushBack(const Segment &seg) { + dim_++; + segments_.push_back(seg); +} + +SegmentList::iterator Segmentation::Insert(SegmentList::iterator it, + const Segment &seg) { + dim_++; + return segments_.insert(it, seg); +} + +void Segmentation::EmplaceBack(int32 start_frame, int32 end_frame, + int32 class_id) { + dim_++; + Segment seg(start_frame, end_frame, class_id); + segments_.push_back(seg); +} + +SegmentList::iterator Segmentation::Emplace(SegmentList::iterator it, + int32 start_frame, int32 end_frame, + int32 class_id) { + dim_++; + Segment seg(start_frame, end_frame, class_id); + return segments_.insert(it, seg); +} + +SegmentList::iterator Segmentation::Erase(SegmentList::iterator it) { + dim_--; + return segments_.erase(it); +} + +void Segmentation::Clear() { + segments_.clear(); + dim_ = 0; +} + +void Segmentation::Read(std::istream &is, bool binary) { + Clear(); + + if (binary) { + int32 sz = is.peek(); + if (sz == Segment::SizeInBytes()) { + is.get(); + } else { + KALDI_ERR << "Segmentation::Read: expected to see Segment of size " + << Segment::SizeInBytes() << ", saw instead " << sz + << ", at file position " << is.tellg(); + } + + int32 segmentssz; + is.read(reinterpret_cast(&segmentssz), sizeof(segmentssz)); + if (is.fail() || segmentssz < 0) + KALDI_ERR << "Segmentation::Read: read failure at file position " + << is.tellg(); + + for (int32 i = 0; i < segmentssz; i++) { + Segment seg; + seg.Read(is, binary); + segments_.push_back(seg); + } + dim_ = segmentssz; + } else { + Segment seg; + while (1) { + int i = is.peek(); + if (i == -1) { + KALDI_ERR << "Unexpected EOF"; + } else if (static_cast(i) == '\n') { + if (seg.start_frame != -1) { + KALDI_ERR << "No semicolon before newline (wrong format)"; + } else { + is.get(); + break; + } + } else if (std::isspace(i)) { + is.get(); + } else if (static_cast(i) == ';') { + if (seg.start_frame != -1) { + segments_.push_back(seg); + dim_++; + seg.Reset(); + } else { + is.get(); + KALDI_ASSERT(static_cast(is.peek()) == '\n'); + is.get(); + break; + } + is.get(); + } else { + seg.Read(is, false); + } + } + } +#ifdef KALDI_PARANOID + Check(); +#endif +} + +void Segmentation::Write(std::ostream &os, bool binary) const { +#ifdef KALDI_PARANOID + Check(); +#endif + + SegmentList::const_iterator it = Begin(); + if (binary) { + char sz = Segment::SizeInBytes(); + os.write(&sz, 1); + + int32 segmentssz = static_cast(Dim()); + KALDI_ASSERT((size_t)segmentssz == Dim()); + + os.write(reinterpret_cast(&segmentssz), sizeof(segmentssz)); + + for (; it != End(); ++it) { + it->Write(os, binary); + } + } else { + if (Dim() == 0) { + os << ";"; + } + for (; it != End(); ++it) { + it->Write(os, binary); + os << "; "; + } + os << std::endl; + } +} + +void Segmentation::Check() const { + int32 dim = 0; + for (SegmentList::const_iterator it = Begin(); it != End(); ++it, dim++) { + KALDI_ASSERT(it->start_frame >= 0); + KALDI_ASSERT(it->end_frame >= 0); + KALDI_ASSERT(it->Label() >= 0); + } + KALDI_ASSERT(dim == dim_); +} + +void Segmentation::Sort() { + segments_.sort(SegmentComparator()); +} + +void Segmentation::SortByLength() { + segments_.sort(SegmentLengthComparator()); +} + +SegmentList::iterator Segmentation::MinElement() { + return std::min_element(segments_.begin(), segments_.end(), + SegmentLengthComparator()); +} + +SegmentList::iterator Segmentation::MaxElement() { + return std::max_element(segments_.begin(), segments_.end(), + SegmentLengthComparator()); +} + +Segmentation::Segmentation() { + Clear(); +} + + +void Segmentation::GenRandomSegmentation(int32 max_length, + int32 max_segment_length, + int32 num_classes) { + Clear(); + int32 st = 0; + int32 end = 0; + + while (st < max_length) { + int32 segment_length = RandInt(1, max_segment_length); + + end = st + segment_length - 1; + + // Choose random class id + int32 k = RandInt(-1, num_classes - 1); + + if (k >= 0) { + Segment seg(st, end, k); + segments_.push_back(seg); + dim_++; + } + + // Choose random shift i.e. the distance between two adjacent segments + int32 shift = RandInt(0, max_segment_length); + st = end + shift; + } + + Check(); +} + +} // namespace segmenter +} // namespace kaldi diff --git a/src/segmenter/segmentation.h b/src/segmenter/segmentation.h new file mode 100644 index 00000000000..aa408374751 --- /dev/null +++ b/src/segmenter/segmentation.h @@ -0,0 +1,144 @@ +// segmenter/segmentation.h + +// Copyright 2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_SEGMENTER_SEGMENTATION_H_ +#define KALDI_SEGMENTER_SEGMENTATION_H_ + +#include +#include "base/kaldi-common.h" +#include "matrix/kaldi-matrix.h" +#include "util/kaldi-table.h" +#include "segmenter/segment.h" + +namespace kaldi { +namespace segmenter { + +// Segments are stored as a doubly-linked-list. This could be changed later +// if needed. Hence defining a typedef SegmentList. +typedef std::list SegmentList; + +// Declare class +class SegmentationPostProcessor; + +/** + * The main class to store segmentation and do operations on it. The segments + * are stored in the structure SegmentList, which is currently a doubly-linked + * list. + * See the .cc file for details of implementation of the different functions. + * This file gives only a small description of the functions. +**/ + +class Segmentation { + public: + // Inserts the segment at the back of the list. + void PushBack(const Segment &seg); + + // Inserts the segment before the segment at the position specified by the + // iterator "it". + SegmentList::iterator Insert(SegmentList::iterator it, + const Segment &seg); + + // The following function is a wrapper to the + // emplace_back functionality of a STL list of Segments + // and inserts a new segment to the back of the list. + void EmplaceBack(int32 start_frame, int32 end_frame, int32 class_id); + + // The following function is a wrapper to the + // emplace functionality of a STL list of segments + // and inserts a segment at the position specified by the iterator "it". + // Returns an iterator to the inserted segment. + SegmentList::iterator Emplace(SegmentList::iterator it, + int32 start_frame, int32 end_frame, + int32 class_id); + + // Call erase operation on the SegmentList and returns the iterator pointing + // to the next segment in the SegmentList and also decrements dim_. + SegmentList::iterator Erase(SegmentList::iterator it); + + // Reset segmentation i.e. clear all values + void Clear(); + + // Read segmentation object from input stream + void Read(std::istream &is, bool binary); + + // Write segmentation object to output stream + void Write(std::ostream &os, bool binary) const; + + // Check if all segments have class_id >=0 and if dim_ matches the number of + // segments. + void Check() const; + + // Sort the segments on the start_frame + void Sort(); + + // Sort the segments on the length + void SortByLength(); + + // Returns an iterator to the smallest segment akin to std::min_element + SegmentList::iterator MinElement(); + + // Returns an iterator to the largest segment akin to std::max_element + SegmentList::iterator MaxElement(); + + // Generate a random segmentation for debugging purposes. + // Arguments: + // max_length: The maximum length of the random segmentation to be + // generated. + // max_segment_length: Maximum length of a segment in the segmentation + // num_classes: Maximum number of classes in the generated segmentation + void GenRandomSegmentation(int32 max_length, int32 max_segment_length, + int32 num_classes); + + // Public accessors + inline int32 Dim() const { return dim_; } + SegmentList::iterator Begin() { return segments_.begin(); } + SegmentList::const_iterator Begin() const { return segments_.begin(); } + SegmentList::iterator End() { return segments_.end(); } + SegmentList::const_iterator End() const { return segments_.end(); } + + Segment& Back() { return segments_.back(); } + const Segment& Back() const { return segments_.back(); } + + const SegmentList* Data() const { return &segments_; } + + // Default constructor + Segmentation(); + + private: + // number of segments in the segmentation + int32 dim_; + + // list of segments in the segmentation + SegmentList segments_; + + friend class SegmentationPostProcessor; +}; + +typedef TableWriter > SegmentationWriter; +typedef SequentialTableReader > + SequentialSegmentationReader; +typedef RandomAccessTableReader > + RandomAccessSegmentationReader; +typedef RandomAccessTableReaderMapped > + RandomAccessSegmentationReaderMapped; + +} // end namespace segmenter +} // end namespace kaldi + +#endif // KALDI_SEGMENTER_SEGMENTATION_H_ diff --git a/src/segmenterbin/Makefile b/src/segmenterbin/Makefile new file mode 100644 index 00000000000..6e0036c6fb7 --- /dev/null +++ b/src/segmenterbin/Makefile @@ -0,0 +1,42 @@ + +all: + +EXTRA_CXXFLAGS = -Wno-sign-compare +include ../kaldi.mk + +BINFILES = segmentation-copy segmentation-get-stats \ + segmentation-init-from-ali segmentation-to-ali \ + segmentation-init-from-segments segmentation-to-segments \ + segmentation-combine-segments segmentation-merge-recordings \ + segmentation-create-subsegments segmentation-intersect-ali \ + segmentation-to-rttm segmentation-post-process \ + segmentation-merge segmentation-split-segments \ + segmentation-remove-segments \ + segmentation-init-from-lengths \ + segmentation-combine-segments-to-recordings \ + segmentation-create-overlapped-subsegments \ + segmentation-intersect-segments \ + segmentation-init-from-additive-signals-info \ + class-counts-per-frame-to-labels \ + agglomerative-cluster-ib \ + intersect-int-vectors \ + gmm-global-init-models-from-feats \ + segmentation-cluster-adjacent-segments \ + ib-scoring-dense #\ + gmm-acc-pdf-stats-segmentation \ + gmm-est-segmentation gmm-update-segmentation \ + segmentation-init-from-diarization \ + segmentation-compute-class-ctm-conf \ + combine-vector-segments + +OBJFILES = + + + +TESTFILES = + +ADDLIBS = ../hmm/kaldi-hmm.a ../gmm/kaldi-gmm.a ../segmenter/kaldi-segmenter.a ../tree/kaldi-tree.a \ + ../util/kaldi-util.a ../matrix/kaldi-matrix.a ../base/kaldi-base.a ../thread/kaldi-thread.a + +include ../makefiles/default_rules.mk + diff --git a/src/segmenterbin/agglomerative-cluster-ib.cc b/src/segmenterbin/agglomerative-cluster-ib.cc new file mode 100644 index 00000000000..489b24c24bc --- /dev/null +++ b/src/segmenterbin/agglomerative-cluster-ib.cc @@ -0,0 +1,160 @@ +// segmenterbin/agglomerative-cluster-ib.cc + +// Copyright 2017 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "tree/cluster-utils.h" +#include "segmenter/information-bottleneck-cluster-utils.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + + const char *usage = + "Cluster per-utterance probability distributions of " + "relevance variables using Information Bottleneck principle.\n" + "Usage: agglomerative-cluster-ib [options] " + " \n" + " e.g.: agglomerative-cluster-ib ark:avg_post.1.ark " + "ark,t:data/dev/reco2utt ark,t:labels.txt"; + + ParseOptions po(usage); + + InformationBottleneckClustererOptions opts; + + std::string reco2num_clusters_rspecifier; + std::string counts_rspecifier; + int32 junk_label = -2; + BaseFloat max_merge_thresh = std::numeric_limits::max(); + int32 min_clusters = 1; + + po.Register("reco2num-clusters-rspecifier", &reco2num_clusters_rspecifier, + "If supplied, clustering creates exactly this many clusters " + "for the corresponding recording."); + po.Register("counts-rspecifier", &counts_rspecifier, + "The counts for each of the initial segments. If not specified " + "the count is taken to be 1 for each segment."); + po.Register("junk-label", &junk_label, + "Assign this label to utterances that could not be clustered"); + po.Register("max-merge-thresh", &max_merge_thresh, + "Threshold on cost change from merging clusters; clusters " + "won't be merged if the cost is more than this."); + po.Register("min-clusters", &min_clusters, + "Mininum number of clusters desired; we'll stop merging " + "after reaching this number."); + + opts.Register(&po); + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + std::string relevance_prob_rspecifier = po.GetArg(1), + reco2utt_rspecifier = po.GetArg(2), + label_wspecifier = po.GetArg(3); + + RandomAccessBaseFloatVectorReader relevance_prob_reader( + relevance_prob_rspecifier); + SequentialTokenVectorReader reco2utt_reader(reco2utt_rspecifier); + RandomAccessInt32Reader reco2num_clusters_reader( + reco2num_clusters_rspecifier); + Int32Writer label_writer(label_wspecifier); + RandomAccessBaseFloatReader counts_reader(counts_rspecifier); + + int32 count = 1, num_utt_err = 0, num_reco_err = 0, num_done = 0, + num_reco = 0; + + for (; !reco2utt_reader.Done(); reco2utt_reader.Next()) { + const std::vector &uttlist = reco2utt_reader.Value(); + const std::string &reco = reco2utt_reader.Key(); + + std::vector points; + points.reserve(uttlist.size()); + + int32 id = 0; + for (std::vector::const_iterator it = uttlist.begin(); + it != uttlist.end(); ++it, id++) { + if (!relevance_prob_reader.HasKey(*it)) { + KALDI_WARN << "Could not find relevance probability distribution " + << "for utterance " << *it << " in archive " + << relevance_prob_rspecifier; + num_utt_err++; + continue; + } + + if (!counts_rspecifier.empty()) { + if (!counts_reader.HasKey(*it)) { + KALDI_WARN << "Could not find counts for utterance " << *it; + num_utt_err++; + continue; + } + count = counts_reader.Value(*it); + } + + const Vector& relevance_prob = + relevance_prob_reader.Value(*it); + + points.push_back( + new InformationBottleneckClusterable(id, count, relevance_prob)); + num_done++; + } + + std::vector clusters_out; + std::vector assignments_out; + + int32 this_num_clusters = min_clusters; + + if (!reco2num_clusters_rspecifier.empty()) { + if (!reco2num_clusters_reader.HasKey(reco)) { + KALDI_WARN << "Could not find num-clusters for recording " + << reco; + num_reco_err++; + } else { + this_num_clusters = reco2num_clusters_reader.Value(reco); + } + } + + IBClusterBottomUp(points, opts, max_merge_thresh, this_num_clusters, + NULL, &assignments_out); + + for (int32 i = 0; i < points.size(); i++) { + InformationBottleneckClusterable* point + = static_cast (points[i]); + int32 id = point->Counts().begin()->first; + const std::string &utt = uttlist[id]; + label_writer.Write(utt, assignments_out[i] + 1); + } + + DeletePointers(&points); + num_reco++; + } + + KALDI_LOG << "Clustered " << num_done << " segments from " + << num_reco << " recordings; failed with " + << num_utt_err << " segments and " + << num_reco_err << " recordings."; + + return (num_done > 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} diff --git a/src/segmenterbin/class-counts-per-frame-to-labels.cc b/src/segmenterbin/class-counts-per-frame-to-labels.cc new file mode 100644 index 00000000000..85676794e95 --- /dev/null +++ b/src/segmenterbin/class-counts-per-frame-to-labels.cc @@ -0,0 +1,115 @@ +// segmenterbin/class-counts-per-frame-to-labels.cc + +// Copyright 2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "hmm/posterior.h" + + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + + const char *usage = + "Converts class-counts-per-frame in the format of vectors of vectors of " + "integers into labels for overlapping SAD.\n" + "If there is a junk-label in the classes in the frame, then the label " + "for the frame is set to the junk-label no matter what other labels " + "are present.\n" + "If there is only a 0 (silence) in the classes in the frame, then the " + "label for the frame is set to 0.\n" + "If there is only one non-zero non-junk class, then the label is set " + "to 1.\n" + "Otherwise, the label is set to 2 (overlapping speakers)\n" + "\n" + "Usage: class-counts-per-frame-to-labels [options] " + " \n"; + + int32 junk_label = -1; + ParseOptions po(usage); + + po.Register("junk-label", &junk_label, + "The label used for segments that are junk. If a frame has " + "a junk label, it will be considered junk segment, no matter " + "what other labels the frame contains. Also frames with no " + "classes seen are labeled junk."); + + po.Read(argc, argv); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + std::string in_fn = po.GetArg(1), + out_fn = po.GetArg(2); + + int num_done = 0; + Int32VectorWriter writer(out_fn); + SequentialPosteriorReader reader(in_fn); + for (; !reader.Done(); reader.Next(), num_done++) { + const Posterior &class_counts_per_frame = reader.Value(); + std::vector labels(class_counts_per_frame.size(), junk_label); + + for (size_t i = 0; i < class_counts_per_frame.size(); i++) { + const std::vector > &class_counts = + class_counts_per_frame[i]; + + if (class_counts.size() == 0) { + labels[i] = junk_label; + } else { + bool silence_found = false; + std::vector >::const_iterator it = + class_counts.begin(); + int32 class_counts_in_frame = 0; + for (; it != class_counts.end(); ++it) { + KALDI_ASSERT(it->second > 0); + if (it->first == 0) { + silence_found = true; + } else { + class_counts_in_frame += static_cast(it->second); + if (it->first == junk_label) { + labels[i] = junk_label; + break; + } + } + } + + if (class_counts_in_frame == 0) { + KALDI_ASSERT(silence_found); + labels[i] = 0; + } else if (class_counts_in_frame == 1) { + labels[i] = 1; + } else { + labels[i] = 2; + } + } + } + writer.Write(reader.Key(), labels); + } + KALDI_LOG << "Copied " << num_done << " items."; + return (num_done != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + + + diff --git a/src/segmenterbin/gmm-global-init-models-from-feats.cc b/src/segmenterbin/gmm-global-init-models-from-feats.cc new file mode 100644 index 00000000000..c323306df83 --- /dev/null +++ b/src/segmenterbin/gmm-global-init-models-from-feats.cc @@ -0,0 +1,485 @@ +// gmmbin/gmm-global-init-models-from-feats.cc + +// Copyright 2013 Johns Hopkins University (author: Daniel Povey) +// 2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "gmm/model-common.h" +#include "gmm/full-gmm.h" +#include "gmm/diag-gmm.h" +#include "gmm/mle-full-gmm.h" + +namespace kaldi { + +// We initialize the GMM parameters by setting the variance to the global +// variance of the features, and the means to distinct randomly chosen frames. +void InitGmmFromRandomFrames(const MatrixBase &feats, DiagGmm *gmm) { + int32 num_gauss = gmm->NumGauss(), num_frames = feats.NumRows(), + dim = feats.NumCols(); + KALDI_ASSERT(num_frames >= 10 * num_gauss && "Too few frames to train on"); + Vector mean(dim), var(dim); + for (int32 i = 0; i < num_frames; i++) { + mean.AddVec(1.0 / num_frames, feats.Row(i)); + var.AddVec2(1.0 / num_frames, feats.Row(i)); + } + var.AddVec2(-1.0, mean); + if (var.Max() <= 0.0) + KALDI_ERR << "Features do not have positive variance " << var; + + DiagGmmNormal gmm_normal(*gmm); + + std::set used_frames; + for (int32 g = 0; g < num_gauss; g++) { + int32 random_frame = RandInt(0, num_frames - 1); + while (used_frames.count(random_frame) != 0) + random_frame = RandInt(0, num_frames - 1); + used_frames.insert(random_frame); + gmm_normal.weights_(g) = 1.0 / num_gauss; + gmm_normal.means_.Row(g).CopyFromVec(feats.Row(random_frame)); + gmm_normal.vars_.Row(g).CopyFromVec(var); + } + gmm->CopyFromNormal(gmm_normal); + gmm->ComputeGconsts(); +} + +void MleDiagGmmSharedVarsUpdate(const MleDiagGmmOptions &config, + const AccumDiagGmm &diag_gmm_acc, + GmmFlagsType flags, + DiagGmm *gmm, + BaseFloat *obj_change_out, + BaseFloat *count_out, + int32 *floored_elements_out = NULL, + int32 *floored_gaussians_out = NULL, + int32 *removed_gaussians_out = NULL) { + KALDI_ASSERT(gmm != NULL); + + if (flags & ~diag_gmm_acc.Flags()) + KALDI_ERR << "Flags in argument do not match the active accumulators"; + + KALDI_ASSERT(diag_gmm_acc.NumGauss() == gmm->NumGauss() && + diag_gmm_acc.Dim() == gmm->Dim()); + + int32 num_gauss = gmm->NumGauss(); + double occ_sum = diag_gmm_acc.occupancy().Sum(); + + int32 elements_floored = 0, gauss_floored = 0; + + // remember old objective value + gmm->ComputeGconsts(); + BaseFloat obj_old = MlObjective(*gmm, diag_gmm_acc); + + // First get the gmm in "normal" representation (not the exponential-model + // form). + DiagGmmNormal ngmm(*gmm); + + Vector shared_var(gmm->Dim()); + + std::vector to_remove; + for (int32 i = 0; i < num_gauss; i++) { + double occ = diag_gmm_acc.occupancy()(i); + double prob; + if (occ_sum > 0.0) + prob = occ / occ_sum; + else + prob = 1.0 / num_gauss; + + if (occ > static_cast(config.min_gaussian_occupancy) + && prob > static_cast(config.min_gaussian_weight)) { + + ngmm.weights_(i) = prob; + + // copy old mean for later normalizations + Vector old_mean(ngmm.means_.Row(i)); + + // update mean, then variance, as far as there are accumulators + if (diag_gmm_acc.Flags() & (kGmmMeans|kGmmVariances)) { + Vector mean(diag_gmm_acc.mean_accumulator().Row(i)); + mean.Scale(1.0 / occ); + // transfer to estimate + ngmm.means_.CopyRowFromVec(mean, i); + } + + if (diag_gmm_acc.Flags() & kGmmVariances) { + KALDI_ASSERT(diag_gmm_acc.Flags() & kGmmMeans); + Vector var(diag_gmm_acc.variance_accumulator().Row(i)); + var.Scale(1.0 / occ); + var.AddVec2(-1.0, ngmm.means_.Row(i)); // subtract squared means. + + // if we intend to only update the variances, we need to compensate by + // adding the difference between the new and old mean + if (!(flags & kGmmMeans)) { + old_mean.AddVec(-1.0, ngmm.means_.Row(i)); + var.AddVec2(1.0, old_mean); + } + shared_var.AddVec(occ, var); + } + } else { // Insufficient occupancy. + if (config.remove_low_count_gaussians && + static_cast(to_remove.size()) < num_gauss-1) { + // remove the component, unless it is the last one. + KALDI_WARN << "Too little data - removing Gaussian (weight " + << std::fixed << prob + << ", occupation count " << std::fixed << diag_gmm_acc.occupancy()(i) + << ", vector size " << gmm->Dim() << ")"; + to_remove.push_back(i); + } else { + KALDI_WARN << "Gaussian has too little data but not removing it because" + << (config.remove_low_count_gaussians ? + " it is the last Gaussian: i = " + : " remove-low-count-gaussians == false: g = ") << i + << ", occ = " << diag_gmm_acc.occupancy()(i) << ", weight = " << prob; + ngmm.weights_(i) = + std::max(prob, static_cast(config.min_gaussian_weight)); + } + } + } + + if (diag_gmm_acc.Flags() & kGmmVariances) { + int32 floored; + if (config.variance_floor_vector.Dim() != 0) { + floored = shared_var.ApplyFloor(config.variance_floor_vector); + } else { + floored = shared_var.ApplyFloor(config.min_variance); + } + if (floored != 0) { + elements_floored += floored; + gauss_floored++; + } + + shared_var.Scale(1.0 / occ_sum); + for (int32 i = 0; i < num_gauss; i++) { + ngmm.vars_.CopyRowFromVec(shared_var, i); + } + } + + // copy to natural representation according to flags + ngmm.CopyToDiagGmm(gmm, flags); + + gmm->ComputeGconsts(); // or MlObjective will fail. + BaseFloat obj_new = MlObjective(*gmm, diag_gmm_acc); + + if (obj_change_out) + *obj_change_out = (obj_new - obj_old); + if (count_out) *count_out = occ_sum; + if (floored_elements_out) *floored_elements_out = elements_floored; + if (floored_gaussians_out) *floored_gaussians_out = gauss_floored; + + if (to_remove.size() > 0) { + gmm->RemoveComponents(to_remove, true /*renormalize weights*/); + gmm->ComputeGconsts(); + } + if (removed_gaussians_out != NULL) *removed_gaussians_out = to_remove.size(); + + if (gauss_floored > 0) + KALDI_VLOG(2) << gauss_floored << " variances floored in " << gauss_floored + << " Gaussians."; +} + + +void TrainOneIter(const MatrixBase &feats, + const MleDiagGmmOptions &gmm_opts, + int32 iter, + int32 num_threads, + bool share_covars, + DiagGmm *gmm) { + AccumDiagGmm gmm_acc(*gmm, kGmmAll); + + Vector frame_weights(feats.NumRows(), kUndefined); + frame_weights.Set(1.0); + + double tot_like; + tot_like = gmm_acc.AccumulateFromDiagMultiThreaded(*gmm, feats, frame_weights, + num_threads); + + KALDI_LOG << "Likelihood per frame on iteration " << iter + << " was " << (tot_like / feats.NumRows()) << " over " + << feats.NumRows() << " frames."; + + BaseFloat objf_change, count; + if (share_covars) { + MleDiagGmmSharedVarsUpdate(gmm_opts, gmm_acc, kGmmAll, gmm, + &objf_change, &count); + } else { + MleDiagGmmUpdate(gmm_opts, gmm_acc, kGmmAll, gmm, &objf_change, &count); + } + + KALDI_LOG << "Objective-function change on iteration " << iter << " was " + << (objf_change / count) << " over " << count << " frames."; +} + +void TrainGmm(const MatrixBase &feats, + const MleDiagGmmOptions &gmm_opts, + int32 num_gauss, int32 num_gauss_init, int32 num_iters, + int32 num_threads, bool share_covars, DiagGmm *gmm) { + KALDI_LOG << "Initializing GMM means from random frames to " + << num_gauss_init << " Gaussians."; + InitGmmFromRandomFrames(feats, gmm); + + // we'll increase the #Gaussians by splitting, + // till halfway through training. + int32 cur_num_gauss = num_gauss_init, + gauss_inc = (num_gauss - num_gauss_init) / (num_iters / 2); + + for (int32 iter = 0; iter < num_iters; iter++) { + TrainOneIter(feats, gmm_opts, iter, num_threads, share_covars, gmm); + + int32 next_num_gauss = std::min(num_gauss, cur_num_gauss + gauss_inc); + if (next_num_gauss > gmm->NumGauss()) { + KALDI_LOG << "Splitting to " << next_num_gauss << " Gaussians."; + gmm->Split(next_num_gauss, 0.1); + cur_num_gauss = next_num_gauss; + } + } +} + +} // namespace kaldi + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + + const char *usage = + "This program initializes a single diagonal GMM and does multiple iterations of\n" + "training from features stored in memory.\n" + "Usage: gmm-global-init-from-feats [options] \n" + "e.g.: gmm-global-init-from-feats scp:train.scp ark:1.ark\n"; + + ParseOptions po(usage); + MleDiagGmmOptions gmm_opts; + + bool binary = true; + int32 num_gauss = 100; + int32 num_gauss_init = 0; + int32 max_gauss = 0; + int32 min_gauss = 0; + int32 num_iters = 50; + int32 num_frames = 200000; + int32 srand_seed = 0; + int32 num_threads = 4; + BaseFloat num_gauss_fraction = -1; + bool share_covars = false; + std::string spk2utt_rspecifier; + + po.Register("binary", &binary, "Write output in binary mode"); + po.Register("num-gauss", &num_gauss, "Number of Gaussians in the model"); + po.Register("num-gauss-init", &num_gauss_init, "Number of Gaussians in " + "the model initially (if nonzero and less than num_gauss, " + "we'll do mixture splitting)"); + po.Register("num-iters", &num_iters, "Number of iterations of training"); + po.Register("num-frames", &num_frames, "Number of feature vectors to store in " + "memory and train on (randomly chosen from the input features)"); + po.Register("srand", &srand_seed, "Seed for random number generator "); + po.Register("num-threads", &num_threads, "Number of threads used for " + "statistics accumulation"); + po.Register("spk2utt-rspecifier", &spk2utt_rspecifier, + "If specified, estimates models per-speaker"); + po.Register("num-gauss-fraction", &num_gauss_fraction, + "If specified, chooses the number of gaussians to be " + "num-gauss-fraction * min(num-frames-available, num-frames). " + "This number is expected to be in the range(0, 0.1)."); + po.Register("max-gauss", &max_gauss, "Maximum number of Gaussians allowed " + "in the model. Applicable when num_gauss_fraction is specified."); + po.Register("min-gauss", &min_gauss, "Minimum number of Gaussians allowed " + "in the model. Applicable when num_gauss_fraction is specified."); + po.Register("share-covars", &share_covars, "If true, then the variances " + "of the Gaussian components are tied."); + + gmm_opts.Register(&po); + + po.Read(argc, argv); + + srand(srand_seed); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + if (num_gauss_fraction != -1) { + KALDI_ASSERT(num_gauss_fraction > 0 && num_gauss_fraction < 0.1); + } + + KALDI_ASSERT(max_gauss >= 0 && min_gauss >= 0 && max_gauss >= min_gauss); + + std::string feature_rspecifier = po.GetArg(1), + model_wspecifier = po.GetArg(2); + + DiagGmmWriter gmm_writer(model_wspecifier); + + KALDI_ASSERT(num_frames > 0); + + if (spk2utt_rspecifier.empty()) { + KALDI_LOG << "Reading features (will keep " << num_frames << " frames " + << "per utterance.)"; + } else { + KALDI_LOG << "Reading features (will keep " << num_frames << " frames " + << "per speaker.)"; + } + + int32 dim = 0; + + if (spk2utt_rspecifier.empty()) { + SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier); + for (; !feature_reader.Done(); feature_reader.Next()) { + const Matrix &this_feats = feature_reader.Value(); + if (dim == 0) { + dim = this_feats.NumCols(); + } else if (this_feats.NumCols() != dim) { + KALDI_ERR << "Features have inconsistent dims " + << this_feats.NumCols() << " vs. " << dim + << " (current utt is) " << feature_reader.Key(); + } + + Matrix feats(num_frames, dim); + int64 num_read = 0; + + for (int32 t = 0; t < this_feats.NumRows(); t++) { + num_read++; + if (num_read <= num_frames) { + feats.Row(num_read - 1).CopyFromVec(this_feats.Row(t)); + } else { + BaseFloat keep_prob = num_frames / static_cast(num_read); + if (WithProb(keep_prob)) { // With probability "keep_prob" + feats.Row(RandInt(0, num_frames - 1)).CopyFromVec(this_feats.Row(t)); + } + } + } + + KALDI_ASSERT(num_read > 0); + + if (num_read < num_frames) { + KALDI_WARN << "For utterance " << feature_reader.Key() << ", " + << "number of frames read " << num_read << " was less than " + << "target number " << num_frames << ", using all we read."; + feats.Resize(num_read, dim, kCopyData); + } else { + BaseFloat percent = num_frames * 100.0 / num_read; + KALDI_LOG << "For utterance " << feature_reader.Key() << ", " + << "kept " << num_frames << " out of " << num_read + << " input frames = " << percent << "%."; + } + + int32 this_num_gauss_init = num_gauss_init; + int32 this_num_gauss = num_gauss; + + if (num_gauss_fraction != -1) { + this_num_gauss = feats.NumRows() * num_gauss_fraction; + if (this_num_gauss > max_gauss) + this_num_gauss = max_gauss; + if (this_num_gauss < min_gauss) + this_num_gauss = min_gauss; + } + + if (this_num_gauss_init <= 0 || this_num_gauss_init > this_num_gauss) + this_num_gauss_init = this_num_gauss; + + DiagGmm gmm(this_num_gauss_init, dim); + TrainGmm(feats, gmm_opts, this_num_gauss, this_num_gauss_init, + num_iters, num_threads, share_covars, &gmm); + + gmm_writer.Write(feature_reader.Key(), gmm); + } + KALDI_LOG << "Done initializing GMMs."; + } else { + SequentialTokenVectorReader spk2utt_reader(spk2utt_rspecifier); + RandomAccessBaseFloatMatrixReader feature_reader(feature_rspecifier); + + int32 num_err = 0; + for (; !spk2utt_reader.Done(); spk2utt_reader.Next()) { + const std::vector &uttlist = spk2utt_reader.Value(); + + Matrix feats; + int64 num_read = 0; + + for (std::vector::const_iterator it = uttlist.begin(); + it != uttlist.end(); ++it) { + if (!feature_reader.HasKey(*it)) { + KALDI_WARN << "Could not find features for utterance " << *it; + num_err++; + } + + const Matrix &this_feats = feature_reader.Value(*it); + if (feats.NumCols() == 0) { + dim = this_feats.NumCols(); + feats.Resize(num_frames, dim); + } else if (this_feats.NumCols() != dim) { + KALDI_ERR << "Features have inconsistent dims " + << this_feats.NumCols() << " vs. " << dim + << " (current utt is) " << *it; + } + + for (int32 t = 0; t < this_feats.NumRows(); t++) { + num_read++; + if (num_read <= num_frames) { + feats.Row(num_read - 1).CopyFromVec(this_feats.Row(t)); + } else { + BaseFloat keep_prob = num_frames / static_cast(num_read); + if (WithProb(keep_prob)) { // With probability "keep_prob" + feats.Row(RandInt(0, num_frames - 1)).CopyFromVec(this_feats.Row(t)); + } + } + } + } + + KALDI_ASSERT(num_read > 0); + + if (num_read < num_frames) { + KALDI_WARN << "For speaker " << spk2utt_reader.Key() << ", " + << "number of frames read " << num_read << " was less than " + << "target number " << num_frames << ", using all we read."; + feats.Resize(num_read, dim, kCopyData); + } else { + BaseFloat percent = num_frames * 100.0 / num_read; + KALDI_LOG << "For spekear " << spk2utt_reader.Key() << ", " + << "kept " << num_frames << " out of " << num_read + << " input frames = " << percent << "%."; + } + + int32 this_num_gauss_init = num_gauss_init; + int32 this_num_gauss = num_gauss; + + if (num_gauss_fraction != -1) { + this_num_gauss = feats.NumRows() * num_gauss_fraction; + if (this_num_gauss > max_gauss) + this_num_gauss = max_gauss; + if (this_num_gauss < min_gauss) + this_num_gauss = min_gauss; + } + + if (this_num_gauss_init <= 0 || this_num_gauss_init > this_num_gauss) + this_num_gauss_init = this_num_gauss; + + DiagGmm gmm(this_num_gauss_init, dim); + TrainGmm(feats, gmm_opts, this_num_gauss, this_num_gauss_init, + num_iters, num_threads, share_covars, &gmm); + + gmm_writer.Write(spk2utt_reader.Key(), gmm); + } + + KALDI_LOG << "Done initializing GMMs. Failed getting features for " + << num_err << "utterances"; + } + + return 0; + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} diff --git a/src/segmenterbin/intersect-int-vectors.cc b/src/segmenterbin/intersect-int-vectors.cc new file mode 100644 index 00000000000..0611dd513e1 --- /dev/null +++ b/src/segmenterbin/intersect-int-vectors.cc @@ -0,0 +1,160 @@ +// segmenterbin/intersect-int-vectors.cc + +// Copyright 2017 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + + const char *usage = + "Intersect two integer vectors and create a new integer vectors " + "whole ids are defined as the cross-products of the integer " + "ids from the two vectors.\n" + "\n" + "Usage: intersect-int-vectors [options] " + " \n" + " e.g.: intersect-int-vectors ark:1.ali ark:2.ali ark:-\n" + "See also: segmentation-init-from-segments, " + "segmentation-combine-segments\n"; + + ParseOptions po(usage); + + std::string mapping_rxfilename, mapping_wxfilename; + int32 length_tolerance = 0; + + po.Register("mapping-in", &mapping_rxfilename, + "A file with three columns that define the mapping from " + "a pair of integers to a third one."); + po.Register("mapping-out", &mapping_wxfilename, + "Write a mapping in the same format as --mapping-in, " + "but let the program decide the mapping to unique integer " + "ids."); + po.Register("length-tolerance", &length_tolerance, + "Tolerance this number of frames of mismatch between the " + "two integer vector pairs."); + + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + std::string ali_rspecifier1 = po.GetArg(1), + ali_rspecifier2 = po.GetArg(2), + ali_wspecifier = po.GetArg(3); + + std::map, int32> mapping; + if (!mapping_rxfilename.empty()) { + Input ki(mapping_rxfilename); + std::string line; + while (std::getline(ki.Stream(), line)) { + std::vector parts; + SplitStringToVector(line, " ", true, &parts); + KALDI_ASSERT(parts.size() == 3); + + std::pair id_pair = std::make_pair( + std::atoi(parts[0].c_str()), std::atoi(parts[1].c_str())); + int32 id_new = std::atoi(parts[2].c_str()); + KALDI_ASSERT(id_new >= 0); + + std::map, int32>::iterator it = + mapping.lower_bound(id_pair); + KALDI_ASSERT(it == mapping.end() || it->first != id_pair); + + mapping.insert(it, std::make_pair(id_pair, id_new)); + } + } + + SequentialInt32VectorReader ali_reader1(ali_rspecifier1); + RandomAccessInt32VectorReader ali_reader2(ali_rspecifier2); + + Int32VectorWriter ali_writer(ali_wspecifier); + + int32 num_ids = 0, num_err = 0, num_done = 0; + + for (; !ali_reader1.Done(); ali_reader1.Next()) { + const std::string &key = ali_reader1.Key(); + + if (!ali_reader2.HasKey(key)) { + KALDI_WARN << "Could not find second alignment for key " << key + << "in " << ali_rspecifier2; + num_err++; + continue; + } + + const std::vector &alignment1 = ali_reader1.Value(); + const std::vector &alignment2 = ali_reader2.Value(key); + + if (static_cast(alignment1.size()) + - static_cast(alignment2.size()) > length_tolerance) { + KALDI_WARN << "Mismatch in length of alignments in " + << ali_rspecifier1 << " and " << ali_rspecifier2 + << "; " << alignment1.size() << " vs " + << alignment2.size(); + num_err++; + } + + int32 min_length = std::min(static_cast(alignment1.size()), + static_cast(alignment2.size())); + std::vector alignment_out(min_length); + + for (size_t i = 0; i < min_length; i++) { + std::pair id_pair = std::make_pair( + alignment1[i], alignment2[i]); + + std::map, int32>::iterator it = + mapping.lower_bound(id_pair); + + int32 id_new = -1; + if (!mapping_rxfilename.empty()) { + if (it == mapping.end() || it->first != id_pair) { + KALDI_ERR << "Could not find id-pair (" << id_pair.first + << ", " << id_pair.second + << ") in mapping " << mapping_rxfilename; + } + id_new = it->second; + } else { + if (it == mapping.end() || it->first != id_pair) { + id_new = ++num_ids; + mapping.insert(it, std::make_pair(id_pair, id_new)); + } else { + id_new = it->second; + } + } + + alignment_out[i] = id_new; + } + + ali_writer.Write(key, alignment_out); + num_done++; + } + + KALDI_LOG << "Intersected " << num_done << " int vector pairs; " + << "failed with " << num_err; + + return ((num_done > 0 && num_err < num_done) ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/segmentation-cluster-adjacent-segments.cc b/src/segmenterbin/segmentation-cluster-adjacent-segments.cc new file mode 100644 index 00000000000..fde13cd7ead --- /dev/null +++ b/src/segmenterbin/segmentation-cluster-adjacent-segments.cc @@ -0,0 +1,333 @@ +// segmenterbin/segmentation-merge.cc + +// Copyright 2017 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmentation-utils.h" +#include "tree/clusterable-classes.h" + +namespace kaldi { +namespace segmenter { + +BaseFloat Distance(const Segment &seg1, const Segment &seg2, + const MatrixBase &feats, + BaseFloat var_floor, + int32 length_tolerance = 2) { + int32 start1 = seg1.start_frame; + int32 end1 = seg1.end_frame; + + int32 start2 = seg2.start_frame; + int32 end2 = seg2.end_frame; + + if (end1 > feats.NumRows() + length_tolerance) { + KALDI_ERR << "Segment end > feature length; " << end1 + << " vs " << feats.NumRows(); + } + + GaussClusterable stats1(feats.NumCols(), var_floor); + for (int32 i = start1; i < std::min(end1, feats.NumRows()); i++) { + stats1.AddStats(feats.Row(i)); + } + Vector means1(stats1.x_stats()); + means1.Scale(1.0 / stats1.count()); + Vector vars1(stats1.x2_stats()); + vars1.Scale(1.0 / stats1.count()); + vars1.AddVec2(-1.0, means1); + vars1.ApplyFloor(var_floor); + + GaussClusterable stats2(feats.NumCols(), var_floor); + for (int32 i = start2; i < std::min(end2, feats.NumRows()); i++) { + stats2.AddStats(feats.Row(i)); + } + Vector means2(stats2.x_stats()); + means2.Scale(1.0 / stats2.count()); + Vector vars2(stats2.x2_stats()); + vars2.Scale(1.0 / stats2.count()); + vars2.AddVec2(-1.0, means2); + vars2.ApplyFloor(var_floor); + + double ans = 0.0; + for (int32 i = 0; i < feats.NumCols(); i++) { + ans += (vars1(i) / vars2(i) + vars2(i) / vars1(i) + + (means2(i) - means1(i)) * (means2(i) - means1(i)) + * (1.0 / vars1(i) + 1.0 / vars2(i))); + } + + return ans; +} + +int32 ClusterAdjacentSegments(const MatrixBase &feats, + BaseFloat absolute_distance_threshold, + BaseFloat delta_distance_threshold, + BaseFloat var_floor, + int32 length_tolerance, + Segmentation *segmentation) { + if (segmentation->Dim() <= 3) { + // Very unusual case. + // TODO: Do something more reasonable. + return 1; + } + + + SegmentList::iterator it = segmentation->Begin(), + next_it = segmentation->Begin(); + ++next_it; + + // Vector storing for each segment, whether there is a change point at the + // beginning of the segment. + std::vector is_change_point(segmentation->Dim(), false); + is_change_point[0] = true; + + Vector distances(segmentation->Dim() - 1); + int32 i = 0; + + for (; next_it != segmentation->End(); ++it, ++next_it, i++) { + // Distance between segment i and i + 1 + distances(i) = Distance(*it, *next_it, feats, + var_floor, length_tolerance); + + if (i > 2) { + if (distances(i-1) - distances(i-2) > delta_distance_threshold && + distances(i) - distances(i-1) < -delta_distance_threshold) { + is_change_point[i-1] = true; + } + } else { + if (distances(i) - distances(i-1) > absolute_distance_threshold) + is_change_point[i] = true; + } + } + + int32 num_classes = 0; + for (i = 0, it = segmentation->Begin(); + it != segmentation->End(); ++it, i++) { + if (is_change_point[i]) { + num_classes++; + } + it->SetLabel(num_classes); + } + + return num_classes; + /* + BaseFloat prev_dist = Distance(*it, *next_it, feats, + var_floor, length_tolerance); + + if (segmentation->Dim() == 2) { + it->SetLabel(1); + if (prev_dist < absolute_distance_threshold * feats.NumCols() + && next_it->start_frame <= it->end_frame) { + // Similar segments merged. + next_it->SetLabel(it->Label()); + } else { + // Segments not merged. + next_it->SetLabel(it->Label() + 1); + } + + return next_it->Label();; + } + + // The algorithm is a simple peak detection. + // Consider three segments that are pointed by the iterators + // prev_it, it, next_it. + // If Distance(prev_it, it) > Consider + ++it; + ++next_it; + bool next_segment_is_new_cluster = false; + + for (; next_it != segmentation->End(); ++it, ++next_it) { + SegmentList::iterator prev_it(it); + --prev_it; + + // Compute distance between this and next segment. + BaseFloat dist = Distance(*it, *next_it, feats, var_floor, + length_tolerance); + + // Possibly merge current segment if previous. + if (next_segment_is_new_cluster || + (prev_it->end_frame + 1 >= it->start_frame && + prev_dist < absolute_distance_threshold * feats.NumCols())) { + // Previous and current segment are next to each other. + // Merge current segment with previous. + it->SetLabel(prev_it->Label()); + + KALDI_VLOG(3) << "Merging clusters " << *prev_it << " and " << *it + << " ; dist = " << prev_dist; + } else { + it->SetLabel(prev_it->Label() + 1); + KALDI_VLOG(3) << "Not merging merging cluster " << *prev_it + << " and " << *it << " ; dist = " << prev_dist; + } + + // Decide if the current segment must be merged with next. + if (prev_it->end_frame + 1 >= it->start_frame && + it->end_frame + 1 >= next_it->start_frame) { + // All 3 segments are adjacent. + if (dist - prev_dist > delta_distance_threshold * feats.NumCols()) { + // Next segment is very different from the current and previous segment. + // So create a new cluster for the next segment. + next_segment_is_new_cluster = true; + } else { + next_segment_is_new_cluster = false; + } + } + + prev_dist = dist; + } + + SegmentList::iterator prev_it(it); + --prev_it; + if (next_segment_is_new_cluster || + (prev_it->end_frame + 1 >= it->start_frame && + prev_dist < absolute_distance_threshold * feats.NumCols())) { + // Merge current segment with previous. + it->SetLabel(prev_it->Label()); + + KALDI_VLOG(3) << "Merging clusters " << *prev_it << " and " << *it + << " ; dist = " << prev_dist; + } else { + it->SetLabel(prev_it->Label() + 1); + } + + return it->Label(); + */ +} + +} // end segmenter +} // end kaldi + + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Merge adjacent segments that are similar to each other.\n" + "\n" + "Usage: segmentation-cluster-adjacent-segments [options] " + " \n" + " e.g.: segmentation-cluster-adjacent-segments ark:foo.seg ark:feats.ark ark,t:-\n" + "See also: segmentation-merge, segmentation-merge-recordings, " + "segmentation-post-process --merge-labels\n"; + + bool binary = true; + int32 length_tolerance = 2; + BaseFloat var_floor = 0.01; + BaseFloat absolute_distance_threshold = 3.0; + BaseFloat delta_distance_threshold = 0.0002; + + ParseOptions po(usage); + + po.Register("binary", &binary, + "Write in binary mode " + "(only relevant if output is a wxfilename)"); + po.Register("length-tolerance", &length_tolerance, + "Tolerate length difference between segmentation and " + "features if its less than this many frames."); + po.Register("variance-floor", &var_floor, + "Variance floor of Gaussians used in computing distances " + "for clustering."); + po.Register("absolute-distance-threshold", &absolute_distance_threshold, + "Maximum per-dim distance below which segments will not be " + "be merged."); + po.Register("delta-distance-threshold", &delta_distance_threshold, + "If the delta-distance is below this value, then it will " + "be treated as 0."); + + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + std::string segmentation_in_fn = po.GetArg(1), + feats_in_fn = po.GetArg(2), + segmentation_out_fn = po.GetArg(3); + + // all these "fn"'s are either rspecifiers or filenames. + bool in_is_rspecifier = + (ClassifyRspecifier(segmentation_in_fn, NULL, NULL) + != kNoRspecifier), + out_is_wspecifier = + (ClassifyWspecifier(segmentation_out_fn, NULL, NULL, NULL) + != kNoWspecifier); + + if (in_is_rspecifier != out_is_wspecifier) + KALDI_ERR << "Cannot mix regular files and archives"; + + if (!in_is_rspecifier) { + Segmentation segmentation; + ReadKaldiObject(segmentation_in_fn, &segmentation); + + Matrix feats; + ReadKaldiObject(feats_in_fn, &feats); + + Sort(&segmentation); + int32 num_clusters = ClusterAdjacentSegments( + feats, absolute_distance_threshold, delta_distance_threshold, + var_floor, length_tolerance, + &segmentation); + + KALDI_LOG << "Clustered segments; got " << num_clusters << " clusters."; + WriteKaldiObject(segmentation, segmentation_out_fn, binary); + + return 0; + } else { + int32 num_done = 0, num_err = 0; + + SequentialSegmentationReader segmentation_reader(segmentation_in_fn); + RandomAccessBaseFloatMatrixReader feats_reader(feats_in_fn); + SegmentationWriter segmentation_writer(segmentation_out_fn); + + for (; !segmentation_reader.Done(); segmentation_reader.Next()) { + Segmentation segmentation(segmentation_reader.Value()); + const std::string &key = segmentation_reader.Key(); + + if (!feats_reader.HasKey(key)) { + KALDI_WARN << "Could not find key " << key << " in " + << "feats-rspecifier " << feats_in_fn; + num_err++; + continue; + } + + const MatrixBase &feats = feats_reader.Value(key); + + Sort(&segmentation); + int32 num_clusters = ClusterAdjacentSegments( + feats, absolute_distance_threshold, delta_distance_threshold, + var_floor, length_tolerance, + &segmentation); + KALDI_VLOG(2) << "For key " << key << ", got " << num_clusters + << " clusters."; + + segmentation_writer.Write(key, segmentation); + num_done++; + } + + KALDI_LOG << "Clustered segments from " << num_done << " recordings " + << "failed with " << num_err; + return (num_done != 0 ? 0 : 1); + } + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + + diff --git a/src/segmenterbin/segmentation-combine-segments-to-recordings.cc b/src/segmenterbin/segmentation-combine-segments-to-recordings.cc new file mode 100644 index 00000000000..acf71265577 --- /dev/null +++ b/src/segmenterbin/segmentation-combine-segments-to-recordings.cc @@ -0,0 +1,114 @@ +// segmenterbin/segmentation-combine-segments-to-recordings.cc + +// Copyright 2015-16 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmentation-utils.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Combine kaldi segments in segmentation format to " + "recording-level segmentation\n" + "A reco2utt file is used to specify which utterances are contained " + "in a recording.\n" + "This program expects the input segmentation to be a kaldi segment " + "converted to segmentation using segmentation-init-from-segments. " + "For other segmentations, the user can use the binary " + "segmentation-combine-segments instead.\n" + "\n" + "Usage: segmentation-combine-segments-to-recording [options] " + " " + "\n" + " e.g.: segmentation-combine-segments-to-recording \\\n" + "'ark:segmentation-init-from-segments --shift-to-zero=false " + "data/dev/segments ark:- |' ark,t:data/dev/reco2utt ark:file.seg\n" + "See also: segmentation-combine-segments, " + "segmentation-merge, segmentation-merge-recordings, " + "segmentation-post-process --merge-adjacent-segments\n"; + + ParseOptions po(usage); + + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + std::string segmentation_rspecifier = po.GetArg(1), + reco2utt_rspecifier = po.GetArg(2), + segmentation_wspecifier = po.GetArg(3); + + SequentialTokenVectorReader reco2utt_reader(reco2utt_rspecifier); + RandomAccessSegmentationReader segmentation_reader( + segmentation_rspecifier); + SegmentationWriter segmentation_writer(segmentation_wspecifier); + + int32 num_done = 0, num_segmentations = 0, num_err = 0; + + for (; !reco2utt_reader.Done(); reco2utt_reader.Next()) { + const std::vector &utts = reco2utt_reader.Value(); + const std::string &reco_id = reco2utt_reader.Key(); + + Segmentation out_segmentation; + + for (std::vector::const_iterator it = utts.begin(); + it != utts.end(); ++it) { + if (!segmentation_reader.HasKey(*it)) { + KALDI_WARN << "Could not find utterance " << *it << " in " + << "segments segmentation " + << segmentation_rspecifier; + num_err++; + continue; + } + + const Segmentation &segmentation = segmentation_reader.Value(*it); + if (segmentation.Dim() != 1) { + KALDI_ERR << "Segments segmentation for utt " << *it << " is not " + << "kaldi segment converted to segmentation format " + << "in " << segmentation_rspecifier; + } + const Segment &segment = *(segmentation.Begin()); + + out_segmentation.PushBack(segment); + + num_done++; + } + + Sort(&out_segmentation); + segmentation_writer.Write(reco_id, out_segmentation); + num_segmentations++; + } + + KALDI_LOG << "Combined " << num_done << " utterance-level segments " + << "into " << num_segmentations + << " recording-level segmentations; failed with " + << num_err << " utterances."; + + return ((num_done > 0 && num_err < num_done) ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/segmentation-combine-segments.cc b/src/segmenterbin/segmentation-combine-segments.cc new file mode 100644 index 00000000000..1d745ca91f9 --- /dev/null +++ b/src/segmenterbin/segmentation-combine-segments.cc @@ -0,0 +1,142 @@ +// segmenterbin/segmentation-combine-segments.cc + +// Copyright 2015-16 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmentation-utils.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Combine utterance-level segmentations in an archive to " + "recording-level segmentations using the kaldi segments to map " + "utterances to their positions in the recordings.\n" + "A reco2utt file is used to specify which utterances belong to each " + "recording.\n" + "\n" + "Usage: segmentation-combine-segments [options] " + " " + " " + " \n" + " e.g.: segmentation-combine-segments ark:utt.seg " + "'ark:segmentation-init-from-segments --shift-to-zero=false " + "data/dev/segments ark:- |' ark,t:data/dev/reco2utt ark:file.seg\n" + "See also: segmentation-combine-segments-to-recording, " + "segmentation-merge, segmentatin-merge-recordings, " + "segmentation-post-process --merge-adjacent-segments\n"; + + bool include_missing = false; + + ParseOptions po(usage); + + po.Register("include-missing-utt-level-segmentations", &include_missing, + "If true, then the segmentations missing in " + "utt-level-segmentation-rspecifier is included in the " + "final output with the label taken from the " + "kaldi-segments-segmentation-rspecifier"); + + po.Read(argc, argv); + + if (po.NumArgs() != 4) { + po.PrintUsage(); + exit(1); + } + + std::string utt_segmentation_rspecifier = po.GetArg(1), + segments_segmentation_rspecifier = po.GetArg(2), + reco2utt_rspecifier = po.GetArg(3), + segmentation_wspecifier = po.GetArg(4); + + SequentialTokenVectorReader reco2utt_reader(reco2utt_rspecifier); + RandomAccessSegmentationReader segments_segmentation_reader( + segments_segmentation_rspecifier); + RandomAccessSegmentationReader utt_segmentation_reader( + utt_segmentation_rspecifier); + SegmentationWriter segmentation_writer(segmentation_wspecifier); + + int32 num_done = 0, num_segmentations = 0, num_err = 0; + int64 num_segments = 0; + + for (; !reco2utt_reader.Done(); reco2utt_reader.Next()) { + const std::vector &utts = reco2utt_reader.Value(); + const std::string &reco_id = reco2utt_reader.Key(); + + Segmentation out_segmentation; + + for (std::vector::const_iterator it = utts.begin(); + it != utts.end(); ++it) { + if (!segments_segmentation_reader.HasKey(*it)) { + KALDI_WARN << "Could not find utterance " << *it << " in " + << "segments segmentation " + << segments_segmentation_rspecifier; + num_err++; + continue; + } + + const Segmentation &segments_segmentation = + segments_segmentation_reader.Value(*it); + if (segments_segmentation.Dim() != 1) { + KALDI_ERR << "Segments segmentation for utt " << *it << " is not " + << "kaldi segment converted to segmentation format " + << "in " << segments_segmentation_rspecifier; + } + const Segment &segment = *(segments_segmentation.Begin()); + + if (!utt_segmentation_reader.HasKey(*it)) { + KALDI_WARN << "Could not find utterance " << *it << " in " + << "segmentation " << utt_segmentation_rspecifier + << (include_missing ? "; using default segmentation": ""); + if (!include_missing) { + num_err++; + } else { + out_segmentation.PushBack(segment); + num_segments++; + } + continue; + } + + const Segmentation &utt_segmentation + = utt_segmentation_reader.Value(*it); + num_segments += InsertFromSegmentation(utt_segmentation, + segment.start_frame, false, + &out_segmentation, NULL); + num_done++; + } + + Sort(&out_segmentation); + segmentation_writer.Write(reco_id, out_segmentation); + num_segmentations++; + } + + KALDI_LOG << "Combined " << num_done << " utterance-level segmentations " + << "into " << num_segmentations + << " recording-level segmentations; failed with " + << num_err << " utterances; " + << "wrote a total of " << num_segments << " segments."; + + return ((num_done > 0 && num_err < num_done) ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/segmentation-copy.cc b/src/segmenterbin/segmentation-copy.cc new file mode 100644 index 00000000000..b7e215b55f8 --- /dev/null +++ b/src/segmenterbin/segmentation-copy.cc @@ -0,0 +1,257 @@ +// segmenterbin/segmentation-copy.cc + +// Copyright 2015-16 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmentation.h" +#include "segmenter/segmentation-utils.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Copy segmentation or archives of segmentation.\n" + "If label-map is supplied, then apply the mapping to the labels \n" + "when copying.\n" + "If utt2label-map-rspecifier is supplied, then an utterance-specific " + "mapping is applied on the original labels\n" + "\n" + "Usage: segmentation-copy [options] " + "\n" + " e.g.: segmentation-copy ark:1.seg ark,t:-\n" + " or \n" + " segmentation-copy [options] " + "\n" + " e.g.: segmentation-copy --binary=false foo -\n"; + + bool binary = true; + std::string label_map_rxfilename, utt2label_map_rspecifier; + std::string include_rxfilename, exclude_rxfilename; + int32 keep_label = -1; + BaseFloat frame_subsampling_factor = 1; + + ParseOptions po(usage); + + po.Register("binary", &binary, + "Write in binary mode " + "(only relevant if output is a wxfilename)"); + po.Register("label-map", &label_map_rxfilename, + "File with mapping from old to new labels. " + "If new label is -1, then that segment is removed."); + po.Register("frame-subsampling-factor", &frame_subsampling_factor, + "Change frame rate by this factor"); + po.Register("utt2label-map-rspecifier", &utt2label_map_rspecifier, + "Utterance-specific mapping from old to new labels. " + "The first column is the utterance id. The next columns are " + "pairs :. If is -1, then " + "that represents the default label map. i.e. Any old label " + "for which the mapping is not defined, will be mapped to the " + "label corresponding to old-label -1."); + po.Register("keep-label", &keep_label, + "If supplied, only segments of this label are written out"); + po.Register("include", &include_rxfilename, + "Text file, the first field of each" + " line being interpreted as an " + "utterance-id whose features will be included"); + po.Register("exclude", &exclude_rxfilename, + "Text file, the first field of each " + "line being interpreted as an utterance-id" + " whose features will be excluded"); + + po.Read(argc, argv); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + // all these "fn"'s are either rspecifiers or filenames. + + std::string segmentation_in_fn = po.GetArg(1), + segmentation_out_fn = po.GetArg(2); + + // Read mapping from old to new labels + unordered_map label_map; + if (!label_map_rxfilename.empty()) { + Input ki(label_map_rxfilename); + std::string line; + while (std::getline(ki.Stream(), line)) { + std::vector splits; + SplitStringToVector(line, " ", true, &splits); + + if (splits.size() != 2) + KALDI_ERR << "Invalid format of line " << line + << " in " << label_map_rxfilename; + + label_map[std::atoi(splits[0].c_str())] = std::atoi(splits[1].c_str()); + } + } + + unordered_set include_set; + if (include_rxfilename != "") { + if (exclude_rxfilename != "") { + KALDI_ERR << "should not have both --exclude and --include option!"; + } + Input ki(include_rxfilename); + std::string line; + while (std::getline(ki.Stream(), line)) { + std::vector split_line; + SplitStringToVector(line, " \t\r", true, &split_line); + KALDI_ASSERT(!split_line.empty() && + "Empty line encountered in input from --include option"); + include_set.insert(split_line[0]); + } + } + + unordered_set exclude_set; + if (exclude_rxfilename != "") { + if (include_rxfilename != "") { + KALDI_ERR << "should not have both --exclude and --include option!"; + } + Input ki(exclude_rxfilename); + std::string line; + while (std::getline(ki.Stream(), line)) { + std::vector split_line; + SplitStringToVector(line, " \t\r", true, &split_line); + KALDI_ASSERT(!split_line.empty() && + "Empty line encountered in input from --exclude option"); + exclude_set.insert(split_line[0]); + } + } + + bool in_is_rspecifier = + (ClassifyRspecifier(segmentation_in_fn, NULL, NULL) + != kNoRspecifier), + out_is_wspecifier = + (ClassifyWspecifier(segmentation_out_fn, NULL, NULL, NULL) + != kNoWspecifier); + + if (in_is_rspecifier != out_is_wspecifier) + KALDI_ERR << "Cannot mix regular files and archives"; + + int64 num_done = 0, num_err = 0; + + if (!in_is_rspecifier) { + Segmentation segmentation; + { + bool binary_in; + Input ki(segmentation_in_fn, &binary_in); + segmentation.Read(ki.Stream(), binary_in); + } + + if (!label_map_rxfilename.empty()) + RelabelSegmentsUsingMap(label_map, &segmentation); + + if (keep_label != -1) + KeepSegments(keep_label, &segmentation); + + if (frame_subsampling_factor != 1.0) { + ScaleFrameShift(frame_subsampling_factor, &segmentation); + } + + if (!utt2label_map_rspecifier.empty()) + KALDI_ERR << "It makes no sense to specify utt2label-map-rspecifier " + << "when not reading segmentation archives."; + + Output ko(segmentation_out_fn, binary); + segmentation.Write(ko.Stream(), binary); + + KALDI_LOG << "Copied segmentation to " << segmentation_out_fn; + return 0; + } else { + RandomAccessTokenVectorReader utt2label_map_reader( + utt2label_map_rspecifier); + + SegmentationWriter writer(segmentation_out_fn); + SequentialSegmentationReader reader(segmentation_in_fn); + + for (; !reader.Done(); reader.Next()) { + const std::string &key = reader.Key(); + + if (include_rxfilename != "" && include_set.count(key) == 0) { + continue; + } + + if (exclude_rxfilename != "" && include_set.count(key) > 0) { + continue; + } + + if (label_map_rxfilename.empty() && + frame_subsampling_factor == 1.0 && + utt2label_map_rspecifier.empty() && + keep_label == -1) { + writer.Write(key, reader.Value()); + } else { + Segmentation segmentation = reader.Value(); + if (!label_map_rxfilename.empty()) + RelabelSegmentsUsingMap(label_map, &segmentation); + + if (!utt2label_map_rspecifier.empty()) { + if (!utt2label_map_reader.HasKey(key)) { + KALDI_WARN << "Utterance " << key + << " not found in utt2label_map " + << utt2label_map_rspecifier; + num_err++; + continue; + } + + unordered_map utt_label_map; + + const std::vector &utt_label_map_vec = + utt2label_map_reader.Value(key); + std::vector::const_iterator it = + utt_label_map_vec.begin(); + + for (; it != utt_label_map_vec.end(); ++it) { + std::vector vec; + SplitStringToFloats(*it, ":", false, &vec); + if (vec.size() != 2) { + KALDI_ERR << "Invalid utt-label-map " << *it; + } + utt_label_map[static_cast(vec[0])] = + static_cast(vec[1]); + } + + RelabelSegmentsUsingMap(utt_label_map, &segmentation); + } + + if (keep_label != -1) + KeepSegments(keep_label, &segmentation); + + if (frame_subsampling_factor != 1.0) + ScaleFrameShift(frame_subsampling_factor, &segmentation); + + writer.Write(key, segmentation); + } + + num_done++; + } + + KALDI_LOG << "Copied " << num_done << " segmentation; failed with " + << num_err << " segmentations"; + return (num_done != 0 ? 0 : 1); + } + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/segmentation-create-subsegments.cc b/src/segmenterbin/segmentation-create-subsegments.cc new file mode 100644 index 00000000000..9d7f4c08b6d --- /dev/null +++ b/src/segmenterbin/segmentation-create-subsegments.cc @@ -0,0 +1,175 @@ +// segmenterbin/segmentation-create-subsegments.cc + +// Copyright 2015-16 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmentation-utils.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Create sub-segmentation of a segmentation by intersecting with " + "segments from a 'filter' segmentation. \n" + "The labels for the new subsegments are decided " + "depending on whether the label of 'filter' segment " + "matches the specified 'filter_label' or not:\n" + " if filter segment's label == filter_label: \n" + " if subsegment_label is specified:\n" + " label = subsegment_label\n" + " else: \n" + " label = filter_label \n" + " else: \n" + " if unmatched_label is specified:\n" + " label = unmatched_label\n" + " else\n:" + " label = primary_label\n" + "See the function SubSegmentUsingNonOverlappingSegments() " + "for more details.\n" + "\n" + "Usage: segmentation-create-subsegments [options] " + " " + " \n" + " or : segmentation-create-subsegments [options] " + " " + " \n" + " e.g.: segmentation-create-subsegments --binary=false " + "--filter-label=1 --subsegment-label=1000 foo bar -\n" + " segmentation-create-subsegments --filter-label=1 " + "--subsegment-label=1000 ark:1.foo ark:1.bar ark:-\n"; + + bool binary = true, ignore_missing = false; + int32 filter_label = -1, subsegment_label = -1, unmatched_label = -1; + ParseOptions po(usage); + + po.Register("binary", &binary, + "Write in binary mode " + "(only relevant if output is a wxfilename)"); + po.Register("filter-label", &filter_label, + "The label on which filtering is done."); + po.Register("subsegment-label", &subsegment_label, + "If non-negative, change the class-id of the matched regions " + "in the intersection of the two segmentations to this label."); + po.Register("unmatched-label", &unmatched_label, + "If non-negative, change the class-id of the unmatched " + "regions in the intersection of the two segmentations " + "to this label."); + po.Register("ignore-missing", &ignore_missing, "Ignore missing " + "segmentations in filter. If this is set true, then the " + "segmentations with missing key in filter are written " + "without any modification."); + + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + std::string segmentation_in_fn = po.GetArg(1), + secondary_segmentation_in_fn = po.GetArg(2), + segmentation_out_fn = po.GetArg(3); + + // all these "fn"'s are either rspecifiers or filenames. + + bool in_is_rspecifier = + (ClassifyRspecifier(segmentation_in_fn, NULL, NULL) + != kNoRspecifier), + filter_is_rspecifier = + (ClassifyRspecifier(secondary_segmentation_in_fn, NULL, NULL) + != kNoRspecifier), + out_is_wspecifier = + (ClassifyWspecifier(segmentation_out_fn, NULL, NULL, NULL) + != kNoWspecifier); + + if (in_is_rspecifier != out_is_wspecifier || + in_is_rspecifier != filter_is_rspecifier) + KALDI_ERR << "Cannot mix regular files and archives"; + + int64 num_done = 0, num_err = 0; + + if (!in_is_rspecifier) { + Segmentation segmentation; + { + bool binary_in; + Input ki(segmentation_in_fn, &binary_in); + segmentation.Read(ki.Stream(), binary_in); + } + Segmentation secondary_segmentation; + { + bool binary_in; + Input ki(secondary_segmentation_in_fn, &binary_in); + secondary_segmentation.Read(ki.Stream(), binary_in); + } + + Segmentation new_segmentation; + SubSegmentUsingNonOverlappingSegments( + segmentation, secondary_segmentation, filter_label, subsegment_label, + unmatched_label, &new_segmentation); + Output ko(segmentation_out_fn, binary); + new_segmentation.Write(ko.Stream(), binary); + + KALDI_LOG << "Created subsegments of " << segmentation_in_fn + << " based on " << secondary_segmentation_in_fn + << " and wrote to " << segmentation_out_fn; + return 0; + } else { + SegmentationWriter writer(segmentation_out_fn); + SequentialSegmentationReader reader(segmentation_in_fn); + RandomAccessSegmentationReader filter_reader( + secondary_segmentation_in_fn); + + for (; !reader.Done(); reader.Next(), num_done++) { + const Segmentation &segmentation = reader.Value(); + const std::string &key = reader.Key(); + + if (!filter_reader.HasKey(key)) { + KALDI_WARN << "Could not find filter segmentation for utterance " + << key; + if (!ignore_missing) + num_err++; + else + writer.Write(key, segmentation); + continue; + } + const Segmentation &secondary_segmentation = filter_reader.Value(key); + + Segmentation new_segmentation; + SubSegmentUsingNonOverlappingSegments(segmentation, + secondary_segmentation, + filter_label, subsegment_label, + unmatched_label, + &new_segmentation); + + writer.Write(key, new_segmentation); + } + + KALDI_LOG << "Created subsegments for " << num_done << " segmentations; " + << "failed with " << num_err << " segmentations"; + + return ((num_done != 0 && num_err < num_done) ? 0 : 1); + } + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/segmentation-get-stats.cc b/src/segmenterbin/segmentation-get-stats.cc new file mode 100644 index 00000000000..1e39bafec44 --- /dev/null +++ b/src/segmenterbin/segmentation-get-stats.cc @@ -0,0 +1,139 @@ +// segmenterbin/segmentation-get-per-frame-stats.cc + +// Copyright 2016 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include +#include "base/kaldi-common.h" +#include "hmm/posterior.h" +#include "util/common-utils.h" +#include "segmenter/segmentation-utils.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Get per-frame stats from segmentation. \n" + "Currently supported stats are \n" + " num-overlaps: Number of overlapping segments common to this frame\n" + " num-classes: Number of distinct classes common to this frame\n" + "\n" + "Usage: segmentation-get-stats [options] " + " " + "\n" + " e.g.: segmentation-get-stats ark:1.seg ark:/dev/null " + "ark:num_classes.ark ark:/dev/null\n"; + + ParseOptions po(usage); + + std::string lengths_rspecifier; + int32 length_tolerance = 2; + + po.Register("lengths-rspecifier", &lengths_rspecifier, + "Archive of frame lengths of the utterances. " + "Fills up any extra length with zero stats."); + po.Register("length-tolerance", &length_tolerance, + "Tolerate shortage of this many frames in the specified " + "lengths file"); + + po.Read(argc, argv); + + if (po.NumArgs() != 4) { + po.PrintUsage(); + exit(1); + } + + std::string segmentation_rspecifier = po.GetArg(1), + num_overlaps_wspecifier = po.GetArg(2), + num_classes_wspecifier = po.GetArg(3), + class_counts_per_frame_wspecifier = po.GetArg(4); + + int64 num_done = 0, num_err = 0; + + SequentialSegmentationReader reader(segmentation_rspecifier); + Int32VectorWriter num_overlaps_writer(num_overlaps_wspecifier); + Int32VectorWriter num_classes_writer(num_classes_wspecifier); + PosteriorWriter class_counts_per_frame_writer( + class_counts_per_frame_wspecifier); + + RandomAccessInt32Reader lengths_reader(lengths_rspecifier); + + for (; !reader.Done(); reader.Next(), num_done++) { + const Segmentation &segmentation = reader.Value(); + const std::string &key = reader.Key(); + + int32 length = -1; + if (!lengths_rspecifier.empty()) { + if (!lengths_reader.HasKey(key)) { + KALDI_WARN << "Could not find length for key " << key; + num_err++; + continue; + } + length = lengths_reader.Value(key); + } + + std::vector > class_counts_map_per_frame; + if (!GetClassCountsPerFrame(segmentation, length, + length_tolerance, + &class_counts_map_per_frame)) { + KALDI_WARN << "Failed getting stats for key " << key; + num_err++; + continue; + } + + if (length == -1) + length = class_counts_map_per_frame.size(); + + std::vector num_classes_per_frame(length, 0); + std::vector num_overlaps_per_frame(length, 0); + Posterior class_counts_per_frame(length, + std::vector >()); + + for (int32 i = 0; i < class_counts_map_per_frame.size(); i++) { + std::map &class_counts = class_counts_map_per_frame[i]; + + for (std::map::const_iterator it = class_counts.begin(); + it != class_counts.end(); ++it) { + if (it->second > 0) { + num_classes_per_frame[i]++; + class_counts_per_frame[i].push_back( + std::make_pair(it->first, it->second)); + } + num_overlaps_per_frame[i] += it->second; + } + std::sort(class_counts_per_frame[i].begin(), + class_counts_per_frame[i].end()); + } + + num_classes_writer.Write(key, num_classes_per_frame); + num_overlaps_writer.Write(key, num_overlaps_per_frame); + class_counts_per_frame_writer.Write(key, class_counts_per_frame); + + num_done++; + } + + KALDI_LOG << "Got stats for " << num_done << " segmentations; failed with " + << num_err << " segmentations"; + return (num_done != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/segmentation-init-from-additive-signals-info.cc b/src/segmenterbin/segmentation-init-from-additive-signals-info.cc new file mode 100644 index 00000000000..abf5aed219b --- /dev/null +++ b/src/segmenterbin/segmentation-init-from-additive-signals-info.cc @@ -0,0 +1,151 @@ +// segmenterbin/segmentation-init-from-overlap-info.cc + +// Copyright 2015-16 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmentation-utils.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Convert overlapping segments information into segmentation\n" + "\n" + "Usage: segmentation-init-from-additive-signals-info [options] " + " \n" + " e.g.: segmentation-init-from-additive-signals-info --additive-signals-segmentation-rspecifier=ark:utt_segmentation.ark " + "ark,t:overlapped_segments_info.txt ark:-\n"; + + BaseFloat frame_shift = 0.01; + int32 junk_label = -2; + std::string lengths_rspecifier; + std::string additive_signals_segmentation_rspecifier; + + ParseOptions po(usage); + + po.Register("frame-shift", &frame_shift, "Frame shift in seconds"); + po.Register("lengths-rspecifier", &lengths_rspecifier, + "Archive of lengths for recordings; if provided, will be " + "used to truncate the output segmentation."); + po.Register("additive-signals-segmentation-rspecifier", + &additive_signals_segmentation_rspecifier, + "Archive of segmentation of the additive signal which will used " + "instead of an all 1 segmentation"); + po.Register("junk-label", &junk_label, + "The unreliable regions are labeled with this label"); + + po.Read(argc, argv); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + std::string additive_signals_info_rspecifier = po.GetArg(1), + segmentation_wspecifier = po.GetArg(2); + + SequentialTokenVectorReader additive_signals_info_reader( + additive_signals_info_rspecifier); + SegmentationWriter writer(segmentation_wspecifier); + + RandomAccessSegmentationReader additive_signals_segmentation_reader( + additive_signals_segmentation_rspecifier); + RandomAccessInt32Reader lengths_reader(lengths_rspecifier); + + int32 num_done = 0, num_err = 0; + + for (; !additive_signals_info_reader.Done(); + additive_signals_info_reader.Next()) { + const std::string &key = additive_signals_info_reader.Key(); + const std::vector &additive_signals_info = + additive_signals_info_reader.Value(); + + Segmentation segmentation; + + for (size_t i = 0; i < additive_signals_info.size(); i++) { + std::vector parts; + SplitStringToVector(additive_signals_info[i], ",:", false, &parts); + + if (parts.size() != 3) { + KALDI_ERR << "Invalid format of overlap info " + << additive_signals_info[i] + << "for key " << key << " in " + << additive_signals_info_rspecifier; + } + const std::string &utt_id = parts[0]; + double start_time; + double duration; + ConvertStringToReal(parts[1], &start_time); + ConvertStringToReal(parts[2], &duration); + + int32 start_frame = round(start_time / frame_shift); + + if (!additive_signals_segmentation_reader.HasKey(utt_id)) { + KALDI_WARN << "Could not find utterance " << utt_id << " in " + << "segmentation " + << additive_signals_segmentation_rspecifier + << ". Assiginng the segment --junk-label."; + if (duration < 0) { + KALDI_ERR << "duration < 0 for utt_id " << utt_id << " in " + << "additive_signals_info " + << additive_signals_info_rspecifier + << "; additive-signals-segmentation must be provided " + << "in such a case"; + } + num_err++; + int32 length = round(duration / frame_shift); + segmentation.EmplaceBack(start_frame, start_frame + length - 1, + junk_label); + continue; // Treated as non-overlapping even though there + // is overlap + } + + InsertFromSegmentation( + additive_signals_segmentation_reader.Value(utt_id), + start_frame, false, &segmentation); + } + + Sort(&segmentation); + if (!lengths_rspecifier.empty()) { + if (!lengths_reader.HasKey(key)) { + KALDI_WARN << "Could not find length for the recording " << key + << "in " << lengths_rspecifier; + continue; + } + TruncateToLength(lengths_reader.Value(key), &segmentation); + } + writer.Write(key, segmentation); + + num_done++; + } + + KALDI_LOG << "Successfully processed " << num_done << " recordings " + << " in additive signals info" + << "; could not get segmentation for " << num_err + << "additive signals."; + + return (num_done > num_err / 2 ? 0 : 1); + + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} diff --git a/src/segmenterbin/segmentation-init-from-ali.cc b/src/segmenterbin/segmentation-init-from-ali.cc new file mode 100644 index 00000000000..452ff56c2d8 --- /dev/null +++ b/src/segmenterbin/segmentation-init-from-ali.cc @@ -0,0 +1,95 @@ +// segmenterbin/segmentation-init-from-ali.cc + +// Copyright 2015-16 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmentation-utils.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Initialize utterance-level segmentations from alignments file. \n" + "The user can pass this to segmentation-combine-segments to " + "create recording-level segmentations." + "\n" + "Usage: segmentation-init-from-ali [options] " + " \n" + " e.g.: segmentation-init-from-ali ark:1.ali ark:-\n" + "See also: segmentation-init-from-segments, " + "segmentation-combine-segments\n"; + + ParseOptions po(usage); + po.Read(argc, argv); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + std::string ali_rspecifier = po.GetArg(1), + segmentation_wspecifier = po.GetArg(2); + + SegmentationWriter segmentation_writer(segmentation_wspecifier); + + int32 num_done = 0, num_segmentations = 0; + int64 num_segments = 0; + int64 num_err = 0; + + std::map frame_counts_per_class; + + SequentialInt32VectorReader alignment_reader(ali_rspecifier); + + for (; !alignment_reader.Done(); alignment_reader.Next()) { + const std::string &key = alignment_reader.Key(); + const std::vector &alignment = alignment_reader.Value(); + + Segmentation segmentation; + + num_segments += InsertFromAlignment(alignment, 0, alignment.size(), + 0, &segmentation, + &frame_counts_per_class); + + Sort(&segmentation); + segmentation_writer.Write(key, segmentation); + + num_done++; + num_segmentations++; + } + + KALDI_LOG << "Processed " << num_done << " utterances; failed with " + << num_err << " utterances; " + << "wrote " << num_segmentations << " segmentations " + << "with a total of " << num_segments << " segments."; + KALDI_LOG << "Number of frames for the different classes are : "; + + std::map::const_iterator it = frame_counts_per_class.begin(); + for (; it != frame_counts_per_class.end(); ++it) { + KALDI_LOG << it->first << " " << it->second << " ; "; + } + + return ((num_done > 0 && num_err < num_done) ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/segmentation-init-from-lengths.cc b/src/segmenterbin/segmentation-init-from-lengths.cc new file mode 100644 index 00000000000..28c998c220b --- /dev/null +++ b/src/segmenterbin/segmentation-init-from-lengths.cc @@ -0,0 +1,82 @@ +// segmenterbin/segmentation-init-from-lengths.cc + +// Copyright 2015-16 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmentation.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Initialize segmentations from frame lengths file\n" + "\n" + "Usage: segmentation-init-from-lengths [options] " + " \n" + " e.g.: segmentation-init-from-lengths " + "\"ark:feat-to-len scp:feats.scp ark:- |\" ark:-\n" + "\n" + "See also: segmentation-init-from-ali, " + "segmentation-init-from-segments\n"; + + int32 label = 1; + + ParseOptions po(usage); + + po.Register("label", &label, "Label to assign to the created segments"); + po.Read(argc, argv); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + std::string lengths_rspecifier = po.GetArg(1), + segmentation_wspecifier = po.GetArg(2); + + SequentialInt32Reader lengths_reader(lengths_rspecifier); + SegmentationWriter segmentation_writer(segmentation_wspecifier); + + int32 num_done = 0; + + for (; !lengths_reader.Done(); lengths_reader.Next()) { + const std::string &key = lengths_reader.Key(); + const int32 &length = lengths_reader.Value(); + + Segmentation segmentation; + + if (length > 0) { + segmentation.EmplaceBack(0, length - 1, label); + } + + segmentation_writer.Write(key, segmentation); + num_done++; + } + + KALDI_LOG << "Created " << num_done << " segmentations."; + + return (num_done > 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/segmentation-init-from-segments.cc b/src/segmenterbin/segmentation-init-from-segments.cc new file mode 100644 index 00000000000..980ec697602 --- /dev/null +++ b/src/segmenterbin/segmentation-init-from-segments.cc @@ -0,0 +1,180 @@ +// segmenterbin/segmentation-init-from-segments.cc + +// Copyright 2015-16 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmentation.h" + +// If segments file contains +// Alpha-001 Alpha 0.00 0.16 +// Alpha-002 Alpha 1.50 4.10 +// Beta-001 Beta 0.50 2.66 +// Beta-002 Beta 3.50 5.20 +// the output segmentation will contain +// Alpha-001 [ 0 15 1 ] +// Alpha-002 [ 0 359 1 ] +// Beta-001 [ 0 215 1 ] +// Beta-002 [ 0 169 1 ] +// If --shift-to-zero=false is provided, then the output will contain +// Alpha-001 [ 0 15 1 ] +// Alpha-002 [ 150 409 1 ] +// Beta-001 [ 50 265 1 ] +// Beta-002 [ 350 519 1 ] +// +// If the following utt2label-rspecifier was provided: +// Alpha-001 2 +// Alpha-002 2 +// Beta-001 4 +// Beta-002 4 +// then the output segmentation will contain +// Alpha-001 [ 0 15 2 ] +// Alpha-002 [ 0 359 2 ] +// Beta-001 [ 0 215 4 ] +// Beta-002 [ 0 169 4 ] + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Convert segments from segments file into utterance-level " + "segmentation format. \n" + "The user can convert the segmenation to recording-level using " + "the binary segmentation-combine-segments-to-recording.\n" + "\n" + "Usage: segmentation-init-from-segments [options] " + " \n" + " e.g.: segmentation-init-from-segments segments ark:-\n"; + + int32 segment_label = 1; + BaseFloat frame_shift = 0.01, frame_overlap = 0.015; + std::string utt2label_rspecifier; + bool shift_to_zero = true; + + ParseOptions po(usage); + + po.Register("label", &segment_label, + "Label for all the segments in the segmentations"); + po.Register("utt2label-rspecifier", &utt2label_rspecifier, + "Mapping for each utterance to an integer label. " + "If supplied, these labels will be used as the segment " + "labels"); + po.Register("shift-to-zero", &shift_to_zero, + "Shift all segments to 0th frame"); + po.Register("frame-shift", &frame_shift, "Frame shift in seconds"); + po.Register("frame-overlap", &frame_overlap, "Frame overlap in seconds"); + + po.Read(argc, argv); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + std::string segments_rxfilename = po.GetArg(1), + segmentation_wspecifier = po.GetArg(2); + + SegmentationWriter writer(segmentation_wspecifier); + RandomAccessInt32Reader utt2label_reader(utt2label_rspecifier); + + Input ki(segments_rxfilename); + + int64 num_lines = 0, num_done = 0; + + std::string line; + + while (std::getline(ki.Stream(), line)) { + num_lines++; + + std::vector split_line; + // Split the line by space or tab and check the number of fields in each + // line. There must be 4 fields--segment name , reacording wav file name, + // start time, end time; 5th field (channel info) is optional. + SplitStringToVector(line, " \t\r", true, &split_line); + if (split_line.size() != 4 && split_line.size() != 5) { + KALDI_WARN << "Invalid line in segments file: " << line; + continue; + } + std::string utt = split_line[0], + reco = split_line[1], + start_str = split_line[2], + end_str = split_line[3]; + + // Convert the start time and endtime to real from string. Segment is + // ignored if start or end time cannot be converted to real. + double start, end; + if (!ConvertStringToReal(start_str, &start)) { + KALDI_WARN << "Invalid line in segments file [bad start]: " << line; + continue; + } + if (!ConvertStringToReal(end_str, &end)) { + KALDI_WARN << "Invalid line in segments file [bad end]: " << line; + continue; + } + + // start time must not be negative; start time must not be greater than + // end time, except if end time is -1 + if (start < 0 || (end != -1.0 && end <= 0) || + ((start >= end) && (end > 0))) { + KALDI_WARN << "Invalid line in segments file " + << "[empty or invalid segment]: " << line; + continue; + } + + if (split_line.size() >= 5) + KALDI_ERR << "Not supporting channel in segments file"; + + Segmentation segmentation; + + if (!utt2label_rspecifier.empty()) { + if (!utt2label_reader.HasKey(utt)) { + KALDI_WARN << "Could not find utterance " << utt << " in " + << utt2label_rspecifier; + continue; + } + + segment_label = utt2label_reader.Value(utt); + } + + if (shift_to_zero) { + int32 last_frame = (end-frame_overlap) / frame_shift + - start / frame_shift - 1; + segmentation.EmplaceBack(0, last_frame, segment_label); + } else { + segmentation.EmplaceBack( + static_cast(start / frame_shift + 0.5), + static_cast((end-frame_overlap) / frame_shift - 0.5), + segment_label); + } + + writer.Write(utt, segmentation); + num_done++; + } + + KALDI_LOG << "Successfully processed " << num_done << " lines out of " + << num_lines << " in the segments file"; + + return (num_done > num_lines / 2 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/segmentation-intersect-ali.cc b/src/segmenterbin/segmentation-intersect-ali.cc new file mode 100644 index 00000000000..a551eee02ce --- /dev/null +++ b/src/segmenterbin/segmentation-intersect-ali.cc @@ -0,0 +1,99 @@ +// segmenterbin/segmentation-intersect-ali.cc + +// Copyright 2015-16 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmentation-utils.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Intersect (like sets) segmentation with an alignment and retain \n" + "only segments where the alignment is the specified label. \n" + "\n" + "Usage: segmentation-intersect-alignment [options] " + " " + "\n" + " e.g.: segmentation-intersect-alignment --binary=false ark:foo.seg " + "ark:filter.ali ark,t:-\n" + "See also: segmentation-combine-segments, " + "segmentation-intersect-segments, segmentation-create-subsegments\n"; + + ParseOptions po(usage); + + int32 ali_label = 0, min_alignment_chunk_length = 0; + + po.Register("ali-label", &ali_label, + "Intersect only at this label of alignments"); + po.Register("min-alignment-chunk-length", &min_alignment_chunk_length, + "The minimmum number of consecutive frames of ali_label in " + "alignment at which the segments can be intersected."); + + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + std::string segmentation_rspecifier = po.GetArg(1), + ali_rspecifier = po.GetArg(2), + segmentation_wspecifier = po.GetArg(3); + + int32 num_done = 0, num_err = 0; + + SegmentationWriter writer(segmentation_wspecifier); + SequentialSegmentationReader segmentation_reader(segmentation_rspecifier); + RandomAccessInt32VectorReader alignment_reader(ali_rspecifier); + + for (; !segmentation_reader.Done(); segmentation_reader.Next()) { + const Segmentation &segmentation = segmentation_reader.Value(); + const std::string &key = segmentation_reader.Key(); + + if (!alignment_reader.HasKey(key)) { + KALDI_WARN << "Could not find segmentation for key " << key + << " in " << ali_rspecifier; + num_err++; + continue; + } + const std::vector &ali = alignment_reader.Value(key); + + Segmentation out_segmentation; + IntersectSegmentationAndAlignment(segmentation, ali, ali_label, + min_alignment_chunk_length, + &out_segmentation); + out_segmentation.Sort(); + + writer.Write(key, out_segmentation); + num_done++; + } + + KALDI_LOG << "Intersected " << num_done + << " segmentations with alignments; failed with " + << num_err << " segmentations"; + return (num_done != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/segmentation-intersect-segments.cc b/src/segmenterbin/segmentation-intersect-segments.cc new file mode 100644 index 00000000000..1c9861ba453 --- /dev/null +++ b/src/segmenterbin/segmentation-intersect-segments.cc @@ -0,0 +1,145 @@ +// segmenterbin/segmentation-intersect-segments.cc + +// Copyright 2015-16 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmentation-utils.h" + +namespace kaldi { +namespace segmenter { + +void IntersectSegmentationsNonOverlapping( + const Segmentation &in_segmentation, + const Segmentation &secondary_segmentation, + int32 mismatch_label, + Segmentation *out_segmentation) { + KALDI_ASSERT(out_segmentation); + KALDI_ASSERT(secondary_segmentation.Dim() > 0); + + std::vector alignment; + ConvertToAlignment(secondary_segmentation, -1, -1, 0, &alignment); + + for (SegmentList::const_iterator it = in_segmentation.Begin(); + it != in_segmentation.End(); ++it) { + if (it->end_frame >= alignment.size()) { + alignment.resize(it->end_frame + 1, -1); + } + Segmentation filter_segmentation; + InsertFromAlignment(alignment, it->start_frame, it->end_frame + 1, + 0, &filter_segmentation, NULL); + + for (SegmentList::const_iterator f_it = filter_segmentation.Begin(); + f_it != filter_segmentation.End(); ++f_it) { + int32 label = it->Label(); + if (f_it->Label() != it->Label()) { + if (mismatch_label == -1) continue; + label = mismatch_label; + } + + out_segmentation->EmplaceBack(f_it->start_frame, f_it->end_frame, + label); + } + } +} + +} +} + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Intersect segments from two archives by retaining only regions .\n" + "where the primary and secondary segments match on label\n" + "\n" + "Usage: segmentation-intersect-segments [options] " + " " + "\n" + " e.g.: segmentation-intersect-segments ark:foo.seg ark:bar.seg " + "ark,t:-\n" + "See also: segmentation-create-subsegments, " + "segmentation-intersect-ali\n"; + + int32 mismatch_label = -1; + bool assume_non_overlapping_secondary = true; + + ParseOptions po(usage); + + po.Register("mismatch-label", &mismatch_label, + "Intersect only where secondary segment has this label"); + po.Register("assume-non-overlapping-secondary", & + assume_non_overlapping_secondary, + "Assume secondary segments are non-overlapping"); + + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + std::string primary_rspecifier = po.GetArg(1), + secondary_rspecifier = po.GetArg(2), + segmentation_writer = po.GetArg(3); + + if (!assume_non_overlapping_secondary) { + KALDI_ERR << "Secondary segment must be non-overlapping for now"; + } + + int64 num_done = 0, num_err = 0; + + SegmentationWriter writer(segmentation_writer); + SequentialSegmentationReader primary_reader(primary_rspecifier); + RandomAccessSegmentationReader secondary_reader(secondary_rspecifier); + + for (; !primary_reader.Done(); primary_reader.Next()) { + const Segmentation &segmentation = primary_reader.Value(); + const std::string &key = primary_reader.Key(); + + if (!secondary_reader.HasKey(key)) { + KALDI_WARN << "Could not find segmentation for key " << key + << " in " << secondary_rspecifier; + num_err++; + continue; + } + const Segmentation &secondary_segmentation = secondary_reader.Value(key); + + Segmentation out_segmentation; + IntersectSegmentationsNonOverlapping(segmentation, + secondary_segmentation, + mismatch_label, + &out_segmentation); + + Sort(&out_segmentation); + + writer.Write(key, out_segmentation); + num_done++; + } + + KALDI_LOG << "Intersected " << num_done << " segmentations; failed with " + << num_err << " segmentations"; + return (num_done != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/segmentation-merge-recordings.cc b/src/segmenterbin/segmentation-merge-recordings.cc new file mode 100644 index 00000000000..69f6758c90d --- /dev/null +++ b/src/segmenterbin/segmentation-merge-recordings.cc @@ -0,0 +1,102 @@ +// segmenterbin/segmentation-merge-recordings.cc + +// Copyright 2016 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmentation-utils.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Merge segmentations of different recordings into one segmentation " + "using a mapping from new to old recording name\n" + "\n" + "Usage: segmentation-merge-recordings [options] " + " \n" + " e.g.: segmentation-merge-recordings ark:sdm2ihm_reco.map " + "ark:ihm_seg.ark ark:sdm_seg.ark\n"; + + ParseOptions po(usage); + + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + std::string new2old_list_rspecifier = po.GetArg(1); + std::string segmentation_rspecifier = po.GetArg(2), + segmentation_wspecifier = po.GetArg(3); + + SequentialTokenVectorReader new2old_reader(new2old_list_rspecifier); + RandomAccessSegmentationReader segmentation_reader( + segmentation_rspecifier); + SegmentationWriter segmentation_writer(segmentation_wspecifier); + + int32 num_new_segmentations = 0, num_old_segmentations = 0; + int64 num_segments = 0, num_err = 0; + + for (; !new2old_reader.Done(); new2old_reader.Next()) { + const std::vector &old_key_list = new2old_reader.Value(); + const std::string &new_key = new2old_reader.Key(); + + KALDI_ASSERT(old_key_list.size() > 0); + + Segmentation segmentation; + + for (std::vector::const_iterator it = old_key_list.begin(); + it != old_key_list.end(); ++it) { + num_old_segmentations++; + + if (!segmentation_reader.HasKey(*it)) { + KALDI_WARN << "Could not find key " << *it << " in " + << "old segmentation " << segmentation_rspecifier; + num_err++; + continue; + } + + const Segmentation &this_segmentation = segmentation_reader.Value(*it); + + num_segments += InsertFromSegmentation(this_segmentation, 0, NULL, + &segmentation); + } + Sort(&segmentation); + + segmentation_writer.Write(new_key, segmentation); + + num_new_segmentations++; + } + + KALDI_LOG << "Merged " << num_old_segmentations << " old segmentations " + << "into " << num_new_segmentations << " new segmentations; " + << "created overall " << num_segments << " segments; " + << "failed to merge " << num_err << " old segmentations"; + + return (num_segments > 0 && num_new_segmentations > 0 && + num_err < num_old_segmentations / 2 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/segmentation-merge.cc b/src/segmenterbin/segmentation-merge.cc new file mode 100644 index 00000000000..21e9a410e15 --- /dev/null +++ b/src/segmenterbin/segmentation-merge.cc @@ -0,0 +1,146 @@ +// segmenterbin/segmentation-merge.cc + +// Copyright 2015 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmentation-utils.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Merge corresponding segments from multiple archives or files.\n" + "i.e. for each utterance in the first segmentation, the segments " + "from all the supplied segmentations are merged and put in a single " + "segmentation." + "\n" + "Usage: segmentation-merge [options] " + " ... " + "\n" + " e.g.: segmentation-merge ark:foo.seg ark:bar.seg ark,t:-\n" + " or \n" + " segmentation-merge " + " ... " + "\n" + " e.g.: segmentation-merge --binary=false foo bar -\n" + "See also: segmentation-copy, segmentation-merge-recordings, " + "segmentation-post-process --merge-labels\n"; + + bool binary = true; + bool sort = true; + + ParseOptions po(usage); + + po.Register("binary", &binary, + "Write in binary mode " + "(only relevant if output is a wxfilename)"); + po.Register("sort", &sort, "Sort the segements after merging"); + + po.Read(argc, argv); + + if (po.NumArgs() <= 2) { + po.PrintUsage(); + exit(1); + } + + std::string segmentation_in_fn = po.GetArg(1), + segmentation_out_fn = po.GetArg(po.NumArgs()); + + // all these "fn"'s are either rspecifiers or filenames. + bool in_is_rspecifier = + (ClassifyRspecifier(segmentation_in_fn, NULL, NULL) + != kNoRspecifier), + out_is_wspecifier = + (ClassifyWspecifier(segmentation_out_fn, NULL, NULL, NULL) + != kNoWspecifier); + + if (in_is_rspecifier != out_is_wspecifier) + KALDI_ERR << "Cannot mix regular files and archives"; + + int64 num_done = 0, num_err = 0; + + if (!in_is_rspecifier) { + Segmentation segmentation; + { + bool binary_in; + Input ki(segmentation_in_fn, &binary_in); + segmentation.Read(ki.Stream(), binary_in); + } + + for (int32 i = 2; i < po.NumArgs(); i++) { + bool binary_in; + Input ki(po.GetArg(i), &binary_in); + Segmentation other_segmentation; + other_segmentation.Read(ki.Stream(), binary_in); + ExtendSegmentation(other_segmentation, false, + &segmentation); + } + + Sort(&segmentation); + + Output ko(segmentation_out_fn, binary); + segmentation.Write(ko.Stream(), binary); + + KALDI_LOG << "Merged segmentations to " << segmentation_out_fn; + return 0; + } else { + SegmentationWriter writer(segmentation_out_fn); + SequentialSegmentationReader reader(segmentation_in_fn); + std::vector other_readers( + po.NumArgs()-2, + static_cast(NULL)); + + for (size_t i = 0; i < po.NumArgs()-2; i++) { + other_readers[i] = new RandomAccessSegmentationReader(po.GetArg(i+2)); + } + + for (; !reader.Done(); reader.Next()) { + Segmentation segmentation(reader.Value()); + std::string key = reader.Key(); + + for (size_t i = 0; i < po.NumArgs()-2; i++) { + if (!other_readers[i]->HasKey(key)) { + KALDI_WARN << "Could not find segmentation for key " << key + << " in " << po.GetArg(i+2); + num_err++; + } + const Segmentation &other_segmentation = + other_readers[i]->Value(key); + ExtendSegmentation(other_segmentation, false, + &segmentation); + } + + Sort(&segmentation); + + writer.Write(key, segmentation); + num_done++; + } + + KALDI_LOG << "Merged " << num_done << " segmentation; failed with " + << num_err << " segmentations"; + return (num_done != 0 ? 0 : 1); + } + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/segmentation-post-process.cc b/src/segmenterbin/segmentation-post-process.cc new file mode 100644 index 00000000000..921ee5dc5d8 --- /dev/null +++ b/src/segmenterbin/segmentation-post-process.cc @@ -0,0 +1,142 @@ +// segmenterbin/segmentation-post-process.cc + +// Copyright 2015-16 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmentation-post-processor.h" +#include "segmenter/segmentation-utils.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Post processing of segmentation that does the following operations " + "in order: \n" + "1) Merge labels: Merge labels specified in --merge-labels into a " + "single label specified by --merge-dst-label. \n" + "2) Padding segments: Pad segments of label specified by --pad-label " + "by a few frames as specified by --pad-length. \n" + "3) Shrink segments: Shrink segments of label specified by " + "--shrink-label by a few frames as specified by --shrink-length. \n" + "4) Blend segments with neighbors: Blend short segments of class-id " + "specified by --blend-short-segments-class that are " + "shorter than --max-blend-length frames with their " + "respective neighbors if both the neighbors are within " + "a distance of --max-intersegment-length frames.\n" + "5) Remove segments: Remove segments of class-ids contained " + "in --remove-labels.\n" + "6) Merge adjacent segments: Merge adjacent segments of the same " + "label if they are within a distance of --max-intersegment-length " + "frames.\n" + "7) Split segments: Split segments that are longer than " + "--max-segment-length frames into overlapping segments " + "with an overlap of --overlap-length frames. \n" + "Usage: segmentation-post-process [options] " + "\n" + " or : segmentation-post-process [options] " + "\n" + " e.g.: segmentation-post-process --binary=false foo -\n" + " segmentation-post-process ark:foo.seg ark,t:-\n" + "See also: segmentation-merge, segmentation-copy, " + "segmentation-remove-segments\n"; + + bool binary = true; + + ParseOptions po(usage); + + SegmentationPostProcessingOptions opts; + + po.Register("binary", &binary, + "Write in binary mode " + "(only relevant if output is a wxfilename)"); + + opts.Register(&po); + + po.Read(argc, argv); + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + SegmentationPostProcessor post_processor(opts); + + std::string segmentation_in_fn = po.GetArg(1), + segmentation_out_fn = po.GetArg(2); + + bool in_is_rspecifier = + (ClassifyRspecifier(segmentation_in_fn, NULL, NULL) + != kNoRspecifier), + out_is_wspecifier = + (ClassifyWspecifier(segmentation_out_fn, NULL, NULL, NULL) + != kNoWspecifier); + + if (in_is_rspecifier != out_is_wspecifier) + KALDI_ERR << "Cannot mix regular files and archives"; + + int64 num_done = 0, num_err = 0; + + if (!in_is_rspecifier) { + Segmentation segmentation; + { + bool binary_in; + Input ki(segmentation_in_fn, &binary_in); + segmentation.Read(ki.Stream(), binary_in); + } + if (post_processor.PostProcess(&segmentation)) { + Output ko(segmentation_out_fn, binary); + Sort(&segmentation); + segmentation.Write(ko.Stream(), binary); + KALDI_LOG << "Post-processed segmentation " << segmentation_in_fn + << " and wrote " << segmentation_out_fn; + return 0; + } + KALDI_LOG << "Failed post-processing segmentation " + << segmentation_in_fn; + return 1; + } + + SegmentationWriter writer(segmentation_out_fn); + SequentialSegmentationReader reader(segmentation_in_fn); + for (; !reader.Done(); reader.Next()) { + Segmentation segmentation(reader.Value()); + const std::string &key = reader.Key(); + + if (!post_processor.PostProcess(&segmentation)) { + num_err++; + continue; + } + + Sort(&segmentation); + + writer.Write(key, segmentation); + num_done++; + } + + KALDI_LOG << "Successfully post-processed " << num_done + << " segmentations; " + << "failed with " << num_err << " segmentations"; + return (num_done != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/segmentation-remove-segments.cc b/src/segmenterbin/segmentation-remove-segments.cc new file mode 100644 index 00000000000..27af1420e54 --- /dev/null +++ b/src/segmenterbin/segmentation-remove-segments.cc @@ -0,0 +1,161 @@ +// segmenterbin/segmentation-remove-segments.cc + +// Copyright 2015-16 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmentation-utils.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Remove segments of particular class_id (e.g silence or noise) " + "or a set of class_ids.\n" + "The labels to removed can be made utterance-specific by passing " + "--remove-labels-rspecifier option.\n" + "\n" + "Usage: segmentation-remove-segments [options] " + " \n" + " or : segmentation-remove-segments [options] " + " \n" + "\n" + " e.g.: segmentation-remove-segments --remove-label=0 ark:foo.ark " + "ark:foo.speech.ark\n" + "See also: segmentation-post-process --remove-labels, " + "segmentation-post-process --max-blend-length, segmentation-copy\n"; + + bool binary = true; + + int32 remove_label = -1; + int32 max_remove_length = -1; + std::string remove_labels_rspecifier = ""; + + ParseOptions po(usage); + + po.Register("binary", &binary, + "Write in binary mode " + "(only relevant if output is a wxfilename)"); + po.Register("remove-label", &remove_label, "Remove segments of this label"); + po.Register("remove-labels-rspecifier", &remove_labels_rspecifier, + "Specify colon separated list of labels for each key"); + po.Register("max-remove-length", &max_remove_length, + "If supplied, this specifies the maximum length of segments " + "will be removed. A value of -1 specifies a length of " + "+infinity i.e. segments will be removed based " + "on only their labels and irrespective of their lengths."); + + po.Read(argc, argv); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + std::string segmentation_in_fn = po.GetArg(1), + segmentation_out_fn = po.GetArg(2); + + // all these "fn"'s are either rspecifiers or filenames. + + bool in_is_rspecifier = + (ClassifyRspecifier(segmentation_in_fn, NULL, NULL) + != kNoRspecifier), + out_is_wspecifier = + (ClassifyWspecifier(segmentation_out_fn, NULL, NULL, NULL) + != kNoWspecifier); + + if (in_is_rspecifier != out_is_wspecifier) + KALDI_ERR << "Cannot mix regular files and archives"; + + int64 num_done = 0, num_missing = 0; + + if (!in_is_rspecifier) { + Segmentation segmentation; + { + bool binary_in; + Input ki(segmentation_in_fn, &binary_in); + segmentation.Read(ki.Stream(), binary_in); + } + if (!remove_labels_rspecifier.empty()) { + KALDI_ERR << "It does not make sense to specify " + << "--remove-labels-rspecifier " + << "for single segmentation"; + } + + RemoveSegments(remove_label, &segmentation); + + { + Output ko(segmentation_out_fn, binary); + segmentation.Write(ko.Stream(), binary); + } + + KALDI_LOG << "Removed segments and wrote segmentation to " + << segmentation_out_fn; + + return 0; + } else { + SegmentationWriter writer(segmentation_out_fn); + SequentialSegmentationReader reader(segmentation_in_fn); + + RandomAccessTokenReader remove_labels_reader(remove_labels_rspecifier); + + for (; !reader.Done(); reader.Next(), num_done++) { + Segmentation segmentation(reader.Value()); + std::string key = reader.Key(); + + if (!remove_labels_rspecifier.empty()) { + if (!remove_labels_reader.HasKey(key)) { + KALDI_WARN << "No remove-labels found for recording " << key; + num_missing++; + writer.Write(key, segmentation); + continue; + } + + std::vector remove_labels; + const std::string& remove_labels_str = + remove_labels_reader.Value(key); + + if (!SplitStringToIntegers(remove_labels_str, ":,", false, + &remove_labels)) { + KALDI_ERR << "Bad colon-separated list " + << remove_labels_str << " for key " << key + << " in " << remove_labels_rspecifier; + } + + remove_label = remove_labels[0]; + + RemoveSegments(remove_labels, max_remove_length, &segmentation); + } else { + RemoveSegments(remove_label, &segmentation); + } + writer.Write(key, segmentation); + } + + KALDI_LOG << "Removed segments " << "from " << num_done + << " segmentations; " + << "remove-labels list missing for " << num_missing; + return (num_done != 0 ? 0 : 1); + } + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/segmentation-split-segments.cc b/src/segmenterbin/segmentation-split-segments.cc new file mode 100644 index 00000000000..a45211b28ca --- /dev/null +++ b/src/segmenterbin/segmentation-split-segments.cc @@ -0,0 +1,194 @@ +// segmenterbin/segmentation-split-segments.cc + +// Copyright 2016 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmentation-utils.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Split long segments optionally using alignment.\n" + "The splitting works in two possible ways:\n" + " 1) If alignment is not provided: The segments are split if they\n" + " are longer than --max-segment-length frames into overlapping\n" + " segments with an overlap of --overlap-length frames.\n" + " 2) If alignment is provided: The segments are split if they\n" + " are longer than --max-segment-length frames at the region \n" + " where there is a contiguous segment of --ali-label in the \n" + " alignment that is at least --min-alignment-chunk-length frames \n" + " long.\n" + "Usage: segmentation-split-segments [options] " + " \n" + " or : segmentation-split-segments [options] " + " \n" + " e.g.: segmentation-split-segments --binary=false foo -\n" + " segmentation-split-segments ark:foo.seg ark,t:-\n" + "See also: segmentation-post-process\n"; + + bool binary = true; + int32 max_segment_length = -1; + int32 min_remainder = -1; + int32 overlap_length = 0; + int32 split_label = -1; + int32 ali_label = 0; + int32 min_alignment_chunk_length = 2; + + std::string alignments_in_fn; + + ParseOptions po(usage); + + po.Register("binary", &binary, + "Write in binary mode " + "(only relevant if output is a wxfilename)"); + po.Register("max-segment-length", &max_segment_length, + "If segment is longer than this length, split it into " + "pieces with less than these many frames. " + "Refer to the SplitSegments() code for details. " + "Used in conjunction with the option --overlap-length."); + po.Register("min-remainder", &min_remainder, + "The minimum remainder left after splitting that will " + "prevent a splitting from begin done. " + "Set to max-segment-length / 2, if not specified. " + "Applicable only when alignments is not specified."); + po.Register("overlap-length", &overlap_length, + "When splitting segments longer than max-segment-length, " + "have the pieces overlap by these many frames. " + "Refer to the SplitSegments() code for details. " + "Used in conjunction with the option --max-segment-length."); + po.Register("split-label", &split_label, + "If supplied, split only segments of these labels. " + "Otherwise, split all segments."); + po.Register("alignments", &alignments_in_fn, + "A single alignment file or archive of alignment used " + "for splitting, " + "depending on whether the input segmentation is single file " + "or archive"); + po.Register("ali-label", &ali_label, + "Split at this label of alignments"); + po.Register("min-alignment-chunk-length", &min_alignment_chunk_length, + "The minimum number of frames of alignment with ali_label " + "at which to split the segments"); + + po.Read(argc, argv); + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + std::string segmentation_in_fn = po.GetArg(1), + segmentation_out_fn = po.GetArg(2); + + bool in_is_rspecifier = + (ClassifyRspecifier(segmentation_in_fn, NULL, NULL) + != kNoRspecifier), + out_is_wspecifier = + (ClassifyWspecifier(segmentation_out_fn, NULL, NULL, NULL) + != kNoWspecifier); + + if (in_is_rspecifier != out_is_wspecifier) + KALDI_ERR << "Cannot mix regular files and archives"; + + if (min_remainder == -1) { + min_remainder = max_segment_length / 2; + } + + int64 num_done = 0, num_err = 0; + + if (!in_is_rspecifier) { + std::vector ali; + + Segmentation segmentation; + { + bool binary_in; + Input ki(segmentation_in_fn, &binary_in); + segmentation.Read(ki.Stream(), binary_in); + } + + if (!alignments_in_fn.empty()) { + { + bool binary_in; + Input ki(alignments_in_fn, &binary_in); + ReadIntegerVector(ki.Stream(), binary_in, &ali); + } + SplitSegmentsUsingAlignment(max_segment_length, + split_label, ali, ali_label, + min_alignment_chunk_length, + &segmentation); + } else { + SplitSegments(max_segment_length, min_remainder, + overlap_length, split_label, &segmentation); + } + + Sort(&segmentation); + + { + Output ko(segmentation_out_fn, binary); + segmentation.Write(ko.Stream(), binary); + } + + KALDI_LOG << "Split segmentation " << segmentation_in_fn + << " and wrote " << segmentation_out_fn; + return 0; + } + + SegmentationWriter writer(segmentation_out_fn); + SequentialSegmentationReader reader(segmentation_in_fn); + RandomAccessInt32VectorReader ali_reader(alignments_in_fn); + + for (; !reader.Done(); reader.Next()) { + Segmentation segmentation(reader.Value()); + const std::string &key = reader.Key(); + + if (!alignments_in_fn.empty()) { + if (!ali_reader.HasKey(key)) { + KALDI_WARN << "Could not find key " << key + << " in alignments " << alignments_in_fn; + num_err++; + continue; + } + SplitSegmentsUsingAlignment(max_segment_length, split_label, + ali_reader.Value(key), ali_label, + min_alignment_chunk_length, + &segmentation); + } else { + SplitSegments(max_segment_length, min_remainder, + overlap_length, split_label, + &segmentation); + } + + Sort(&segmentation); + + writer.Write(key, segmentation); + num_done++; + } + + KALDI_LOG << "Successfully split " << num_done + << " segmentations; " + << "failed with " << num_err << " segmentations"; + return (num_done != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/segmentation-to-ali.cc b/src/segmenterbin/segmentation-to-ali.cc new file mode 100644 index 00000000000..9a618247a42 --- /dev/null +++ b/src/segmenterbin/segmentation-to-ali.cc @@ -0,0 +1,99 @@ +// segmenterbin/segmentation-to-ali.cc + +// Copyright 2015-16 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmentation-utils.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Convert segmentation to alignment\n" + "\n" + "Usage: segmentation-to-ali [options] " + "\n" + " e.g.: segmentation-to-ali ark:1.seg ark:1.ali\n"; + + std::string lengths_rspecifier; + int32 default_label = 0, length_tolerance = 2; + + ParseOptions po(usage); + + po.Register("lengths-rspecifier", &lengths_rspecifier, + "Archive of frame lengths " + "of the utterances. Fills up any extra length with " + "the specified default-label"); + po.Register("default-label", &default_label, "Fill any extra length " + "with this label"); + po.Register("length-tolerance", &length_tolerance, "Tolerate shortage of " + "this many frames in the specified lengths file"); + + po.Read(argc, argv); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + std::string segmentation_rspecifier = po.GetArg(1); + std::string alignment_wspecifier = po.GetArg(2); + + RandomAccessInt32Reader lengths_reader(lengths_rspecifier); + + SequentialSegmentationReader segmentation_reader(segmentation_rspecifier); + Int32VectorWriter alignment_writer(alignment_wspecifier); + + int32 num_err = 0, num_done = 0; + for (; !segmentation_reader.Done(); segmentation_reader.Next()) { + const Segmentation &segmentation = segmentation_reader.Value(); + const std::string &key = segmentation_reader.Key(); + + int32 length = -1; + if (lengths_rspecifier != "") { + if (!lengths_reader.HasKey(key)) { + KALDI_WARN << "Could not find length for utterance " << key; + num_err++; + continue; + } + length = lengths_reader.Value(key); + } + + std::vector ali; + if (!ConvertToAlignment(segmentation, default_label, length, + length_tolerance, &ali)) { + KALDI_WARN << "Conversion failed for utterance " << key; + num_err++; + continue; + } + alignment_writer.Write(key, ali); + num_done++; + } + + KALDI_LOG << "Converted " << num_done << " segmentations into alignments; " + << "failed with " << num_err << " segmentations"; + return (num_done != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/segmentation-to-rttm.cc b/src/segmenterbin/segmentation-to-rttm.cc new file mode 100644 index 00000000000..8f22d78f3bc --- /dev/null +++ b/src/segmenterbin/segmentation-to-rttm.cc @@ -0,0 +1,284 @@ +// segmenterbin/segmentation-to-rttm.cc + +// Copyright 2015-16 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmentation.h" + +namespace kaldi { +namespace segmenter { + +/** + * This function is used to write the segmentation in RTTM format. Each class is + * treated as a "SPEAKER". If map_to_speech_and_sil is true, then the class_id 0 + * is treated as SILENCE and every other class_id as SPEECH. The argument + * start_time is used to set what the time corresponding to the 0 frame in the + * segment. Each segment is converted into the following line, + * SPEAKER 1 + * ,where + * is the file_id supplied as an argument + * is the start time of the segment in seconds + * is the length of the segment in seconds + * is the class_id stored in the segment. If map_to_speech_and_sil is + * set true then is either SPEECH or SILENCE. + * The function retunns the largest class_id that it encounters. +**/ + +void WriteRttm(const Segmentation &segmentation, + const std::string &file_id, + const std::string &channel, + BaseFloat frame_shift, BaseFloat start_time, + bool map_to_speech_and_sil, + int32 no_score_label, std::ostream &os) { + SegmentList::const_iterator it = segmentation.Begin(); + + unordered_map classes_map; + std::vector classes_vec; + + for (; it != segmentation.End(); ++it) { + if (no_score_label > 0 && it->Label() == no_score_label) { + os << "NOSCORE " << file_id << " " << channel << " " + << it->start_frame * frame_shift + start_time << " " + << (it->Length()) * frame_shift << " \n"; + continue; + } + os << "SPEAKER " << file_id << " " << channel << " " + << it->start_frame * frame_shift + start_time << " " + << (it->Length()) * frame_shift << " "; + if (map_to_speech_and_sil) { + switch (it->Label()) { + case 0: + os << "SILENCE "; + break; + default: + os << "SPEECH "; + break; + } + } else { + if (it->Label() >= 0) { + os << it->Label() << " "; + if (classes_map.count(it->Label()) == 0) { + classes_map[it->Label()] = true; + classes_vec.push_back(it->Label()); + } + } + } + os << "" << std::endl; + } + + if (!map_to_speech_and_sil) { + for (std::vector::const_iterator it = classes_vec.begin(); + it != classes_vec.end(); ++it) { + os << "SPKR-INFO " << file_id << " " << channel + << " unknown " << *it << " \n"; + } + } +} + +} // namespace segmenter +} // namespace kaldi + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Convert segmentation into RTTM\n" + "\n" + "Usage: segmentation-to-rttm [options] " + "\n" + " e.g.: segmentation-to-rttm ark:1.seg -\n"; + + bool map_to_speech_and_sil = true; + int32 no_score_label = -1; + + BaseFloat frame_shift = 0.01; + std::string segments_rxfilename; + std::string reco2file_and_channel_rxfilename; + ParseOptions po(usage); + + po.Register("frame-shift", &frame_shift, "Frame shift in seconds"); + po.Register("segments", &segments_rxfilename, "Segments file"); + po.Register("reco2file-and-channel", &reco2file_and_channel_rxfilename, + "reco2file_and_channel file"); + po.Register("map-to-speech-and-sil", &map_to_speech_and_sil, + "Map all classes other than 0 to SPEECH"); + po.Register("no-score-label", &no_score_label, + "If specified, then a NOSCORE region is added to RTTM " + "when this label occurs in the segmentation."); + + po.Read(argc, argv); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + unordered_map utt2file; + unordered_map utt2start_time; + + if (!segments_rxfilename.empty()) { + Input ki(segments_rxfilename); // no binary argment: never binary. + int32 i = 0; + std::string line; + /* read each line from segments file */ + while (std::getline(ki.Stream(), line)) { + std::vector split_line; + // Split the line by space or tab and check the number of fields in each + // line. There must be 4 fields--segment name , reacording wav file + // name, start time, end time; 5th field (channel info) is optional. + SplitStringToVector(line, " \t\r", true, &split_line); + if (split_line.size() != 4 && split_line.size() != 5) { + KALDI_WARN << "Invalid line in segments file: " << line; + continue; + } + std::string segment = split_line[0], + utterance = split_line[1], + start_str = split_line[2], + end_str = split_line[3]; + + // Convert the start time and endtime to real from string. Segment is + // ignored if start or end time cannot be converted to real. + double start, end; + if (!ConvertStringToReal(start_str, &start)) { + KALDI_WARN << "Invalid line in segments file [bad start]: " << line; + continue; + } + if (!ConvertStringToReal(end_str, &end)) { + KALDI_WARN << "Invalid line in segments file [bad end]: " << line; + continue; + } + // start time must not be negative; start time must not be greater than + // end time, except if end time is -1 + if (start < 0 || end <= 0 || start >= end) { + KALDI_WARN << "Invalid line in segments file " + << "[empty or invalid segment]: " + << line; + continue; + } + int32 channel = -1; // means channel info is unspecified. + // if each line has 5 elements then 5th element must be channel + // identifier + if (split_line.size() == 5) { + if (!ConvertStringToInteger(split_line[4], &channel) || channel < 0) { + KALDI_WARN << "Invalid line in segments file " + << "[bad channel]: " << line; + continue; + } + } + + utt2file.insert(std::make_pair(segment, utterance)); + utt2start_time.insert(std::make_pair(segment, start)); + i++; + } + KALDI_LOG << "Read " << i << " lines from " << segments_rxfilename; + } + + unordered_map, + StringHasher> reco2file_and_channel; + + if (!reco2file_and_channel_rxfilename.empty()) { + // no binary argment: never binary. + Input ki(reco2file_and_channel_rxfilename); + + int32 i = 0; + std::string line; + /* read each line from reco2file_and_channel file */ + while (std::getline(ki.Stream(), line)) { + std::vector split_line; + SplitStringToVector(line, " \t\r", true, &split_line); + if (split_line.size() != 3) { + KALDI_WARN << "Invalid line in reco2file_and_channel file: " << line; + continue; + } + + const std::string &reco_id = split_line[0]; + const std::string &file_id = split_line[1]; + const std::string &channel = split_line[2]; + + reco2file_and_channel.insert( + std::make_pair(reco_id, std::make_pair(file_id, channel))); + i++; + } + + KALDI_LOG << "Read " << i << " lines from " + << reco2file_and_channel_rxfilename; + } + + unordered_set seen_files; + + std::string segmentation_rspecifier = po.GetArg(1), + rttm_out_wxfilename = po.GetArg(2); + + int64 num_done = 0, num_err = 0; + + Output ko(rttm_out_wxfilename, false); + SequentialSegmentationReader reader(segmentation_rspecifier); + for (; !reader.Done(); reader.Next(), num_done++) { + Segmentation segmentation(reader.Value()); + const std::string &key = reader.Key(); + + std::string reco_id = key; + BaseFloat start_time = 0.0; + if (!segments_rxfilename.empty()) { + if (utt2file.count(key) == 0 || utt2start_time.count(key) == 0) + KALDI_ERR << "Could not find key " << key << " in segments " + << segments_rxfilename; + KALDI_ASSERT(utt2file.count(key) > 0 && utt2start_time.count(key) > 0); + reco_id = utt2file[key]; + start_time = utt2start_time[key]; + } + + std::string file_id, channel; + if (!reco2file_and_channel_rxfilename.empty()) { + if (reco2file_and_channel.count(reco_id) == 0) + KALDI_ERR << "Could not find recording " << reco_id + << " in " << reco2file_and_channel_rxfilename; + file_id = reco2file_and_channel[reco_id].first; + channel = reco2file_and_channel[reco_id].second; + } else { + file_id = reco_id; + channel = "1"; + } + + WriteRttm(segmentation, file_id, + channel, frame_shift, start_time, + map_to_speech_and_sil, no_score_label, ko.Stream()); + + if (map_to_speech_and_sil) { + if (seen_files.count(reco_id) == 0) { + ko.Stream() << "SPKR-INFO " << file_id << " " << channel + << " unknown SILENCE \n"; + ko.Stream() << "SPKR-INFO " << file_id << " " << channel + << " unknown SPEECH \n"; + seen_files.insert(reco_id); + } + } + } + + KALDI_LOG << "Copied " << num_done << " segmentation; failed with " + << num_err << " segmentations"; + return (num_done != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} diff --git a/src/segmenterbin/segmentation-to-segments.cc b/src/segmenterbin/segmentation-to-segments.cc new file mode 100644 index 00000000000..c57aa827ead --- /dev/null +++ b/src/segmenterbin/segmentation-to-segments.cc @@ -0,0 +1,133 @@ +// segmenterbin/segmentation-to-segments.cc + +// Copyright 2015-16 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmentation.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Convert segmentation to a segments file and utt2spk file." + "Assumes that the input segmentations are indexed by reco-id and " + "treats speakers from different recording as distinct speakers." + "\n" + "Usage: segmentation-to-segments [options] " + " \n" + " e.g.: segmentation-to-segments ark:foo.seg ark,t:utt2spk segments\n"; + + BaseFloat frame_shift = 0.01, frame_overlap = 0.015; + bool single_speaker = false, per_utt_speaker = false; + ParseOptions po(usage); + + po.Register("frame-shift", &frame_shift, "Frame shift in seconds"); + po.Register("frame-overlap", &frame_overlap, "Frame overlap in seconds"); + po.Register("single-speaker", &single_speaker, "If this is set true, " + "then all the utterances in a recording are mapped to the " + "same speaker"); + po.Register("per-utt-speaker", &per_utt_speaker, + "If this is set true, then each utterance is mapped to distint " + "speaker with spkr_id = utt_id"); + + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + if (frame_shift < 0.001 || frame_shift > 1) { + KALDI_ERR << "Invalid frame-shift " << frame_shift << "; must be in " + << "the range [0.001,1]"; + } + + if (frame_overlap < 0 || frame_overlap > 1) { + KALDI_ERR << "Invalid frame-overlap " << frame_overlap << "; must be in " + << "the range [0,1]"; + } + + std::string segmentation_rspecifier = po.GetArg(1), + utt2spk_wspecifier = po.GetArg(2), + segments_wxfilename = po.GetArg(3); + + SequentialSegmentationReader reader(segmentation_rspecifier); + TokenWriter utt2spk_writer(utt2spk_wspecifier); + + Output ko(segments_wxfilename, false); + + int32 num_done = 0; + int64 num_segments = 0; + + for (; !reader.Done(); reader.Next(), num_done++) { + const Segmentation &segmentation = reader.Value(); + const std::string &key = reader.Key(); + + for (SegmentList::const_iterator it = segmentation.Begin(); + it != segmentation.End(); ++it) { + BaseFloat start_time = it->start_frame * frame_shift; + BaseFloat end_time = (it->end_frame + 1) * frame_shift + frame_overlap; + + std::ostringstream oss; + + if (!single_speaker) { + oss << key << "-" << it->Label(); + } else { + oss << key; + } + + std::string spk = oss.str(); + + oss << "-"; + oss << std::setw(6) << std::setfill('0') << it->start_frame; + oss << std::setw(1) << "-"; + oss << std::setw(6) << std::setfill('0') + << it->end_frame + 1 + + static_cast(frame_overlap / frame_shift); + + std::string utt = oss.str(); + + if (per_utt_speaker) + utt2spk_writer.Write(utt, utt); + else + utt2spk_writer.Write(utt, spk); + + ko.Stream() << utt << " " << key << " "; + ko.Stream() << std::fixed << std::setprecision(3) << start_time << " "; + ko.Stream() << std::setprecision(3) << end_time << "\n"; + + num_segments++; + } + } + + KALDI_LOG << "Converted " << num_done << " segmentations to segments; " + << "wrote " << num_segments << " segments"; + + return (num_done != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/simplehmm/Makefile b/src/simplehmm/Makefile new file mode 100644 index 00000000000..89c9f70a8c3 --- /dev/null +++ b/src/simplehmm/Makefile @@ -0,0 +1,16 @@ +all: + + +include ../kaldi.mk + +TESTFILES = simple-hmm-test + +OBJFILES = simple-hmm.o simple-hmm-utils.o simple-hmm-graph-compiler.o + +LIBNAME = kaldi-simplehmm +ADDLIBS = ../hmm/kaldi-hmm.a ../decoder/kaldi-decoder.a \ + ../util/kaldi-util.a ../thread/kaldi-thread.a \ + ../matrix/kaldi-matrix.a ../base/kaldi-base.a + +include ../makefiles/default_rules.mk + diff --git a/src/simplehmm/decodable-simple-hmm.h b/src/simplehmm/decodable-simple-hmm.h new file mode 100644 index 00000000000..6f224ee6176 --- /dev/null +++ b/src/simplehmm/decodable-simple-hmm.h @@ -0,0 +1,88 @@ +// simplehmm/decodable-simple-hmm.h + +// Copyright 2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_SIMPLEHMM_DECODABLE_SIMPLE_HMM_H_ +#define KALDI_SIMPLEHMM_DECODABLE_SIMPLE_HMM_H_ + +#include + +#include "base/kaldi-common.h" +#include "simplehmm/simple-hmm.h" +#include "itf/decodable-itf.h" + +namespace kaldi { +namespace simple_hmm { + +class DecodableMatrixSimpleHmm: public DecodableInterface { + public: + // This constructor creates an object that will not delete "likes" + // when done. + DecodableMatrixSimpleHmm(const SimpleHmm &model, + const Matrix &likes, + BaseFloat scale): + model_(model), likes_(&likes), scale_(scale), delete_likes_(false) + { + if (likes.NumCols() != model.NumPdfs()) + KALDI_ERR << "DecodableMatrixScaledMapped: mismatch, matrix has " + << likes.NumCols() << " rows but transition-model has " + << model.NumPdfs() << " pdf-ids."; + } + + // This constructor creates an object that will delete "likes" + // when done. + DecodableMatrixSimpleHmm(const SimpleHmm &model, + BaseFloat scale, + const Matrix *likes): + model_(model), likes_(likes), scale_(scale), delete_likes_(true) { + if (likes->NumCols() != model.NumPdfs()) + KALDI_ERR << "DecodableMatrixScaledMapped: mismatch, matrix has " + << likes->NumCols() << " rows but transition-model has " + << model.NumPdfs() << " pdf-ids."; + } + + virtual int32 NumFramesReady() const { return likes_->NumRows(); } + + virtual bool IsLastFrame(int32 frame) const { + KALDI_ASSERT(frame < NumFramesReady()); + return (frame == NumFramesReady() - 1); + } + + // Note, frames are numbered from zero. + virtual BaseFloat LogLikelihood(int32 frame, int32 tid) { + return scale_ * (*likes_)(frame, model_.TransitionIdToPdfClass(tid)); + } + + // Indices are one-based! This is for compatibility with OpenFst. + virtual int32 NumIndices() const { return model_.NumTransitionIds(); } + + virtual ~DecodableMatrixSimpleHmm() { + if (delete_likes_) delete likes_; + } + private: + const SimpleHmm &model_; // for tid to pdf mapping + const Matrix *likes_; + BaseFloat scale_; + bool delete_likes_; + KALDI_DISALLOW_COPY_AND_ASSIGN(DecodableMatrixSimpleHmm); +}; + +} // namespace simple_hmm +} // namespace kaldi + +#endif // KALDI_SIMPLEHMM_DECODABLE_SIMPLE_HMM_H_ diff --git a/src/simplehmm/simple-hmm-graph-compiler.cc b/src/simplehmm/simple-hmm-graph-compiler.cc new file mode 100644 index 00000000000..9626e08ae5f --- /dev/null +++ b/src/simplehmm/simple-hmm-graph-compiler.cc @@ -0,0 +1,128 @@ +// decoder/simple-hmm-graph-compiler.cc + +// Copyright 2009-2011 Microsoft Corporation +// 2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "simplehmm/simple-hmm-graph-compiler.h" +#include "simplehmm/simple-hmm-utils.h" // for GetHTransducer + +namespace kaldi { + +bool SimpleHmmGraphCompiler::CompileGraphFromAlignment( + const std::vector &alignment, + fst::VectorFst *out_fst) { + using namespace fst; + VectorFst class_fst; + MakeLinearAcceptor(alignment, &class_fst); + return CompileGraph(class_fst, out_fst); +} + +bool SimpleHmmGraphCompiler::CompileGraph( + const fst::VectorFst &class_fst, + fst::VectorFst *out_fst) { + using namespace fst; + KALDI_ASSERT(out_fst); + KALDI_ASSERT(class_fst.Start() != kNoStateId); + + if (GetVerboseLevel() >= 4) { + KALDI_VLOG(4) << "Classes FST: "; + WriteFstKaldi(KALDI_LOG, false, class_fst); + } + + VectorFst *H = GetHTransducer(model_, opts_.transition_scale, + opts_.self_loop_scale); + + if (GetVerboseLevel() >= 4) { + KALDI_VLOG(4) << "HTransducer:"; + WriteFstKaldi(KALDI_LOG, false, *H); + } + + // Epsilon-removal and determinization combined. + // This will fail if not determinizable. + DeterminizeStarInLog(H); + + if (GetVerboseLevel() >= 4) { + KALDI_VLOG(4) << "HTransducer determinized:"; + WriteFstKaldi(KALDI_LOG, false, *H); + } + + VectorFst &trans2class_fst = *out_fst; // transition-id to class. + TableCompose(*H, class_fst, &trans2class_fst); + + KALDI_ASSERT(trans2class_fst.Start() != kNoStateId); + + if (GetVerboseLevel() >= 4) { + KALDI_VLOG(4) << "trans2class_fst:"; + WriteFstKaldi(KALDI_LOG, false, trans2class_fst); + } + + // Epsilon-removal and determinization combined. + // This will fail if not determinizable. + DeterminizeStarInLog(&trans2class_fst); + + // we elect not to remove epsilons after this phase, as it is + // a little slow. + if (opts_.rm_eps) + RemoveEpsLocal(&trans2class_fst); + + // Encoded minimization. + MinimizeEncoded(&trans2class_fst); + + delete H; + return true; +} + +bool SimpleHmmGraphCompiler::CompileGraphsFromAlignments( + const std::vector > &alignments, + std::vector*> *out_fsts) { + using namespace fst; + std::vector* > class_fsts(alignments.size()); + for (size_t i = 0; i < alignments.size(); i++) { + VectorFst *class_fst = new VectorFst(); + MakeLinearAcceptor(alignments[i], class_fst); + class_fsts[i] = class_fst; + } + bool ans = CompileGraphs(class_fsts, out_fsts); + for (size_t i = 0; i < alignments.size(); i++) + delete class_fsts[i]; + return ans; +} + +bool SimpleHmmGraphCompiler::CompileGraphs( + const std::vector* > &class_fsts, + std::vector* > *out_fsts) { + + using namespace fst; + KALDI_ASSERT(out_fsts && out_fsts->empty()); + out_fsts->resize(class_fsts.size(), NULL); + if (class_fsts.empty()) return true; + + for (size_t i = 0; i < class_fsts.size(); i++) { + const VectorFst *class_fst = class_fsts[i]; + VectorFst out_fst; + + CompileGraph(*class_fst, &out_fst); + + (*out_fsts)[i] = out_fst.Copy(); + } + + return true; +} + + +} // end namespace kaldi diff --git a/src/simplehmm/simple-hmm-graph-compiler.h b/src/simplehmm/simple-hmm-graph-compiler.h new file mode 100644 index 00000000000..dcc8f8fd2ba --- /dev/null +++ b/src/simplehmm/simple-hmm-graph-compiler.h @@ -0,0 +1,100 @@ +// decoder/simple-hmm-graph-compiler.h + +// Copyright 2009-2011 Microsoft Corporation +// 2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_DECODER_SIMPLE_HMM_GRAPH_COMPILER_H_ +#define KALDI_DECODER_SIMPLE_HMM_GRAPH_COMPILER_H_ + +#include "base/kaldi-common.h" +#include "simplehmm/simple-hmm.h" +#include "fst/fstlib.h" +#include "fstext/fstext-lib.h" + + +// This header provides functionality to compile a graph directly from the +// alignment where the alignment is of classes that are simple mappings +// of 'pdf-ids' (same as pdf classes for SimpleHmm). + +namespace kaldi { + +struct SimpleHmmGraphCompilerOptions { + BaseFloat transition_scale; + BaseFloat self_loop_scale; + bool rm_eps; + + explicit SimpleHmmGraphCompilerOptions(BaseFloat transition_scale = 1.0, + BaseFloat self_loop_scale = 1.0): + transition_scale(transition_scale), + self_loop_scale(self_loop_scale), + rm_eps(true) { } + + void Register(OptionsItf *opts) { + opts->Register("transition-scale", &transition_scale, "Scale of transition " + "probabilities (excluding self-loops)"); + opts->Register("self-loop-scale", &self_loop_scale, "Scale of self-loop vs. " + "non-self-loop probability mass "); + opts->Register("rm-eps", &rm_eps, "Remove [most] epsilons before minimization (only applicable " + "if disambig symbols present)"); + } +}; + + +class SimpleHmmGraphCompiler { + public: + SimpleHmmGraphCompiler(const SimpleHmm &model, // Maintains reference to this object. + const SimpleHmmGraphCompilerOptions &opts): + model_(model), opts_(opts) { } + + + /// CompileGraph compiles a single training graph its input is a + /// weighted acceptor (G) at the class level, its output is HCLG-type graph. + /// Note: G could actually be an acceptor, it would also work. + /// This function is not const for technical reasons involving the cache. + /// if not for "table_compose" we could make it const. + bool CompileGraph(const fst::VectorFst &class_fst, + fst::VectorFst *out_fst); + + // CompileGraphs allows you to compile a number of graphs at the same + // time. This consumes more memory but is faster. + bool CompileGraphs( + const std::vector *> &class_fsts, + std::vector *> *out_fsts); + + // This version creates an FST from the per-frame alignment and calls + // CompileGraph. + bool CompileGraphFromAlignment(const std::vector &alignment, + fst::VectorFst *out_fst); + + // This function creates FSTs from the per-frame alignment and calls + // CompileGraphs. + bool CompileGraphsFromAlignments( + const std::vector > &alignments, + std::vector *> *out_fsts); + + ~SimpleHmmGraphCompiler() { } + private: + const SimpleHmm &model_; + + SimpleHmmGraphCompilerOptions opts_; +}; + + +} // end namespace kaldi. + +#endif // KALDI_DECODER_SIMPLE_HMM_GRAPH_COMPILER_H_ diff --git a/src/simplehmm/simple-hmm-test.cc b/src/simplehmm/simple-hmm-test.cc new file mode 100644 index 00000000000..b2de0e05a08 --- /dev/null +++ b/src/simplehmm/simple-hmm-test.cc @@ -0,0 +1,76 @@ +// hmm/simple-hmm-test.cc + +// Copyright 2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "simplehmm/simple-hmm.h" +#include "hmm/hmm-test-utils.h" + +namespace kaldi { +namespace simple_hmm { + + +SimpleHmm *GenRandSimpleHmm() { + std::vector phones; + phones.push_back(1); + + std::vector num_pdf_classes; + num_pdf_classes.push_back(rand() + 1); + + HmmTopology topo = GenRandTopology(phones, num_pdf_classes); + + SimpleHmm *model = new SimpleHmm(topo); + + return model; +} + + +void TestSimpleHmm() { + + SimpleHmm *model = GenRandSimpleHmm(); + + bool binary = (rand() % 2 == 0); + + std::ostringstream os; + model->Write(os, binary); + + SimpleHmm model2; + std::istringstream is2(os.str()); + model2.Read(is2, binary); + + { + std::ostringstream os1, os2; + model->Write(os1, false); + model2.Write(os2, false); + KALDI_ASSERT(os1.str() == os2.str()); + KALDI_ASSERT(model->Compatible(model2)); + } + delete model; +} + + +} // end namespace simple_hmm +} // end namespace kaldi + + +int main() { + for (int i = 0; i < 2; i++) + kaldi::TestSimpleHmm(); + KALDI_LOG << "Test OK.\n"; +} + + diff --git a/src/simplehmm/simple-hmm-utils.cc b/src/simplehmm/simple-hmm-utils.cc new file mode 100644 index 00000000000..fc0c7e4ca3c --- /dev/null +++ b/src/simplehmm/simple-hmm-utils.cc @@ -0,0 +1,146 @@ +// hmm/simple-hmm-utils.cc + +// Copyright 2009-2011 Microsoft Corporation +// 2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "simplehmm/simple-hmm-utils.h" +#include "fst/fstlib.h" +#include "fstext/fstext-lib.h" + +namespace kaldi { + +fst::VectorFst* GetHTransducer( + const SimpleHmm &model, + BaseFloat transition_scale, BaseFloat self_loop_scale) { + using namespace fst; + typedef StdArc Arc; + typedef Arc::Weight Weight; + typedef Arc::StateId StateId; + typedef Arc::Label Label; + + VectorFst *fst = GetSimpleHmmAsFst(model, transition_scale, + self_loop_scale); + + for (StateIterator > siter(*fst); + !siter.Done(); siter.Next()) { + Arc::StateId s = siter.Value(); + for (MutableArcIterator > aiter(fst, s); + !aiter.Done(); aiter.Next()) { + Arc arc = aiter.Value(); + if (arc.ilabel == 0) { + KALDI_ASSERT(arc.olabel == 0); + continue; + } + + KALDI_ASSERT(arc.ilabel == arc.olabel && + arc.ilabel <= model.NumTransitionIds()); + + arc.olabel = model.TransitionIdToPdf(arc.ilabel) + 1; + aiter.SetValue(arc); + } + } + + return fst; +} + +fst::VectorFst *GetSimpleHmmAsFst( + const SimpleHmm &model, + BaseFloat transition_scale, BaseFloat self_loop_scale) { + using namespace fst; + typedef StdArc Arc; + typedef Arc::Weight Weight; + typedef Arc::StateId StateId; + typedef Arc::Label Label; + + KALDI_ASSERT(model.NumPdfs() > 0); + const HmmTopology &topo = model.GetTopo(); + // This special Hmm has only one phone + const HmmTopology::TopologyEntry &entry = topo.TopologyForPhone(1); + + VectorFst *ans = new VectorFst; + + // Create a mini-FST with a superfinal state [in case we have emitting + // final-states, which we usually will.] + + std::vector state_ids; + for (size_t i = 0; i < entry.size(); i++) + state_ids.push_back(ans->AddState()); + KALDI_ASSERT(state_ids.size() > 1); // Or invalid topology entry. + ans->SetStart(state_ids[0]); + StateId final_state = state_ids.back(); + ans->SetFinal(final_state, Weight::One()); + + for (int32 hmm_state = 0; + hmm_state < static_cast(entry.size()); + hmm_state++) { + int32 pdf_class = entry[hmm_state].forward_pdf_class; + int32 self_loop_pdf_class = entry[hmm_state].self_loop_pdf_class; + KALDI_ASSERT(self_loop_pdf_class == pdf_class); + + if (pdf_class != kNoPdf) { + KALDI_ASSERT(pdf_class < model.NumPdfs()); + } + + int32 trans_idx; + for (trans_idx = 0; + trans_idx < static_cast(entry[hmm_state].transitions.size()); + trans_idx++) { + BaseFloat log_prob; + Label label; + int32 dest_state = entry[hmm_state].transitions[trans_idx].first; + + if (pdf_class == kNoPdf) { + // no pdf, hence non-estimated probability. very unusual case. [would + // not happen with normal topology] . There is no transition-state + // involved in this case. + KALDI_ASSERT(hmm_state != dest_state); + log_prob = transition_scale + * Log(entry[hmm_state].transitions[trans_idx].second); + label = 0; + } else { // normal probability. + int32 trans_state = + model.TupleToTransitionState(1, hmm_state, pdf_class, pdf_class); + int32 trans_id = + model.PairToTransitionId(trans_state, trans_idx); + + log_prob = model.GetTransitionLogProb(trans_id); + + if (hmm_state == dest_state) + log_prob *= self_loop_scale; + else + log_prob *= transition_scale; + // log_prob is a negative number (or zero)... + label = trans_id; + } + ans->AddArc(state_ids[hmm_state], + Arc(label, label, Weight(-log_prob), + state_ids[dest_state])); + } + } + + fst::RemoveEpsLocal(ans); // this is safe and will not blow up. + // Now apply probability scale. + // We waited till after the possible weight-pushing steps, + // because weight-pushing needs "real" weights in order to work. + // ApplyProbabilityScale(config.transition_scale, ans); + return ans; +} + +} // end namespace kaldi diff --git a/src/simplehmm/simple-hmm-utils.h b/src/simplehmm/simple-hmm-utils.h new file mode 100644 index 00000000000..5bdf185214a --- /dev/null +++ b/src/simplehmm/simple-hmm-utils.h @@ -0,0 +1,51 @@ +// hmm/simple-hmm-utils.h + +// Copyright 2009-2011 Microsoft Corporation +// 2016 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_HMM_SIMPLE_HMM_UTILS_H_ +#define KALDI_HMM_SIMPLE_HMM_UTILS_H_ + +#include "hmm/hmm-utils.h" +#include "simplehmm/simple-hmm.h" +#include "fst/fstlib.h" + +namespace kaldi { + +fst::VectorFst* GetHTransducer( + const SimpleHmm &model, + BaseFloat transition_scale = 1.0, BaseFloat self_loop_scale = 1.0); + +/** + * Converts the SimpleHmm into H tranducer; result owned by caller. + * Caution: our version of + * the H transducer does not include self-loops; you have to add those later. + * See \ref hmm_graph_get_h_transducer. The H transducer has on the + * input transition-ids. + * The output side contains the one-indexed mappings of pdf_ids, typically + * just pdf_id + 1. + */ +fst::VectorFst* +GetSimpleHmmAsFst (const SimpleHmm &model, + BaseFloat transition_scale = 1.0, + BaseFloat self_loop_scale = 1.0); + + +} // end namespace kaldi + +#endif diff --git a/src/simplehmm/simple-hmm.cc b/src/simplehmm/simple-hmm.cc new file mode 100644 index 00000000000..e0e7442ead3 --- /dev/null +++ b/src/simplehmm/simple-hmm.cc @@ -0,0 +1,79 @@ +// hmm/simple-hmm.cc + +// Copyright 2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "simplehmm/simple-hmm.h" + +namespace kaldi { + +void SimpleHmm::FakeContextDependency::GetPdfInfo( + const std::vector &phones, // list of phones + const std::vector &num_pdf_classes, // indexed by phone, + std::vector > > *pdf_info) const { + KALDI_ASSERT(phones.size() == 1 && phones[0] == 1); + KALDI_ASSERT(num_pdf_classes.size() == 2 && + num_pdf_classes[1] == NumPdfs()); + KALDI_ASSERT(pdf_info); + pdf_info->resize(NumPdfs(), + std::vector >()); + + for (int32 pdf = 0; pdf < NumPdfs(); pdf++) { + (*pdf_info)[pdf].push_back(std::make_pair(1, pdf)); + } +} + +void SimpleHmm::FakeContextDependency::GetPdfInfo( + const std::vector &phones, + const std::vector > > &pdf_class_pairs, + std::vector > > > *pdf_info) const { + KALDI_ASSERT(pdf_info); + KALDI_ASSERT(phones.size() == 1 && phones[0] == 1); + KALDI_ASSERT(pdf_class_pairs.size() == 2); + + pdf_info->resize(2); + (*pdf_info)[1].resize(pdf_class_pairs[1].size()); + + for (size_t j = 0; j < pdf_class_pairs[1].size(); j++) { + int32 pdf_class = pdf_class_pairs[1][j].first, + self_loop_pdf_class = pdf_class_pairs[1][j].second; + KALDI_ASSERT(pdf_class == self_loop_pdf_class && + pdf_class < NumPdfs()); + + (*pdf_info)[1][j].push_back(std::make_pair(pdf_class, pdf_class)); + } +} + +void SimpleHmm::Read(std::istream &is, bool binary) { + TransitionModel::Read(is, binary); + ctx_dep_.Init(NumPdfs()); + CheckSimpleHmm(); +} + +void SimpleHmm::CheckSimpleHmm() const { + KALDI_ASSERT(NumPhones() == 1); + KALDI_ASSERT(GetPhones()[0] == 1); + const HmmTopology::TopologyEntry &entry = GetTopo().TopologyForPhone(1); + for (int32 j = 0; j < static_cast(entry.size()); j++) { // for each state... + int32 forward_pdf_class = entry[j].forward_pdf_class, + self_loop_pdf_class = entry[j].self_loop_pdf_class; + KALDI_ASSERT(forward_pdf_class == self_loop_pdf_class && + forward_pdf_class < NumPdfs()); + } +} + +} // end namespace kaldi diff --git a/src/simplehmm/simple-hmm.h b/src/simplehmm/simple-hmm.h new file mode 100644 index 00000000000..6fa9b1db6d2 --- /dev/null +++ b/src/simplehmm/simple-hmm.h @@ -0,0 +1,95 @@ +// hmm/simple-hmm.h + +// Copyright 2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_HMM_SIMPLE_HMM_H_ +#define KALDI_HMM_SIMPLE_HMM_H_ + +#include "base/kaldi-common.h" +#include "hmm/transition-model.h" +#include "itf/context-dep-itf.h" + +namespace kaldi { + +class SimpleHmm: public TransitionModel { + public: + SimpleHmm(const HmmTopology &hmm_topo): + ctx_dep_(hmm_topo) { + Init(ctx_dep_, hmm_topo); + CheckSimpleHmm(); + } + + SimpleHmm(): TransitionModel() { } + + void Read(std::istream &is, bool binary); // note, no symbol table: topo object always read/written w/o symbols. + + private: + void CheckSimpleHmm() const; + + class FakeContextDependency: public ContextDependencyInterface { + public: + int ContextWidth() const { return 1; } + int CentralPosition() const { return 0; } + + bool Compute(const std::vector &phoneseq, int32 pdf_class, + int32 *pdf_id) const { + if (phoneseq.size() == 1 && phoneseq[0] == 1) { + *pdf_id = pdf_class; + return true; + } + return false; + } + + void GetPdfInfo( + const std::vector &phones, // list of phones + const std::vector &num_pdf_classes, // indexed by phone, + std::vector > > *pdf_info) const; + + void GetPdfInfo( + const std::vector &phones, + const std::vector > > &pdf_class_pairs, + std::vector > > > *pdf_info) + const; + + void Init(int32 num_pdfs) { num_pdfs_ = num_pdfs; } + + int32 NumPdfs() const { return num_pdfs_; } + + FakeContextDependency(const HmmTopology &topo) { + KALDI_ASSERT(topo.GetPhones().size() == 1); + num_pdfs_ = topo.NumPdfClasses(1); + } + + FakeContextDependency(): num_pdfs_(0) { } + + ContextDependencyInterface* Copy() const { + FakeContextDependency *copy = new FakeContextDependency(); + copy->Init(num_pdfs_); + return copy; + } + + private: + int32 num_pdfs_; + } ctx_dep_; + + KALDI_DISALLOW_COPY_AND_ASSIGN(SimpleHmm); +}; + +} // end namespace kaldi + +#endif // KALDI_HMM_SIMPLE_HMM_H_ diff --git a/src/simplehmmbin/Makefile b/src/simplehmmbin/Makefile new file mode 100644 index 00000000000..3546ebae7c2 --- /dev/null +++ b/src/simplehmmbin/Makefile @@ -0,0 +1,23 @@ + +all: +EXTRA_CXXFLAGS = -Wno-sign-compare +include ../kaldi.mk + +BINFILES = simple-hmm-init \ + compile-train-simple-hmm-graphs simple-hmm-align-compiled \ + simple-hmm-acc-stats-ali simple-hmm-est make-simple-hmm-graph + + +OBJFILES = + +ADDLIBS = ../decoder/kaldi-decoder.a \ + ../simplehmm/kaldi-simplehmm.a ../lat/kaldi-lat.a \ + ../fstext/kaldi-fstext.a ../hmm/kaldi-hmm.a \ + ../util/kaldi-util.a ../thread/kaldi-thread.a \ + ../matrix/kaldi-matrix.a ../base/kaldi-base.a + + +TESTFILES = + +include ../makefiles/default_rules.mk + diff --git a/src/simplehmmbin/compile-train-simple-hmm-graphs.cc b/src/simplehmmbin/compile-train-simple-hmm-graphs.cc new file mode 100644 index 00000000000..3797e24f2a4 --- /dev/null +++ b/src/simplehmmbin/compile-train-simple-hmm-graphs.cc @@ -0,0 +1,151 @@ +// bin/compile-train-simple-hmm-graphs.cc + +// Copyright 2009-2012 Microsoft Corporation +// 2012-2015 Johns Hopkins University (Author: Daniel Povey) +// 2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "tree/context-dep.h" +#include "simplehmm/simple-hmm.h" +#include "fstext/fstext-lib.h" +#include "simplehmm/simple-hmm-graph-compiler.h" + + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + typedef kaldi::int32 int32; + using fst::SymbolTable; + using fst::VectorFst; + using fst::StdArc; + + const char *usage = + "Creates training graphs (without transition-probabilities, by default)\n" + "for training SimpleHmm models using alignments of pdf-ids.\n" + "Usage: compile-train-simple-hmm-graphs [options] " + " \n" + "e.g.: \n" + " compile-train-simple-hmm-graphs 1.mdl ark:train.tra ark:graphs.fsts\n"; + ParseOptions po(usage); + + SimpleHmmGraphCompilerOptions gopts; + int32 batch_size = 250; + gopts.transition_scale = 0.0; // Change the default to 0.0 since we will generally add the + // transition probs in the alignment phase (since they change eacm time) + gopts.self_loop_scale = 0.0; // Ditto for self-loop probs. + std::string disambig_rxfilename; + gopts.Register(&po); + + po.Register("batch-size", &batch_size, + "Number of FSTs to compile at a time (more -> faster but uses " + "more memory. E.g. 500"); + + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + std::string model_rxfilename = po.GetArg(1); + std::string alignment_rspecifier = po.GetArg(2); + std::string fsts_wspecifier = po.GetArg(3); + + SimpleHmm model; + ReadKaldiObject(model_rxfilename, &model); + + SimpleHmmGraphCompiler gc(model, gopts); + + SequentialInt32VectorReader alignment_reader(alignment_rspecifier); + TableWriter fst_writer(fsts_wspecifier); + + int32 num_succeed = 0, num_fail = 0; + + if (batch_size == 1) { // We treat batch_size of 1 as a special case in order + // to test more parts of the code. + for (; !alignment_reader.Done(); alignment_reader.Next()) { + const std::string &key = alignment_reader.Key(); + std::vector alignment = alignment_reader.Value(); + + for (std::vector::iterator it = alignment.begin(); + it != alignment.end(); ++it) { + KALDI_ASSERT(*it < model.NumPdfs()); + ++(*it); + } + + VectorFst decode_fst; + + if (!gc.CompileGraphFromAlignment(alignment, &decode_fst)) { + decode_fst.DeleteStates(); // Just make it empty. + } + if (decode_fst.Start() != fst::kNoStateId) { + num_succeed++; + fst_writer.Write(key, decode_fst); + } else { + KALDI_WARN << "Empty decoding graph for utterance " + << key; + num_fail++; + } + } + } else { + std::vector keys; + std::vector > alignments; + while (!alignment_reader.Done()) { + keys.clear(); + alignments.clear(); + for (; !alignment_reader.Done() && + static_cast(alignments.size()) < batch_size; + alignment_reader.Next()) { + keys.push_back(alignment_reader.Key()); + alignments.push_back(alignment_reader.Value()); + + for (std::vector::iterator it = alignments.back().begin(); + it != alignments.back().end(); ++it) { + KALDI_ASSERT(*it < model.NumPdfs()); + ++(*it); + } + } + std::vector* > fsts; + if (!gc.CompileGraphsFromAlignments(alignments, &fsts)) { + KALDI_ERR << "Not expecting CompileGraphs to fail."; + } + KALDI_ASSERT(fsts.size() == keys.size()); + for (size_t i = 0; i < fsts.size(); i++) { + if (fsts[i]->Start() != fst::kNoStateId) { + num_succeed++; + fst_writer.Write(keys[i], *(fsts[i])); + } else { + KALDI_WARN << "Empty decoding graph for utterance " + << keys[i]; + num_fail++; + } + } + DeletePointers(&fsts); + } + } + KALDI_LOG << "compile-train--simple-hmm-graphs: succeeded for " + << num_succeed << " graphs, failed for " << num_fail; + return (num_succeed != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + + diff --git a/src/simplehmmbin/make-simple-hmm-graph.cc b/src/simplehmmbin/make-simple-hmm-graph.cc new file mode 100644 index 00000000000..088a73e7c50 --- /dev/null +++ b/src/simplehmmbin/make-simple-hmm-graph.cc @@ -0,0 +1,87 @@ +// simplehmmbin/make-simple-hmm-graph.cc + +// Copyright 2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "simplehmm/simple-hmm.h" +#include "simplehmm/simple-hmm-utils.h" +#include "util/common-utils.h" +#include "fst/fstlib.h" +#include "fstext/table-matcher.h" +#include "fstext/fstext-utils.h" +#include "fstext/context-fst.h" +#include "decoder/simple-hmm-graph-compiler.h" + + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + typedef kaldi::int32 int32; + using fst::SymbolTable; + using fst::VectorFst; + using fst::StdArc; + + const char *usage = + "Make graph to decode with simple HMM. It is an FST from " + "transition-ids to pdf-ids + 1, \n" + "Usage: make-simple-hmm-graph []\n" + "e.g.: \n" + " make-simple-hmm-graph 1.mdl > HCLG.fst\n"; + ParseOptions po(usage); + + SimpleHmmGraphCompilerOptions gopts; + gopts.Register(&po); + + po.Read(argc, argv); + + if (po.NumArgs() < 1 || po.NumArgs() > 2) { + po.PrintUsage(); + exit(1); + } + + std::string model_filename = po.GetArg(1); + std::string fst_out_filename; + if (po.NumArgs() >= 2) fst_out_filename = po.GetArg(2); + if (fst_out_filename == "-") fst_out_filename = ""; + + SimpleHmm trans_model; + ReadKaldiObject(model_filename, &trans_model); + + // The work gets done here. + fst::VectorFst *H = GetHTransducer (trans_model, + gopts.transition_scale, + gopts.self_loop_scale); + +#if _MSC_VER + if (fst_out_filename == "") + _setmode(_fileno(stdout), _O_BINARY); +#endif + + if (! H->Write(fst_out_filename) ) + KALDI_ERR << "make-simple-hmm-graph: error writing FST to " + << (fst_out_filename == "" ? + "standard output" : fst_out_filename); + + delete H; + return 0; + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + + diff --git a/src/simplehmmbin/simple-hmm-acc-stats-ali.cc b/src/simplehmmbin/simple-hmm-acc-stats-ali.cc new file mode 100644 index 00000000000..5bcf8239311 --- /dev/null +++ b/src/simplehmmbin/simple-hmm-acc-stats-ali.cc @@ -0,0 +1,88 @@ +// simplehmmbin/simple-hmm-acc-stats-ali.cc + +// Copyright 2009-2012 Microsoft Corporation Johns Hopkins University (Author: Daniel Povey) +// 2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "simplehmm/simple-hmm.h" + +int main(int argc, char *argv[]) { + using namespace kaldi; + typedef kaldi::int32 int32; + try { + const char *usage = + "Accumulate stats for simple HMM training.\n" + "Usage: simple-hmm-acc-stats-ali [options] " + " \n" + "e.g.:\n simple-hmm-acc-stats-ali 1.mdl ark:1.ali 1.acc\n"; + + ParseOptions po(usage); + bool binary = true; + po.Register("binary", &binary, "Write output in binary mode"); + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + std::string model_filename = po.GetArg(1), + alignments_rspecifier = po.GetArg(2), + accs_wxfilename = po.GetArg(3); + + SimpleHmm model; + ReadKaldiObject(model_filename, &model); + + Vector transition_accs; + model.InitStats(&transition_accs); + + SequentialInt32VectorReader alignments_reader(alignments_rspecifier); + + int32 num_done = 0, num_err = 0; + for (; !alignments_reader.Done(); alignments_reader.Next()) { + const std::string &key = alignments_reader.Key(); + const std::vector &alignment = alignments_reader.Value(); + + for (size_t i = 0; i < alignment.size(); i++) { + int32 tid = alignment[i]; // transition identifier. + model.Accumulate(1.0, tid, &transition_accs); + } + + num_done++; + } + KALDI_LOG << "Done " << num_done << " files, " << num_err + << " with errors."; + + { + Output ko(accs_wxfilename, binary); + transition_accs.Write(ko.Stream(), binary); + } + KALDI_LOG << "Written accs."; + if (num_done != 0) + return 0; + else + return 1; + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + + diff --git a/src/simplehmmbin/simple-hmm-align-compiled.cc b/src/simplehmmbin/simple-hmm-align-compiled.cc new file mode 100644 index 00000000000..4a2bc286b24 --- /dev/null +++ b/src/simplehmmbin/simple-hmm-align-compiled.cc @@ -0,0 +1,131 @@ +// simplehmmbin/simple-hmm-align-compiled.cc + +// Copyright 2009-2013 Microsoft Corporation +// Johns Hopkins University (author: Daniel Povey) +// 2016 Vimal Manohar + + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "simplehmm/simple-hmm.h" +#include "simplehmm/simple-hmm-utils.h" +#include "fstext/fstext-lib.h" +#include "decoder/decoder-wrappers.h" +#include "decoder/decodable-matrix.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + typedef kaldi::int32 int32; + using fst::SymbolTable; + using fst::VectorFst; + using fst::StdArc; + + const char *usage = + "Align matrix of log-likelihoods given simple HMM model.\n" + "Usage: simple-hmm-align-compiled [options] " + " []\n" + "e.g.: \n" + " simple-hmm-align-compiled 1.mdl ark:graphs.fsts ark:log_likes.1.ark ark:1.ali\n"; + + ParseOptions po(usage); + AlignConfig align_config; + BaseFloat acoustic_scale = 1.0; + BaseFloat transition_scale = 1.0; + BaseFloat self_loop_scale = 1.0; + + align_config.Register(&po); + po.Register("transition-scale", &transition_scale, + "Transition-probability scale [relative to acoustics]"); + po.Register("acoustic-scale", &acoustic_scale, + "Scaling factor for acoustic likelihoods"); + po.Register("self-loop-scale", &self_loop_scale, + "Scale of self-loop versus non-self-loop log probs [relative to acoustics]"); + po.Read(argc, argv); + + if (po.NumArgs() < 4 || po.NumArgs() > 5) { + po.PrintUsage(); + exit(1); + } + + std::string model_in_filename = po.GetArg(1), + fst_rspecifier = po.GetArg(2), + loglikes_rspecifier = po.GetArg(3), + alignment_wspecifier = po.GetArg(4), + scores_wspecifier = po.GetOptArg(5); + + SimpleHmm model; + ReadKaldiObject(model_in_filename, &model); + + SequentialTableReader fst_reader(fst_rspecifier); + RandomAccessBaseFloatMatrixReader loglikes_reader(loglikes_rspecifier); + Int32VectorWriter alignment_writer(alignment_wspecifier); + BaseFloatWriter scores_writer(scores_wspecifier); + + int32 num_done = 0, num_err = 0, num_retry = 0; + double tot_like = 0.0; + kaldi::int64 frame_count = 0; + + for (; !fst_reader.Done(); fst_reader.Next()) { + const std::string &utt = fst_reader.Key(); + if (!loglikes_reader.HasKey(utt)) { + num_err++; + KALDI_WARN << "No loglikes for utterance " << utt; + } else { + const Matrix &loglikes = loglikes_reader.Value(utt); + VectorFst decode_fst(fst_reader.Value()); + fst_reader.FreeCurrent(); // this stops copy-on-write of the fst + // by deleting the fst inside the reader, since we're about to mutate + // the fst by adding transition probs. + + if (loglikes.NumRows() == 0) { + KALDI_WARN << "Zero-length utterance: " << utt; + num_err++; + continue; + } + + { // Add transition-probs to the FST. + std::vector disambig_syms; // empty + AddTransitionProbs(model, disambig_syms, transition_scale, + self_loop_scale, &decode_fst); + } + + DecodableMatrixScaledMapped decodable(model, loglikes, acoustic_scale); + + AlignUtteranceWrapper(align_config, utt, + acoustic_scale, &decode_fst, + &decodable, + &alignment_writer, &scores_writer, + &num_done, &num_err, &num_retry, + &tot_like, &frame_count); + } + } + KALDI_LOG << "Overall log-likelihood per frame is " + << (tot_like/frame_count) + << " over " << frame_count<< " frames."; + KALDI_LOG << "Retried " << num_retry << " out of " + << (num_done + num_err) << " utterances."; + KALDI_LOG << "Done " << num_done << ", errors on " << num_err; + return (num_done != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + + diff --git a/src/simplehmmbin/simple-hmm-est.cc b/src/simplehmmbin/simple-hmm-est.cc new file mode 100644 index 00000000000..b121bad44b0 --- /dev/null +++ b/src/simplehmmbin/simple-hmm-est.cc @@ -0,0 +1,86 @@ +// simplehmmbin/simple-hmm-est.cc + +// Copyright 2009-2011 Microsoft Corporation +// 2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "simplehmm/simple-hmm.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + typedef kaldi::int32 int32; + + const char *usage = + "Do Maximum Likelihood re-estimation of simple HMM " + "transition parameters\n" + "Usage: simple-hmm-est [options] \n" + "e.g.: simple-hmm-est 1.mdl 1.acc 2.mdl\n"; + + bool binary_write = true; + MleTransitionUpdateConfig tcfg; + std::string occs_out_filename; + + ParseOptions po(usage); + po.Register("binary", &binary_write, "Write output in binary mode"); + tcfg.Register(&po); + + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + std::string model_in_filename = po.GetArg(1), + stats_filename = po.GetArg(2), + model_out_filename = po.GetArg(3); + + SimpleHmm model; + ReadKaldiObject(model_in_filename, &model); + + Vector transition_accs; + ReadKaldiObject(stats_filename, &transition_accs); + + { + BaseFloat objf_impr, count; + model.MleUpdate(transition_accs, tcfg, &objf_impr, &count); + KALDI_LOG << "Transition model update: Overall " << (objf_impr/count) + << " log-like improvement per frame over " << (count) + << " frames."; + } + + WriteKaldiObject(model, model_out_filename, binary_write); + + if (GetVerboseLevel() >= 2) { + std::vector phone_names; + phone_names.push_back("0"); + phone_names.push_back("1"); + model.Print(KALDI_LOG, phone_names); + } + + KALDI_LOG << "Written model to " << model_out_filename; + return 0; + } catch(const std::exception &e) { + std::cerr << e.what() << '\n'; + return -1; + } +} + + diff --git a/src/simplehmmbin/simple-hmm-init.cc b/src/simplehmmbin/simple-hmm-init.cc new file mode 100644 index 00000000000..ddee0893b7c --- /dev/null +++ b/src/simplehmmbin/simple-hmm-init.cc @@ -0,0 +1,70 @@ +// bin/simple-hmm-init.cc + +// Copyright 2016 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "hmm/hmm-topology.h" +#include "simplehmm/simple-hmm.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using kaldi::int32; + + const char *usage = + "Initialize simple HMM from topology.\n" + "Usage: simple-hmm-init \n" + "e.g.: \n" + " simple-hmm-init topo init.mdl\n"; + + bool binary = true; + + ParseOptions po(usage); + po.Register("binary", &binary, "Write output in binary mode"); + + po.Read(argc, argv); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + std::string topo_filename = po.GetArg(1); + std::string model_filename = po.GetArg(2); + + HmmTopology topo; + { + bool binary_in; + Input ki(topo_filename, &binary_in); + topo.Read(ki.Stream(), binary_in); + } + + SimpleHmm model(topo); + { + Output ko(model_filename, binary); + model.Write(ko.Stream(), binary); + } + return 0; + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + + diff --git a/src/tree/cluster-utils.cc b/src/tree/cluster-utils.cc index 53de0825e08..aa9ae46bc01 100644 --- a/src/tree/cluster-utils.cc +++ b/src/tree/cluster-utils.cc @@ -190,62 +190,6 @@ void AddToClustersOptimized(const std::vector &stats, // Bottom-up clustering routines // ============================================================================ -class BottomUpClusterer { - public: - BottomUpClusterer(const std::vector &points, - BaseFloat max_merge_thresh, - int32 min_clust, - std::vector *clusters_out, - std::vector *assignments_out) - : ans_(0.0), points_(points), max_merge_thresh_(max_merge_thresh), - min_clust_(min_clust), clusters_(clusters_out != NULL? clusters_out - : &tmp_clusters_), assignments_(assignments_out != NULL ? - assignments_out : &tmp_assignments_) { - nclusters_ = npoints_ = points.size(); - dist_vec_.resize((npoints_ * (npoints_ - 1)) / 2); - } - - BaseFloat Cluster(); - ~BottomUpClusterer() { DeletePointers(&tmp_clusters_); } - - private: - void Renumber(); - void InitializeAssignments(); - void SetInitialDistances(); ///< Sets up distances and queue. - /// CanMerge returns true if i and j are existing clusters, and the distance - /// (negated objf-change) "dist" is accurate (i.e. not outdated). - bool CanMerge(int32 i, int32 j, BaseFloat dist); - /// Merge j into i and delete j. - void MergeClusters(int32 i, int32 j); - /// Reconstructs the priority queue from the distances. - void ReconstructQueue(); - - void SetDistance(int32 i, int32 j); - BaseFloat& Distance(int32 i, int32 j) { - KALDI_ASSERT(i < npoints_ && j < i); - return dist_vec_[(i * (i - 1)) / 2 + j]; - } - - BaseFloat ans_; - const std::vector &points_; - BaseFloat max_merge_thresh_; - int32 min_clust_; - std::vector *clusters_; - std::vector *assignments_; - - std::vector tmp_clusters_; - std::vector tmp_assignments_; - - std::vector dist_vec_; - int32 nclusters_; - int32 npoints_; - typedef std::pair > QueueElement; - // Priority queue using greater (lowest distances are highest priority). - typedef std::priority_queue, - std::greater > QueueType; - QueueType queue_; -}; - BaseFloat BottomUpClusterer::Cluster() { KALDI_VLOG(2) << "Initializing cluster assignments."; InitializeAssignments(); @@ -253,12 +197,15 @@ BaseFloat BottomUpClusterer::Cluster() { SetInitialDistances(); KALDI_VLOG(2) << "Clustering..."; - while (nclusters_ > min_clust_ && !queue_.empty()) { + while (!StoppingCriterion()) { std::pair > pr = queue_.top(); BaseFloat dist = pr.first; int32 i = (int32) pr.second.first, j = (int32) pr.second.second; queue_.pop(); - if (CanMerge(i, j, dist)) MergeClusters(i, j); + if (CanMerge(i, j, dist)) { + UpdateClustererStats(i, j); + MergeClusters(i, j); + } } KALDI_VLOG(2) << "Renumbering clusters to contiguous numbers."; Renumber(); @@ -325,11 +272,12 @@ void BottomUpClusterer::InitializeAssignments() { void BottomUpClusterer::SetInitialDistances() { for (int32 i = 0; i < npoints_; i++) { for (int32 j = 0; j < i; j++) { - BaseFloat dist = (*clusters_)[i]->Distance(*((*clusters_)[j])); - dist_vec_[(i * (i - 1)) / 2 + j] = dist; - if (dist <= max_merge_thresh_) + BaseFloat dist = ComputeDistance(i, j); + if (dist <= MergeThreshold(i, j)) queue_.push(std::make_pair(dist, std::make_pair(static_cast(i), static_cast(j)))); + if (j == i - 1) + KALDI_VLOG(2) << "Distance(" << i << ", " << j << ") = " << dist; } } } @@ -344,6 +292,7 @@ bool BottomUpClusterer::CanMerge(int32 i, int32 j, BaseFloat dist) { void BottomUpClusterer::MergeClusters(int32 i, int32 j) { KALDI_ASSERT(i != j && i < npoints_ && j < npoints_); + (*clusters_)[i]->Add(*((*clusters_)[j])); delete (*clusters_)[j]; (*clusters_)[j] = NULL; @@ -376,7 +325,7 @@ void BottomUpClusterer::ReconstructQueue() { for (int32 j = 0; j < i; j++) { if ((*clusters_)[j] != NULL) { BaseFloat dist = dist_vec_[(i * (i - 1)) / 2 + j]; - if (dist <= max_merge_thresh_) { + if (dist <= MergeThreshold(i, j)) { queue_.push(std::make_pair(dist, std::make_pair( static_cast(i), static_cast(j)))); } @@ -389,9 +338,8 @@ void BottomUpClusterer::ReconstructQueue() { void BottomUpClusterer::SetDistance(int32 i, int32 j) { KALDI_ASSERT(i < npoints_ && j < i && (*clusters_)[i] != NULL && (*clusters_)[j] != NULL); - BaseFloat dist = (*clusters_)[i]->Distance(*((*clusters_)[j])); - dist_vec_[(i * (i - 1)) / 2 + j] = dist; // set the distance in the array. - if (dist < max_merge_thresh_) { + BaseFloat dist = ComputeDistance(i, j); + if (dist < MergeThreshold(i, j)) { queue_.push(std::make_pair(dist, std::make_pair(static_cast(i), static_cast(j)))); } @@ -403,7 +351,6 @@ void BottomUpClusterer::SetDistance(int32 i, int32 j) { } - BaseFloat ClusterBottomUp(const std::vector &points, BaseFloat max_merge_thresh, int32 min_clust, diff --git a/src/tree/cluster-utils.h b/src/tree/cluster-utils.h index 55583a237bf..2658cb8dfd0 100644 --- a/src/tree/cluster-utils.h +++ b/src/tree/cluster-utils.h @@ -21,10 +21,14 @@ #ifndef KALDI_TREE_CLUSTER_UTILS_H_ #define KALDI_TREE_CLUSTER_UTILS_H_ +#include #include +using std::vector; #include "matrix/matrix-lib.h" +#include "util/stl-utils.h" #include "itf/clusterable-itf.h" + namespace kaldi { /// \addtogroup clustering_group_simple @@ -103,9 +107,100 @@ void AddToClustersOptimized(const std::vector &stats, * @param assignments_out [out] If non-NULL, will be resized to the number of * points, and each element is the index of the cluster that point * was assigned to. + */ + +class BottomUpClusterer { + public: + typedef uint16 uint_smaller; + typedef int16 int_smaller; + + BottomUpClusterer(const std::vector &points, + BaseFloat max_merge_thresh, + int32 min_clust, + std::vector *clusters_out, + std::vector *assignments_out) + : points_(points), max_merge_thresh_(max_merge_thresh), + min_clust_(min_clust), clusters_(clusters_out != NULL? clusters_out + : &tmp_clusters_), ans_(0.0), + assignments_(assignments_out != NULL ? + assignments_out : &tmp_assignments_) { + nclusters_ = npoints_ = points.size(); + dist_vec_.resize((npoints_ * (npoints_ - 1)) / 2); + } + + BaseFloat Cluster(); + ~BottomUpClusterer() { DeletePointers(&tmp_clusters_); } + + /// Public accessors + BaseFloat& Distance(int32 i, int32 j) { + KALDI_ASSERT(i < npoints_ && j < i); + return dist_vec_[(i * (i - 1)) / 2 + j]; + } + /// CanMerge returns true if i and j are existing clusters, and the distance + /// (negated objf-change) "dist" is accurate (i.e. not outdated). + virtual bool CanMerge(int32 i, int32 j, BaseFloat dist); + + /// Merge j into i and delete j. + virtual void MergeClusters(int32 i, int32 j); + + typedef std::pair > + QueueElement; + // Priority queue using greater (lowest distances are highest priority). + typedef std::priority_queue, + std::greater > QueueType; + + int32 NumClusters() const { return nclusters_; } + int32 NumPoints() const { return npoints_; } + int32 MinClusters() const { return min_clust_; } + bool IsQueueEmpty() const { return queue_.empty(); } + + protected: + const std::vector &points_; + BaseFloat max_merge_thresh_; + int32 min_clust_; + std::vector *clusters_; + + std::vector dist_vec_; + int32 nclusters_; + int32 npoints_; + QueueType queue_; + + private: + void Renumber(); + void InitializeAssignments(); + void SetInitialDistances(); ///< Sets up distances and queue. + /// Reconstructs the priority queue from the distances. + void ReconstructQueue(); + + /// Update some stats to reflect merging clusters i and j + virtual void UpdateClustererStats(int32, int32 j) { }; + + virtual bool StoppingCriterion() const { + return nclusters_ <= min_clust_ || queue_.empty(); + } + + virtual BaseFloat MergeThreshold(int32 i, int32 j) { + return max_merge_thresh_; + } + + void SetDistance(int32 i, int32 j); + virtual BaseFloat ComputeDistance(int32 i, int32 j) { + BaseFloat dist = (*clusters_)[i]->Distance(*((*clusters_)[j])); + dist_vec_[(i * (i - 1)) / 2 + j] = dist; // set the distance in the array. + return dist; + } + + BaseFloat ans_; + std::vector *assignments_; + + std::vector tmp_clusters_; + std::vector tmp_assignments_; +}; + +/** This is a wrapper function to the BottomUpClusterer class. * @return Returns the total objf change relative to all clusters being separate, which is * a negative. Note that this is not the same as what the other clustering algorithms return. - */ + **/ BaseFloat ClusterBottomUp(const std::vector &points, BaseFloat thresh, int32 min_clust, diff --git a/src/util/kaldi-holder.cc b/src/util/kaldi-holder.cc index a26bdf2ce29..a86f09a2030 100644 --- a/src/util/kaldi-holder.cc +++ b/src/util/kaldi-holder.cc @@ -34,7 +34,7 @@ bool ExtractObjectRange(const Matrix &input, const std::string &range, SplitStringToVector(range, ",", false, &splits); if (!((splits.size() == 1 && !splits[0].empty()) || (splits.size() == 2 && !splits[0].empty() && !splits[1].empty()))) { - KALDI_ERR << "Invalid range specifier: " << range; + KALDI_ERR << "Invalid range specifier for matrix: " << range; return false; } std::vector row_range, col_range; @@ -75,6 +75,48 @@ template bool ExtractObjectRange(const Matrix &, const std::string &, template bool ExtractObjectRange(const Matrix &, const std::string &, Matrix *); +template +bool ExtractObjectRange(const Vector &input, const std::string &range, + Vector *output) { + if (range.empty()) { + KALDI_ERR << "Empty range specifier."; + return false; + } + std::vector splits; + SplitStringToVector(range, ",", false, &splits); + if (!((splits.size() == 1 && !splits[0].empty()))) { + KALDI_ERR << "Invalid range specifier for vector: " << range; + return false; + } + std::vector index_range; + bool status = true; + if (splits[0] != ":") + status = SplitStringToIntegers(splits[0], ":", false, &index_range); + + if (index_range.size() == 0) { + index_range.push_back(0); + index_range.push_back(input.Dim() - 1); + } + + if (!(status && index_range.size() == 2 && + index_range[0] >= 0 && index_range[0] <= index_range[1] && + index_range[1] < input.Dim())) { + KALDI_ERR << "Invalid range specifier: " << range + << " for vector of size " << input.Dim(); + return false; + } + int32 size = index_range[1] - index_range[0] + 1; + output->Resize(size, kUndefined); + output->CopyFromVec(input.Range(index_range[0], size)); + return true; +} + +// template instantiation +template bool ExtractObjectRange(const Vector &, const std::string &, + Vector *); +template bool ExtractObjectRange(const Vector &, const std::string &, + Vector *); + bool ExtractRangeSpecifier(const std::string &rxfilename_with_range, std::string *data_rxfilename, std::string *range) { diff --git a/src/util/kaldi-holder.h b/src/util/kaldi-holder.h index 06d7ec8e745..9ab148387ee 100644 --- a/src/util/kaldi-holder.h +++ b/src/util/kaldi-holder.h @@ -242,6 +242,11 @@ template bool ExtractObjectRange(const Matrix &input, const std::string &range, Matrix *output); +/// The template is specialized types Vector and Vector. +template +bool ExtractObjectRange(const Vector &input, const std::string &range, + Vector *output); + // In SequentialTableReaderScriptImpl and RandomAccessTableReaderScriptImpl, for // cases where the scp contained 'range specifiers' (things in square brackets diff --git a/tools/config/common_path.sh b/tools/config/common_path.sh index fbc4b674474..2ec95b8de6c 100644 --- a/tools/config/common_path.sh +++ b/tools/config/common_path.sh @@ -20,4 +20,6 @@ ${KALDI_ROOT}/src/online2bin:\ ${KALDI_ROOT}/src/onlinebin:\ ${KALDI_ROOT}/src/sgmm2bin:\ ${KALDI_ROOT}/src/sgmmbin:\ +${KALDI_ROOT}/src/segmenterbin:\ +${KALDI_ROOT}/src/simplehmmbin:\ $PATH