diff --git a/examples/asr/experimental/k2/align_speech_parallel.py b/examples/asr/experimental/k2/align_speech_parallel.py index 3e57e3cb74d5..dcccee48ee27 100644 --- a/examples/asr/experimental/k2/align_speech_parallel.py +++ b/examples/asr/experimental/k2/align_speech_parallel.py @@ -65,6 +65,7 @@ aligner_args.decode_batch_size=8 \ aligner_args.ctc_cfg.prob_suppress_index=-1 \ aligner_args.ctc_cfg.prob_suppress_value=0.5 \ + aligner_args.rnnt_cfg.predictor_window_size=10 \ aligner_args.decoder_module_cfg.intersect_pruned=true \ aligner_args.decoder_module_cfg.intersect_conf.search_beam=40 \ ... diff --git a/examples/asr/experimental/k2/conf/citrinet/citrinet_mmi_1024.yaml b/examples/asr/experimental/k2/conf/citrinet/citrinet_mmi_1024.yaml index 1c1be351ca35..b254b0720694 100644 --- a/examples/asr/experimental/k2/conf/citrinet/citrinet_mmi_1024.yaml +++ b/examples/asr/experimental/k2/conf/citrinet/citrinet_mmi_1024.yaml @@ -1,4 +1,4 @@ -# This config contains the default values for training a Citrinet model with CTC loss and BPE-based vocabulary. +# This config contains the default values for training a Citrinet model with CTC-MMI loss and BPE-based vocabulary. # Default learning parameters in this config are set for effective batch size of 1k on 32 GPUs. # To train it with smaller batch sizes, you may need to re-tune the learning parameters or use higher accumulate_grad_batches. # If training for a short time, you can also reduce weight decay to 0. @@ -447,11 +447,11 @@ model: graph_module_cfg: criterion_type: map + loss_type: mmi transcribe_training: false split_batch_size: 0 backend_cfg: token_lm: ??? - loss_type: mmi topo_type: default topo_with_self_loops: true intersect_pruned: false diff --git a/examples/asr/experimental/k2/conf/conformer/conformer_ctc_bpe.yaml b/examples/asr/experimental/k2/conf/conformer/conformer_ctc_bpe.yaml new file mode 100644 index 000000000000..fe418b8bdf42 --- /dev/null +++ b/examples/asr/experimental/k2/conf/conformer/conformer_ctc_bpe.yaml @@ -0,0 +1,216 @@ +# It contains the default values for training a Conformer-MMI (CTC) ASR model, large size (~120M) with CTC loss and sub-word encoding. + +# Architecture and training config: +# Default learning parameters in this config are set for effective batch size of 2K. To train it with smaller effective +# batch sizes, you may need to re-tune the learning parameters or use higher accumulate_grad_batches. +# Here are the recommended configs for different variants of Conformer-CTC, other parameters are the same as in this config file. +# One extra layer (compared to original paper) is added to the medium and large variants to compensate for replacing the LSTM decoder with a linear one. +# +# +-------------+---------+---------+----------+------------+-----+ +# | Model | d_model | n_heads | n_layers | time_masks | lr | +# +=============+=========+========+===========+============+=====+ +# | Small (13M)| 176 | 4 | 16 | 5 | 5.0 | +# +-------------+---------+--------+-----------+------------+-----+ +# | Medium (30M)| 256 | 4 | 18 | 5 | 5.0 | +# +-------------+---------+--------+-----------+------------+-----+ +# | Large (121M)| 512 | 8 | 18 | 10 | 2.0 | +# +---------------------------------------------------------------+ +# +# If you do not want to train with AMP, you may use weight decay of 0.0 or reduce the number of time maskings to 2 +# with time_width=100. It may help when you want to train for fewer epochs and need faster convergence. +# With weight_decay=0.0, learning rate may need to get reduced to 2.0. + +# You may find more info about Conformer-CTC here: https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/models.html#conformer-ctc + +name: "Conformer-MMI-BPE" + +model: + sample_rate: 16000 + log_prediction: true # enables logging sample predictions in the output during training + ctc_reduction: 'mean_batch' + skip_nan_grad: false + + train_ds: + manifest_filepath: ??? + sample_rate: ${model.sample_rate} + batch_size: 16 # you may increase batch_size if your memory allows + shuffle: true + num_workers: 8 + pin_memory: true + use_start_end_token: false + trim_silence: false + max_duration: 16.7 # it is set for LibriSpeech, you may need to update it for your dataset + min_duration: 0.1 + # tarred datasets + is_tarred: false + tarred_audio_filepaths: null + shuffle_n: 2048 + # bucketing params + bucketing_strategy: "synced_randomized" + bucketing_batch_size: null + + validation_ds: + manifest_filepath: ??? + sample_rate: ${model.sample_rate} + batch_size: 16 # you may increase batch_size if your memory allows + shuffle: false + num_workers: 8 + pin_memory: true + use_start_end_token: false + + test_ds: + manifest_filepath: null + sample_rate: ${model.sample_rate} + batch_size: 16 # you may increase batch_size if your memory allows + shuffle: false + num_workers: 8 + pin_memory: true + use_start_end_token: false + + # recommend small vocab size of 128 or 256 when using 4x sub-sampling + # you may find more detail on how to train a tokenizer at: /scripts/tokenizers/process_asr_text_tokenizer.py + tokenizer: + dir: ??? # path to directory which contains either tokenizer.model (bpe) or vocab.txt (wpe) + type: bpe # Can be either bpe (SentencePiece tokenizer) or wpe (WordPiece tokenizer) + + preprocessor: + _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor + sample_rate: ${model.sample_rate} + normalize: "per_feature" + window_size: 0.025 + window_stride: 0.01 + window: "hann" + features: 80 + n_fft: 512 + log: true + frame_splicing: 1 + dither: 0.00001 + pad_to: 0 + pad_value: 0.0 + + spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + freq_masks: 2 # set to zero to disable it + # you may use lower time_masks for smaller models to have a faster convergence + time_masks: 10 # set to zero to disable it + freq_width: 27 + time_width: 0.05 + + encoder: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: ${model.preprocessor.features} + feat_out: -1 # you may set it if you need different output size other than the default d_model + n_layers: 18 + d_model: 512 + + # Sub-sampling params + subsampling: striding # vggnet, striding, stacking or stacking_norm, dw_striding + subsampling_factor: 4 # must be power of 2 for striding and vggnet + subsampling_conv_channels: -1 # -1 sets it to d_model + causal_downsampling: false + + # Feed forward module's params + ff_expansion_factor: 4 + + # Multi-headed Attention Module's params + self_attention_model: rel_pos # rel_pos or abs_pos + n_heads: 8 # may need to be lower for smaller d_models + # [left, right] specifies the number of steps to be seen from left and right of each step in self-attention + att_context_size: [-1, -1] # -1 means unlimited context + att_context_style: regular # regular or chunked_limited + xscaling: true # scales up the input embeddings by sqrt(d_model) + untie_biases: true # unties the biases of the TransformerXL layers + pos_emb_max_len: 5000 + + # Convolution module's params + conv_kernel_size: 31 + conv_norm_type: 'batch_norm' # batch_norm or layer_norm or groupnormN (N specifies the number of groups) + # conv_context_size can be"causal" or a list of two integers while conv_context_size[0]+conv_context_size[1]+1==conv_kernel_size + # null means [(kernel_size-1)//2, (kernel_size-1)//2], and 'causal' means [(kernel_size-1), 0] + conv_context_size: null + + ### regularization + dropout: 0.1 # The dropout used in most of the Conformer Modules + dropout_pre_encoder: 0.1 # The dropout used before the encoder + dropout_emb: 0.0 # The dropout used for embeddings + dropout_att: 0.1 # The dropout for multi-headed attention modules + + decoder: + _target_: nemo.collections.asr.modules.ConvASRDecoder + feat_in: null + num_classes: -1 + vocabulary: [] + + optim: + name: adamw + lr: 2.0 + # optimizer arguments + betas: [0.9, 0.98] + # less necessity for weight_decay as we already have large augmentations with SpecAug + # you may need weight_decay for large models, stable AMP training, small datasets, or when lower augmentations are used + # weight decay of 0.0 with lr of 2.0 also works fine + weight_decay: 1e-3 + + # scheduler setup + sched: + name: NoamAnnealing + d_model: ${model.encoder.d_model} + # scheduler config override + warmup_steps: 10000 + warmup_ratio: null + min_lr: 1e-6 + + graph_module_cfg: + criterion_type: map + loss_type: mmi + transcribe_training: false + split_batch_size: 0 + backend_cfg: + token_lm: ??? + topo_type: default + topo_with_self_loops: true + intersect_pruned: false + boost_coeff: 0.0 + +trainer: + devices: -1 # number of GPUs, -1 would use all available GPUs + num_nodes: 1 + max_epochs: 1000 + max_steps: -1 # computed at runtime if not set + val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations + accelerator: auto + strategy: ddp + accumulate_grad_batches: 1 + gradient_clip_val: 0.0 + precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP. + log_every_n_steps: 10 # Interval of logging. + enable_progress_bar: True + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it + check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs + sync_batchnorm: true + enable_checkpointing: False # Provided by exp_manager + logger: false # Provided by exp_manager + benchmark: false # needs to be false for models with variable-length speech input as it slows down training + +exp_manager: + exp_dir: null + name: ${name} + create_tensorboard_logger: true + create_checkpoint_callback: true + checkpoint_callback_params: + # in case of multiple validation sets, first one is used + monitor: "val_wer" + mode: "min" + save_top_k: 5 + always_save_nemo: True # saves the checkpoints as nemo files instead of PTL checkpoints + + # you need to set these two to True to continue the training + resume_if_exists: false + resume_ignore_no_checkpoint: false + + # You may use this section to create a W&B logger + create_wandb_logger: false + wandb_logger_kwargs: + name: null + project: null diff --git a/examples/asr/experimental/k2/conf/conformer/conformer_transducer_bpe.yaml b/examples/asr/experimental/k2/conf/conformer/conformer_transducer_bpe.yaml new file mode 100644 index 000000000000..9486cbf2d58f --- /dev/null +++ b/examples/asr/experimental/k2/conf/conformer/conformer_transducer_bpe.yaml @@ -0,0 +1,268 @@ +# It contains the default values for training a Conformer-Transducer ASR model, large size (~120M) with Transducer loss and sub-word encoding. + +# Architecture and training config: +# Default learning parameters in this config are set for effective batch size of 2K. To train it with smaller effective +# batch sizes, you may need to re-tune the learning parameters or use higher accumulate_grad_batches. +# Here are the recommended configs for different variants of Conformer-Transducer, other parameters are the same as in this config file. +# +# +-------------+---------+---------+----------+--------------+--------------------------+ +# | Model | d_model | n_heads | n_layers | weight_decay | pred_hidden/joint_hidden | +# +=============+=========+========+===========+==============+==========================+ +# | Small (14M)| 176 | 4 | 16 | 0.0 | 320 | +# +-------------+---------+--------+-----------+--------------+--------------------------+ +# | Medium (32M)| 256 | 4 | 16 | 1e-3 | 640 | +# +-------------+---------+--------+-----------+--------------+--------------------------+ +# | Large (120M)| 512 | 8 | 17 | 1e-3 | 640 | +# +-----------------------------------------------------------+--------------------------+ +# + +# You may find more info about Conformer-Transducer here: https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/models.html#conformer-transducer +# Pre-trained models of Conformer-Transducer can be found here: https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/results.html +# The checkpoint of the large model trained on NeMo ASRSET with this recipe can be found here: https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_conformer_transducer_large + +name: "Conformer-Transducer-BPE" + +model: + sample_rate: 16000 + compute_eval_loss: false # eval samples can be very long and exhaust memory. Disable computation of transducer loss during validation/testing with this flag. + log_prediction: true # enables logging sample predictions in the output during training + skip_nan_grad: false + + model_defaults: + enc_hidden: ${model.encoder.d_model} + pred_hidden: 640 + joint_hidden: 640 + + train_ds: + manifest_filepath: ??? + sample_rate: ${model.sample_rate} + batch_size: 16 # you may increase batch_size if your memory allows + shuffle: true + num_workers: 8 + pin_memory: true + use_start_end_token: false + trim_silence: false + max_duration: 16.7 # it is set for LibriSpeech, you may need to update it for your dataset + min_duration: 0.1 + # tarred datasets + is_tarred: false + tarred_audio_filepaths: null + shuffle_n: 2048 + # bucketing params + bucketing_strategy: "synced_randomized" + bucketing_batch_size: null + + validation_ds: + manifest_filepath: ??? + sample_rate: ${model.sample_rate} + batch_size: 16 + shuffle: false + num_workers: 8 + pin_memory: true + use_start_end_token: false + + test_ds: + manifest_filepath: null + sample_rate: ${model.sample_rate} + batch_size: 16 + shuffle: false + num_workers: 8 + pin_memory: true + use_start_end_token: false + + # You may find more detail on how to train a tokenizer at: /scripts/tokenizers/process_asr_text_tokenizer.py + tokenizer: + dir: ??? # path to directory which contains either tokenizer.model (bpe) or vocab.txt (for wpe) + type: bpe # Can be either bpe (SentencePiece tokenizer) or wpe (WordPiece tokenizer) + + preprocessor: + _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor + sample_rate: ${model.sample_rate} + normalize: "per_feature" + window_size: 0.025 + window_stride: 0.01 + window: "hann" + features: 80 + n_fft: 512 + frame_splicing: 1 + dither: 0.00001 + pad_to: 0 + + spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + freq_masks: 2 # set to zero to disable it + time_masks: 10 # set to zero to disable it + freq_width: 27 + time_width: 0.05 + + encoder: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: ${model.preprocessor.features} + feat_out: -1 # you may set it if you need different output size other than the default d_model + n_layers: 17 + d_model: 512 + + # Sub-sampling parameters + subsampling: striding # vggnet, striding, stacking or stacking_norm, dw_striding + subsampling_factor: 4 # must be power of 2 for striding and vggnet + subsampling_conv_channels: -1 # set to -1 to make it equal to the d_model + causal_downsampling: false + + # Reduction parameters: Can be used to add another subsampling layer at a given position. + # Having a 2x reduction will speedup the training and inference speech while keeping similar WER. + # Adding it at the end will give the best WER while adding it at the beginning will give the best speedup. + reduction: null # pooling, striding, or null + reduction_position: null # Encoder block index or -1 for subsampling at the end of encoder + reduction_factor: 1 + + # Feed forward module's params + ff_expansion_factor: 4 + + # Multi-headed Attention Module's params + self_attention_model: rel_pos # rel_pos or abs_pos + n_heads: 8 # may need to be lower for smaller d_models + # [left, right] specifies the number of steps to be seen from left and right of each step in self-attention + att_context_size: [-1, -1] # -1 means unlimited context + att_context_style: regular # regular or chunked_limited + xscaling: true # scales up the input embeddings by sqrt(d_model) + untie_biases: true # unties the biases of the TransformerXL layers + pos_emb_max_len: 5000 + + # Convolution module's params + conv_kernel_size: 31 + conv_norm_type: 'batch_norm' # batch_norm or layer_norm or groupnormN (N specifies the number of groups) + # conv_context_size can be"causal" or a list of two integers while conv_context_size[0]+conv_context_size[1]+1==conv_kernel_size + # null means [(kernel_size-1)//2, (kernel_size-1)//2], and 'causal' means [(kernel_size-1), 0] + conv_context_size: null + + ### regularization + dropout: 0.1 # The dropout used in most of the Conformer Modules + dropout_pre_encoder: 0.1 # The dropout used before the encoder + dropout_emb: 0.0 # The dropout used for embeddings + dropout_att: 0.1 # The dropout for multi-headed attention modules + + decoder: + _target_: nemo.collections.asr.modules.RNNTDecoder + normalization_mode: null # Currently only null is supported for export. + random_state_sampling: false # Random state sampling: https://arxiv.org/pdf/1910.11455.pdf + blank_as_pad: true # This flag must be set in order to support exporting of RNNT models + efficient inference. + + prednet: + pred_hidden: ${model.model_defaults.pred_hidden} + pred_rnn_layers: 1 + t_max: null + dropout: 0.2 + + joint: + _target_: nemo.collections.asr.modules.RNNTJoint + log_softmax: true # should be always true for k2-based lattice loss + preserve_memory: false # dramatically slows down training, but might preserve some memory + + # Fuses the computation of prediction net + joint net + loss + WER calculation + # to be run on sub-batches of size `fused_batch_size`. + # When this flag is set to true, consider the `batch_size` of *_ds to be just `encoder` batch size. + # `fused_batch_size` is the actual batch size of the prediction net, joint net and transducer loss. + # Using small values here will preserve a lot of memory during training, but will make training slower as well. + # An optimal ratio of fused_batch_size : *_ds.batch_size is 1:1. + # However, to preserve memory, this ratio can be 1:8 or even 1:16. + # Extreme case of 1:B (i.e. fused_batch_size=1) should be avoided as training speed would be very slow. + fuse_loss_wer: false + fused_batch_size: 16 + + jointnet: + joint_hidden: ${model.model_defaults.joint_hidden} + activation: "relu" + dropout: 0.2 + + decoding: + strategy: "greedy_batch" # can be greedy, greedy_batch, beam, tsd, alsd. + + # greedy strategy config + greedy: + max_symbols: 10 + + # beam strategy config + beam: + beam_size: 2 + return_best_hypothesis: False + score_norm: true + tsd_max_sym_exp: 50 # for Time Synchronous Decoding + alsd_max_target_len: 2.0 # for Alignment-Length Synchronous Decoding + + loss: + loss_name: "default" + + warprnnt_numba_kwargs: + # FastEmit regularization: https://arxiv.org/abs/2010.11148 + # You may enable FastEmit to reduce the latency of the model for streaming + fastemit_lambda: 0.0 # Recommended values to be in range [1e-4, 1e-2], 0.001 is a good start. + clamp: -1.0 # if > 0, applies gradient clamping in range [-clamp, clamp] for the joint tensor only. + + graph_module_cfg: + criterion_type: ml + loss_type: rnnt + split_batch_size: 0 + backend_cfg: + topo_type: minimal + intersect_pruned: false + # Adds Gaussian noise to the gradients of the decoder to avoid overfitting + variational_noise: + start_step: 0 + std: 0.0 + + optim: + name: adamw + lr: 5.0 + # optimizer arguments + betas: [0.9, 0.98] + weight_decay: 1e-3 + + # scheduler setup + sched: + name: NoamAnnealing + d_model: ${model.encoder.d_model} + # scheduler config override + warmup_steps: 10000 + warmup_ratio: null + min_lr: 1e-6 + +trainer: + devices: -1 # number of GPUs, -1 would use all available GPUs + num_nodes: 1 + max_epochs: 500 + max_steps: -1 # computed at runtime if not set + val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations + accelerator: auto + strategy: ddp + accumulate_grad_batches: 1 + gradient_clip_val: 0.0 + precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP. + log_every_n_steps: 10 # Interval of logging. + enable_progress_bar: True + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it + check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs + sync_batchnorm: true + enable_checkpointing: False # Provided by exp_manager + logger: false # Provided by exp_manager + benchmark: false # needs to be false for models with variable-length speech input as it slows down training + + +exp_manager: + exp_dir: null + name: ${name} + create_tensorboard_logger: true + create_checkpoint_callback: true + checkpoint_callback_params: + # in case of multiple validation sets, first one is used + monitor: "val_wer" + mode: "min" + save_top_k: 5 + always_save_nemo: True # saves the checkpoints as nemo files instead of PTL checkpoints + resume_if_exists: false + resume_ignore_no_checkpoint: false + + create_wandb_logger: false + wandb_logger_kwargs: + name: null + project: null diff --git a/examples/asr/experimental/k2/speech_to_text_bpe.py b/examples/asr/experimental/k2/speech_to_text_bpe.py index 0c24ed65fbde..5eefdfaf1fe3 100644 --- a/examples/asr/experimental/k2/speech_to_text_bpe.py +++ b/examples/asr/experimental/k2/speech_to_text_bpe.py @@ -61,7 +61,8 @@ exp_manager.create_wandb_logger=True \ exp_manager.wandb_logger_kwargs.name="" \ exp_manager.wandb_logger_kwargs.project="" \ - model.graph_module_cfg.criterion_type= \ + model.graph_module_cfg.criterion_type= \ + model.graph_module_cfg.loss_type= \ model.graph_module_cfg.transcribe_training=False \ model.graph_module_cfg.split_batch_size=0 \ model.graph_module_cfg.background_cfg.topo_type=<`default` or `compact` or `shared_blank` or `minimal`> \ @@ -70,7 +71,6 @@ # If graph_module_cfg.criterion_type=`map`, you can set the following parameters: model.graph_module_cfg.background_cfg.token_lm= \ - model.graph_module_cfg.background_cfg.loss_type=mmi \ model.graph_module_cfg.background_cfg.intersect_pruned=False \ model.graph_module_cfg.background_cfg.boost_coeff=0.0 """ diff --git a/examples/asr/experimental/k2/speech_to_text_rnnt_bpe.py b/examples/asr/experimental/k2/speech_to_text_rnnt_bpe.py new file mode 100644 index 000000000000..a0031fba082d --- /dev/null +++ b/examples/asr/experimental/k2/speech_to_text_rnnt_bpe.py @@ -0,0 +1,95 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +# Preparing the Tokenizer for the dataset +Use the `process_asr_text_tokenizer.py` script under /scripts/tokenizers/ in order to prepare the tokenizer. + +```sh +python /scripts/tokenizers/process_asr_text_tokenizer.py \ + --manifest= + OR + --data_file= \ + --data_root="" \ + --vocab_size= \ + --tokenizer=<"spe" or "wpe"> \ + --no_lower_case \ + --spe_type=<"unigram", "bpe", "char" or "word"> \ + --spe_character_coverage=1.0 \ + --log +``` + +# Training the model +```sh +python speech_to_text_rnnt_bpe.py \ + # (Optional: --config-path= --config-name=) \ + model.train_ds.manifest_filepath= \ + model.validation_ds.manifest_filepath= \ + model.tokenizer.dir= \ + model.tokenizer.type= \ + trainer.devices=-1 \ + trainer.accelerator="gpu" \ + trainer.strategy="ddp" \ + trainer.max_epochs=100 \ + model.optim.name="adamw" \ + model.optim.lr=0.001 \ + model.optim.betas=[0.9,0.999] \ + model.optim.weight_decay=0.0001 \ + model.optim.sched.warmup_steps=2000 + exp_manager.create_wandb_logger=True \ + exp_manager.wandb_logger_kwargs.name="" \ + exp_manager.wandb_logger_kwargs.project="" \ + model.graph_module_cfg.criterion_type=ml \ + model.graph_module_cfg.loss_type=rnnt \ + model.graph_module_cfg.split_batch_size=0 \ + model.graph_module_cfg.background_cfg.topo_type=minimal +``` + +# Fine-tune a model + +For documentation on fine-tuning this model, please visit - +https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/configs.html#fine-tuning-configurations + +""" + +import pytorch_lightning as pl +from omegaconf import OmegaConf + +from nemo.collections.asr.models import EncDecK2RnntSeqModelBPE +from nemo.collections.asr.models.configs.k2_sequence_models_config import EncDecK2SeqModelConfig +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager + + +@hydra_runner(config_path="experimental/k2/conf/conformer", config_name="conformer_transducer_bpe.yaml") +def main(cfg: EncDecK2SeqModelConfig): + logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') + + trainer = pl.Trainer(**cfg.trainer) + exp_manager(trainer, cfg.get("exp_manager", None)) + asr_model = EncDecK2RnntSeqModelBPE(cfg=cfg.model, trainer=trainer) + + # Initialize the weights of the model from another model, if provided via config + asr_model.maybe_init_from_pretrained_checkpoint(cfg) + + trainer.fit(asr_model) + + if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None: + if asr_model.prepare_test(trainer): + trainer.test(asr_model) + + +if __name__ == '__main__': + main() # noqa pylint: disable=no-value-for-parameter diff --git a/nemo/collections/asr/losses/lattice_losses.py b/nemo/collections/asr/losses/lattice_losses.py index 4605c61ccc86..7dae44bfd1d1 100644 --- a/nemo/collections/asr/losses/lattice_losses.py +++ b/nemo/collections/asr/losses/lattice_losses.py @@ -26,7 +26,7 @@ class LatticeLoss(Loss): """Family of loss functions based on various lattice scores. Note: - Requires k2 v1.11 or later to be installed to use this loss function. + Requires k2 v1.14 or later to be installed to use this loss function. Losses can be selected via the config, and optionally be passed keyword arguments as follows. @@ -37,9 +37,9 @@ class LatticeLoss(Loss): ... graph_module_cfg: # Config for graph modules, e.g. LatticeLoss criterion_type: "map" + loss_type: "mmi" split_batch_size: 0 backend_cfg: - loss_type: "mmi" topo_type: "default" # other options: "compact", "shared_blank", "minimal" topo_with_self_loops: true token_lm: # must be provided for criterion_type: "map" @@ -56,6 +56,8 @@ class LatticeLoss(Loss): criterion_type: Type of criterion to use. Choices: `ml` and `map`, with `ml` standing for Maximum Likelihood and `map` for Maximum A Posteriori Probability. + loss_type: Type of the loss function to use. Choices: `ctc` and `rnnt` for `ml`, and `mmi` for `map`. + split_batch_size: Local batch size. Used for memory consumption reduction at the cost of speed performance. Effective if complies 0 < split_batch_size < batch_size. @@ -67,7 +69,7 @@ def input_types(self): """Input types definitions for LatticeLoss. """ return { - "log_probs": NeuralType(("B", "T", "D"), LogprobsType()), + "log_probs": NeuralType(("B", "T", "D") if self._3d_input else ("B", "T", "T", "D"), LogprobsType()), "targets": NeuralType(("B", "T"), LabelsType()), "input_lengths": NeuralType(tuple("B"), LengthsType()), "target_lengths": NeuralType(tuple("B"), LengthsType()), @@ -87,30 +89,40 @@ def __init__( reduction: str = "mean_batch", backend: str = "k2", criterion_type: str = "ml", + loss_type: str = "ctc", split_batch_size: int = 0, graph_module_cfg: Optional[DictConfig] = None, ): super().__init__() self._blank = num_classes self.split_batch_size = split_batch_size + inner_reduction = None if reduction == "mean_batch": - ctc_reduction = "none" + inner_reduction = "none" self._apply_batch_mean = True elif reduction in ["sum", "mean", "none"]: - ctc_reduction = reduction + inner_reduction = reduction self._apply_batch_mean = False # we assume that self._blank + 1 == num_classes if backend == "k2": if criterion_type == "ml": - from nemo.collections.asr.parts.k2.ml_loss import MLLoss as K2Loss + if loss_type == "ctc": + from nemo.collections.asr.parts.k2.ml_loss import CtcLoss as K2Loss + elif loss_type == "rnnt": + from nemo.collections.asr.parts.k2.ml_loss import RnntLoss as K2Loss + else: + raise ValueError(f"Unsupported `loss_type`: {loss_type}.") elif criterion_type == "map": - from nemo.collections.asr.parts.k2.map_loss import MAPLoss as K2Loss + if loss_type == "ctc": + from nemo.collections.asr.parts.k2.map_loss import CtcMmiLoss as K2Loss + else: + raise ValueError(f"Unsupported `loss_type`: {loss_type}.") else: - raise ValueError(f"Invalid value of `criterion_type`: {criterion_type}.") + raise ValueError(f"Unsupported `criterion_type`: {criterion_type}.") self._loss = K2Loss( - num_classes=self._blank + 1, blank=self._blank, reduction=ctc_reduction, cfg=graph_module_cfg, + num_classes=self._blank + 1, blank=self._blank, reduction=inner_reduction, cfg=graph_module_cfg, ) elif backend == "gtn": raise NotImplementedError(f"Backend {backend} is not supported.") @@ -118,6 +130,8 @@ def __init__( raise ValueError(f"Invalid value of `backend`: {backend}.") self.criterion_type = criterion_type + self.loss_type = loss_type + self._3d_input = self.loss_type != "rnnt" if self.split_batch_size > 0: # don't need to guard grad_utils @@ -143,20 +157,21 @@ def forward(self, log_probs, targets, input_lengths, target_lengths): target_lengths = target_lengths.long() targets = targets.long() batch_size = log_probs.shape[0] - if self.split_batch_size > 0 and self.split_batch_size < batch_size: + if self.split_batch_size > 0 and self.split_batch_size <= batch_size: loss_list = [] for batch_idx in range(0, batch_size, self.split_batch_size): begin = batch_idx end = min(begin + self.split_batch_size, batch_size) - log_probs_part = log_probs[begin:end] - targets_part = targets[begin:end] input_lengths_part = input_lengths[begin:end] + log_probs_part = log_probs[begin:end, : input_lengths_part.max()] target_lengths_part = target_lengths[begin:end] + targets_part = targets[begin:end, : target_lengths_part.max()] loss_part, _ = ( - self._partial_loss(log_probs_part, targets_part, input_lengths_part, target_lengths_part,) + self._partial_loss(log_probs_part, targets_part, input_lengths_part, target_lengths_part) if log_probs_part.requires_grad - else self._loss(log_probs_part, targets_part, input_lengths_part, target_lengths_part,) + else self._loss(log_probs_part, targets_part, input_lengths_part, target_lengths_part) ) + del log_probs_part, targets_part, input_lengths_part, target_lengths_part loss_list.append(loss_part) loss = torch.cat(loss_list, 0) else: diff --git a/nemo/collections/asr/models/__init__.py b/nemo/collections/asr/models/__init__.py index a2af88130a90..0b77d42488a8 100644 --- a/nemo/collections/asr/models/__init__.py +++ b/nemo/collections/asr/models/__init__.py @@ -21,7 +21,12 @@ from nemo.collections.asr.models.enhancement_models import EncMaskDecAudioToAudioModel from nemo.collections.asr.models.hybrid_rnnt_ctc_bpe_models import EncDecHybridRNNTCTCBPEModel from nemo.collections.asr.models.hybrid_rnnt_ctc_models import EncDecHybridRNNTCTCModel -from nemo.collections.asr.models.k2_sequence_models import EncDecK2SeqModel, EncDecK2SeqModelBPE +from nemo.collections.asr.models.k2_sequence_models import ( + EncDecK2RnntSeqModel, + EncDecK2RnntSeqModelBPE, + EncDecK2SeqModel, + EncDecK2SeqModelBPE, +) from nemo.collections.asr.models.label_models import EncDecSpeakerLabelModel from nemo.collections.asr.models.msdd_models import EncDecDiarLabelModel from nemo.collections.asr.models.rnnt_bpe_models import EncDecRNNTBPEModel diff --git a/nemo/collections/asr/models/configs/aligner_config.py b/nemo/collections/asr/models/configs/aligner_config.py index ff2a99208161..06b41b5c115b 100644 --- a/nemo/collections/asr/models/configs/aligner_config.py +++ b/nemo/collections/asr/models/configs/aligner_config.py @@ -25,8 +25,8 @@ class AlignerCTCConfig: @dataclass class AlignerRNNTConfig: - # Arguments will appear with RNNT support - pass + predictor_window_size: int = 0 + predictor_step_size: int = 1 @dataclass diff --git a/nemo/collections/asr/models/configs/k2_sequence_models_config.py b/nemo/collections/asr/models/configs/k2_sequence_models_config.py index b0a0d6b81107..5a112f626f46 100644 --- a/nemo/collections/asr/models/configs/k2_sequence_models_config.py +++ b/nemo/collections/asr/models/configs/k2_sequence_models_config.py @@ -22,6 +22,7 @@ @dataclass class GraphModuleConfig: criterion_type: str = "ml" + loss_type: str = "ctc" split_batch_size: int = 0 dec_type: str = "topo" transcribe_training: bool = True diff --git a/nemo/collections/asr/models/k2_aligner_model.py b/nemo/collections/asr/models/k2_aligner_model.py index 7591a9b51e58..402cf68ff234 100644 --- a/nemo/collections/asr/models/k2_aligner_model.py +++ b/nemo/collections/asr/models/k2_aligner_model.py @@ -12,15 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy from typing import Dict, List, Optional, Tuple, Union import numpy as np import torch -from omegaconf import DictConfig +from omegaconf import DictConfig, OmegaConf, open_dict from nemo.collections.asr.data.audio_to_ctm_dataset import FrameCtmUnit from nemo.collections.asr.data.audio_to_text_dali import DALIOutputs from nemo.collections.asr.models.asr_model import ASRModel +from nemo.utils import logging class AlignerWrapperModel(ASRModel): @@ -37,7 +39,6 @@ def __init__(self, model: ASRModel, cfg: DictConfig): self.alignment_type = cfg.get("alignment_type", "forced") self.word_output = cfg.get("word_output", True) self.cpu_decoding = cfg.get("cpu_decoding", False) - self.blank_id = self._model.decoder.num_classes_with_blank - 1 self.decode_batch_size = cfg.get("decode_batch_size", 0) # list possible alignment types here for future work @@ -85,7 +86,9 @@ def _init_ctc_alignment_specific(self, cfg: DictConfig): if decoder_module_cfg is not None: self.graph_decoder._decoder.intersect_pruned = decoder_module_cfg.get("intersect_pruned") self.graph_decoder._decoder.intersect_conf = decoder_module_cfg.get("intersect_conf") - elif self.alignment_type == "argmax": + return + + if self.alignment_type == "argmax": # we use transcribe_decoder to get topology-independent output if not self._model.use_graph_lm: self._model.transcribe_decoder = ViterbiDecoderWithGraph( @@ -96,6 +99,57 @@ def _init_ctc_alignment_specific(self, cfg: DictConfig): self._model.transcribe_decoder.output_aligned = True self._model.transcribe_decoder.split_batch_size = self.decode_batch_size self._model.use_graph_lm = False + return + + def _init_rnnt_alignment_specific(self, cfg: DictConfig): + """Part of __init__ intended to initialize attributes specific to the alignment type for RNNT models. + + This method is not supposed to be called outside of __init__. + """ + if self.alignment_type == "argmax": + return + + from nemo.collections.asr.modules.graph_decoder import ViterbiDecoderWithGraph + + if self.alignment_type == "forced": + self.predictor_window_size = cfg.rnnt_cfg.get("predictor_window_size", 0) + self.predictor_step_size = cfg.rnnt_cfg.get("predictor_step_size", 0) + + from nemo.collections.asr.parts.k2.utils import apply_rnnt_prune_ranges, get_uniform_rnnt_prune_ranges + + self.prepare_pruned_outputs = lambda encoder_outputs, encoded_len, decoder_outputs, transcript_len: apply_rnnt_prune_ranges( + encoder_outputs, + decoder_outputs, + get_uniform_rnnt_prune_ranges( + encoded_len, + transcript_len, + self.predictor_window_size + 1, + self.predictor_step_size, + encoder_outputs.size(1), + ).to(device=encoder_outputs.device), + ) + + from nemo.collections.asr.parts.k2.classes import GraphModuleConfig + + self.graph_decoder = ViterbiDecoderWithGraph( + num_classes=self.blank_id, + backend="k2", + dec_type="topo_rnnt_ali", + split_batch_size=self.decode_batch_size, + graph_module_cfg=OmegaConf.structured( + GraphModuleConfig( + topo_type="minimal", + predictor_window_size=self.predictor_window_size, + predictor_step_size=self.predictor_step_size, + ) + ), + ) + # override decoder args if a config is provided + decoder_module_cfg = cfg.get("decoder_module_cfg", None) + if decoder_module_cfg is not None: + self.graph_decoder._decoder.intersect_pruned = decoder_module_cfg.get("intersect_pruned") + self.graph_decoder._decoder.intersect_conf = decoder_module_cfg.get("intersect_conf") + return def _init_model_specific(self, cfg: DictConfig): """Part of __init__ intended to initialize attributes specific to the model type. @@ -106,6 +160,8 @@ def _init_model_specific(self, cfg: DictConfig): if isinstance(self._model, EncDecCTCModel): self.model_type = "ctc" + self.blank_id = self._model.decoder.num_classes_with_blank - 1 + self._predict_impl = self._predict_impl_ctc prob_suppress_index = cfg.ctc_cfg.get("prob_suppress_index", -1) prob_suppress_value = cfg.ctc_cfg.get("prob_suppress_value", 1.0) @@ -129,10 +185,47 @@ def _init_model_specific(self, cfg: DictConfig): if isinstance(self._model, EncDecRNNTModel): self.model_type = "rnnt" - raise NotImplementedError("RNNT models are not supported at the moment.") + self.blank_id = self._model.joint.num_classes_with_blank - 1 + self.log_softmax = None if self._model.joint.log_softmax is None else not self._model.joint.log_softmax + self._predict_impl = self._predict_impl_rnnt + + decoding_config = copy.deepcopy(self._model.cfg.decoding) + decoding_config.strategy = "greedy_batch" + with open_dict(decoding_config): + decoding_config.preserve_alignments = True + decoding_config.fused_batch_size = -1 + self._model.change_decoding_strategy(decoding_config) + self._init_rnnt_alignment_specific(cfg) + return raise RuntimeError(f"Unsupported model type: {type(self._model)}") + def _rnnt_joint_pruned( + self, + encoder_outputs: torch.Tensor, + encoded_len: torch.Tensor, + decoder_outputs: torch.Tensor, + transcript_len: torch.Tensor, + ) -> torch.Tensor: + """A variant of the RNNT Joiner tensor calculation with pruned Encoder and Predictor sum. + Only the uniform pruning is supported at the moment. + """ + encoder_outputs = self._model.joint.enc(encoder_outputs.transpose(1, 2)) # (B, T, H) + decoder_outputs = self._model.joint.pred(decoder_outputs.transpose(1, 2)) # (B, U, H) + + encoder_outputs_pruned, decoder_outputs_pruned = self.prepare_pruned_outputs( + encoder_outputs, encoded_len, decoder_outputs, transcript_len + ) + res = self._model.joint.joint_net(encoder_outputs_pruned + decoder_outputs_pruned) + # copied from model.joint.joint(...) + if self._model.joint.log_softmax is None: + if not res.is_cuda: + res = res.log_softmax(dim=-1) + else: + if self._model.joint.log_softmax: + res = res.log_softmax(dim=-1) + return res + def _apply_prob_suppress(self, log_probs: torch.Tensor) -> torch.Tensor: """Multiplies probability of an element with index self.prob_suppress_index by self.prob_suppress_value times with stochasticity preservation of the log_probs tensor. @@ -185,6 +278,49 @@ def _prepare_ctc_argmax_predictions( predictions.append(pred_candidate.to(device=greedy_predictions.device)) return predictions, probs + def _predict_impl_rnnt_argmax( + self, + encoded: torch.Tensor, + encoded_len: torch.Tensor, + transcript: torch.Tensor, + transcript_len: torch.Tensor, + sample_id: torch.Tensor, + ) -> List[Tuple[int, 'FrameCtmUnit']]: + """Builds time alignment of an encoded sequence. + This method assumes that the RNNT model is used and the alignment type is `argmax`. + + It produces a list of sample ids and fours: (label, start_frame, length, probability), called FrameCtmUnit. + """ + hypotheses = self._model.decoding.rnnt_decoder_predictions_tensor( + encoded, encoded_len, return_hypotheses=True + )[0] + results = [] + for s_id, hypothesis in zip(sample_id, hypotheses): + pred_ids = hypothesis.y_sequence.tolist() + tokens = self._model.decoding.decode_ids_to_tokens(pred_ids) + token_begin = hypothesis.timestep + token_len = [j - i for i, j in zip(token_begin, token_begin[1:] + [len(hypothesis.alignments)])] + # we have no token probabilities for the argmax rnnt setup + token_prob = [1.0] * len(tokens) + if self.word_output: + words = [w for w in self._model.decoding.decode_tokens_to_str(pred_ids).split(" ") if w != ""] + words, word_begin, word_len, word_prob = ( + self._process_tokens_to_words(tokens, token_begin, token_len, token_prob, words) + if hasattr(self._model, "tokenizer") + else self._process_char_with_space_to_words(tokens, token_begin, token_len, token_prob, words) + ) + results.append( + (s_id, [FrameCtmUnit(t, b, l, p) for t, b, l, p in zip(words, word_begin, word_len, word_prob)]) + ) + else: + results.append( + ( + s_id, + [FrameCtmUnit(t, b, l, p) for t, b, l, p in zip(tokens, token_begin, token_len, token_prob)], + ) + ) + return results + def _process_tokens_to_words( self, tokens: List[str], @@ -202,12 +338,27 @@ def _process_tokens_to_words( self._model.tokenizer.text_to_tokens(words[0] + " ") ) word_begin, word_len, word_prob = [], [], [] + token_len_nonzero = [(t_l if t_l > 0 else 1) for t_l in token_len] i = 0 for word in words: - j = i + len(self._model.tokenizer.text_to_tokens(word)) + loc_tokens = self._model.tokenizer.text_to_tokens(word) + step = len(loc_tokens) + # we assume that an empty word consists of only one token + # drop current token + if step == 0: + token_begin[i + 1] = token_begin[i] + token_len[i + 1] += token_len[i] + token_len_nonzero[i + 1] += token_len_nonzero[i] + del tokens[i], token_begin[i], token_len[i], token_len_nonzero[i], token_prob[i] + continue + # fix tokenization + if step == 2 and loc_tokens[-1] == "??": + step -= 1 + j = i + step word_begin.append(token_begin[i]) word_len.append(sum(token_len[i:j])) - word_prob.append(sum(token_prob[k] * token_len[k] for k in range(i, j)) / word_len[-1]) + denominator = sum(token_len_nonzero[i:j]) + word_prob.append(sum(token_prob[k] * token_len_nonzero[k] for k in range(i, j)) / denominator) i = j return words, word_begin, word_len, word_prob @@ -227,15 +378,18 @@ def _process_char_with_space_to_words( # suppose that there are no whitespaces anywhere except between words space_idx = (np.array(tokens) == " ").nonzero()[0].tolist() assert len(words) == len(space_idx) + 1 + token_len_nonzero = [(t_l if t_l > 0 else 1) for t_l in token_len] if len(space_idx) == 0: word_begin = [token_begin[0]] word_len = [sum(token_len)] - word_prob = [sum(t_p * t_l for t_p, t_l in zip(token_prob, token_len)) / word_len[0]] + denominator = sum(token_len_nonzero) + word_prob = [sum(t_p * t_l for t_p, t_l in zip(token_prob, token_len_nonzero)) / denominator] else: space_word = "[SEP]" word_begin = [token_begin[0]] word_len = [sum(token_len[: space_idx[0]])] - word_prob = [sum(token_prob[k] * token_len[k] for k in range(space_idx[0])) / word_len[-1]] + denominator = sum(token_len_nonzero[: space_idx[0]]) + word_prob = [sum(token_prob[k] * token_len_nonzero[k] for k in range(space_idx[0])) / denominator] words_with_space = [words[0]] for word, i, j in zip(words[1:], space_idx, space_idx[1:] + [len(tokens)]): # append space @@ -246,7 +400,8 @@ def _process_char_with_space_to_words( # append next word word_begin.append(token_begin[i + 1]) word_len.append(sum(token_len[i + 1 : j])) - word_prob.append(sum(token_prob[k] * token_len[k] for k in range(i + 1, j)) / word_len[-1]) + denominator = sum(token_len_nonzero[i + 1 : j]) + word_prob.append(sum(token_prob[k] * token_len_nonzero[k] for k in range(i + 1, j)) / denominator) words_with_space.append(word) words = words_with_space return words, word_begin, word_len, word_prob @@ -262,18 +417,28 @@ def _results_to_ctmUnits( if len(pred) == 0: return (s_id, []) - non_blank_idx = (pred != self.blank_id).nonzero(as_tuple=True)[0].tolist() + non_blank_idx = (pred != self.blank_id).nonzero(as_tuple=True)[0].cpu() pred_ids = pred[non_blank_idx].tolist() - tokens = self._model._wer.decode_ids_to_tokens(pred_ids) prob_list = prob.tolist() - token_begin = non_blank_idx - token_len, token_prob = [], [] - for i, j in zip(token_begin, token_begin[1:] + [len(pred)]): - t_l = j - i - token_len.append(t_l) - token_prob.append(sum(prob_list[i:j]) / (t_l)) + if self.model_type == "rnnt": + wer_module = self._model.decoding + # for rnnt forced alignment we always have num_blanks == num_frames, + # thus len(pred) == num_frames + num_non_blanks + token_begin = non_blank_idx - torch.arange(len(non_blank_idx)) + token_end = torch.cat((token_begin[1:], torch.tensor([len(pred) - len(non_blank_idx)]))) + else: + wer_module = self._model._wer + token_begin = non_blank_idx + token_end = torch.cat((token_begin[1:], torch.tensor([len(pred)]))) + tokens = wer_module.decode_ids_to_tokens(pred_ids) + token_len = (token_end - token_begin).tolist() + token_begin = token_begin.tolist() + token_prob = [ + sum(prob_list[i:j]) / (j - i) + for i, j in zip(non_blank_idx.tolist(), non_blank_idx[1:].tolist() + [len(pred)]) + ] if self.word_output: - words = [w for w in self._model._wer.decode_tokens_to_str(pred_ids).split(" ") if w != ""] + words = wer_module.decode_tokens_to_str(pred_ids).split(" ") words, word_begin, word_len, word_prob = ( self._process_tokens_to_words(tokens, token_begin, token_len, token_prob, words) if hasattr(self._model, "tokenizer") @@ -282,21 +447,20 @@ def _results_to_ctmUnits( return s_id, [FrameCtmUnit(t, b, l, p) for t, b, l, p in zip(words, word_begin, word_len, word_prob)] return s_id, [FrameCtmUnit(t, b, l, p) for t, b, l, p in zip(tokens, token_begin, token_len, token_prob)] - @torch.no_grad() - def predict_step(self, batch, batch_idx, dataloader_idx=0) -> List[Tuple[int, 'FrameCtmUnit']]: - signal, signal_len, transcript, transcript_len, sample_id = batch - - if self.model_type == "ctc": - if isinstance(batch, DALIOutputs) and batch.has_processed_signal: - log_probs, encoded_len, _ = self._model.forward( - processed_signal=signal, processed_signal_length=signal_len - ) - else: - log_probs, encoded_len, _ = self._model.forward(input_signal=signal, input_signal_length=signal_len) - elif self.model_type == "rnnt": - raise NotImplementedError("RNNT models are not supported at the moment.") - else: - raise RuntimeError(f"Unsupported model type: {type(self._model)}") + def _predict_impl_ctc( + self, + encoded: torch.Tensor, + encoded_len: torch.Tensor, + transcript: torch.Tensor, + transcript_len: torch.Tensor, + sample_id: torch.Tensor, + ) -> List[Tuple[int, 'FrameCtmUnit']]: + """Builds time alignment of an encoded sequence. + This method assumes that the CTC model is used. + + It produces a list of sample ids and fours: (label, start_frame, length, probability), called FrameCtmUnit. + """ + log_probs = encoded if self.prob_suppress_value != 1.0: log_probs = self._apply_prob_suppress(log_probs) @@ -320,16 +484,122 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0) -> List[Tuple[int, 'F for s_id, pred, prob in zip(sample_id.tolist(), predictions, probs) ] - @torch.no_grad() - def transcribe( + def _predict_impl_rnnt( self, - paths2audio_files: List[str], - batch_size: int = 4, - logprobs: bool = False, - return_hypotheses: bool = False, - num_workers: int = None, - ) -> List[str]: - raise NotImplementedError() + encoded: torch.Tensor, + encoded_len: torch.Tensor, + transcript: torch.Tensor, + transcript_len: torch.Tensor, + sample_id: torch.Tensor, + ) -> List[Tuple[int, 'FrameCtmUnit']]: + """Builds time alignment of an encoded sequence. + This method assumes that the RNNT model is used. + + It produces a list of sample ids and fours: (label, start_frame, length, probability), called FrameCtmUnit. + """ + if self.alignment_type == "argmax": + return self._predict_impl_rnnt_argmax(encoded, encoded_len, transcript, transcript_len, sample_id) + elif self.alignment_type == "forced": + decoded = self._model.decoder(targets=transcript, target_length=transcript_len)[0] + log_probs = ( + self._rnnt_joint_pruned(encoded, encoded_len, decoded, transcript_len) + if self.predictor_window_size > 0 and self.predictor_window_size < transcript_len.max() + else self._model.joint(encoder_outputs=encoded, decoder_outputs=decoded) + ) + apply_log_softmax = True if self.log_softmax is None and encoded.is_cuda else self.log_softmax + if apply_log_softmax: + log_probs = log_probs.log_softmax(dim=-1) + if self.cpu_decoding: + log_probs, encoded_len, transcript, transcript_len = ( + log_probs.cpu(), + encoded_len.cpu(), + transcript.cpu(), + transcript_len.cpu(), + ) + predictions, probs = self.graph_decoder.align(log_probs, encoded_len, transcript, transcript_len) + return [ + self._results_to_ctmUnits(s_id, pred, prob) + for s_id, pred, prob in zip(sample_id.tolist(), predictions, probs) + ] + else: + raise NotImplementedError() + + @torch.no_grad() + def predict_step(self, batch, batch_idx, dataloader_idx=0) -> List[Tuple[int, 'FrameCtmUnit']]: + signal, signal_len, transcript, transcript_len, sample_id = batch + + if isinstance(batch, DALIOutputs) and batch.has_processed_signal: + encoded, encoded_len = self._model.forward(processed_signal=signal, processed_signal_length=signal_len)[:2] + else: + encoded, encoded_len = self._model.forward(input_signal=signal, input_signal_length=signal_len)[:2] + + return self._predict_impl(encoded, encoded_len, transcript, transcript_len, sample_id) + + @torch.no_grad() + def transcribe(self, manifest: List[str], batch_size: int = 4, num_workers: int = None,) -> List['FrameCtmUnit']: + """ + Does alignment. Use this method for debugging and prototyping. + + Args: + + manifest: path to dataset JSON manifest file (in NeMo format). \ + Recommended length per audio file is between 5 and 25 seconds. + batch_size: (int) batch size to use during inference. \ + Bigger will result in better throughput performance but would use more memory. + num_workers: (int) number of workers for DataLoader + + Returns: + A list of four: (label, start_frame, length, probability), called FrameCtmUnit, \ + in the same order as in the manifest. + """ + hypotheses = [] + # Model's mode and device + mode = self._model.training + device = next(self._model.parameters()).device + dither_value = self._model.preprocessor.featurizer.dither + pad_to_value = self._model.preprocessor.featurizer.pad_to + + if num_workers is None: + num_workers = min(batch_size, os.cpu_count() - 1) + + try: + self._model.preprocessor.featurizer.dither = 0.0 + self._model.preprocessor.featurizer.pad_to = 0 + + # Switch model to evaluation mode + self._model.eval() + # Freeze the encoder and decoder modules + self._model.encoder.freeze() + self._model.decoder.freeze() + if hasattr(self._model, "joint"): + self._model.joint.freeze() + logging_level = logging.get_verbosity() + logging.set_verbosity(logging.WARNING) + + config = { + 'manifest_filepath': manifest, + 'batch_size': batch_size, + 'num_workers': num_workers, + } + temporary_datalayer = self._model._setup_transcribe_dataloader(config) + for test_batch in tqdm(temporary_datalayer, desc="Aligning"): + test_batch[0] = test_batch[0].to(device) + test_batch[1] = test_batch[1].to(device) + hypotheses += [unit for i, unit in self.predict_step(test_batch, 0)] + del test_batch + finally: + # set mode back to its original value + self._model.train(mode=mode) + self._model.preprocessor.featurizer.dither = dither_value + self._model.preprocessor.featurizer.pad_to = pad_to_value + + logging.set_verbosity(logging_level) + if mode is True: + self._model.encoder.unfreeze() + self._model.decoder.unfreeze() + if hasattr(self._model, "joint"): + self._model.joint.unfreeze() + return hypotheses def setup_training_data(self, train_data_config: Optional[Union[DictConfig, Dict]]): raise RuntimeError("This module cannot be used in training.") diff --git a/nemo/collections/asr/models/k2_sequence_models.py b/nemo/collections/asr/models/k2_sequence_models.py index 95eb6b4208aa..087e9e41b85d 100644 --- a/nemo/collections/asr/models/k2_sequence_models.py +++ b/nemo/collections/asr/models/k2_sequence_models.py @@ -19,6 +19,8 @@ from nemo.collections.asr.models.ctc_bpe_models import EncDecCTCModelBPE from nemo.collections.asr.models.ctc_models import EncDecCTCModel +from nemo.collections.asr.models.rnnt_bpe_models import EncDecRNNTBPEModel +from nemo.collections.asr.models.rnnt_models import EncDecRNNTModel from nemo.collections.asr.parts.k2.classes import ASRK2Mixin from nemo.core.classes.common import PretrainedModelInfo, typecheck from nemo.utils import logging @@ -28,6 +30,9 @@ class EncDecK2SeqModel(EncDecCTCModel, ASRK2Mixin): """Encoder decoder models with various lattice losses.""" def __init__(self, cfg: DictConfig, trainer: Trainer = None): + loss_type = cfg.graph_module_cfg.get("loss_type", "ctc") + if loss_type != "ctc" and loss_type != "mmi": + raise ValueError(f"Class {self.__class__.__name__} does not support `loss_type`={loss_type}") super().__init__(cfg=cfg, trainer=trainer) self._init_k2() @@ -108,6 +113,9 @@ class EncDecK2SeqModelBPE(EncDecCTCModelBPE, ASRK2Mixin): """Encoder decoder models with Byte Pair Encoding and various lattice losses.""" def __init__(self, cfg: DictConfig, trainer: Trainer = None): + loss_type = cfg.graph_module_cfg.get("loss_type", "ctc") + if loss_type != "ctc" and loss_type != "mmi": + raise ValueError(f"Class {self.__class__.__name__} does not support `loss_type`={loss_type}") super().__init__(cfg=cfg, trainer=trainer) self._init_k2() @@ -182,3 +190,109 @@ def forward( return self._forward_k2_post_processing( log_probs=log_probs, encoded_length=encoded_len, greedy_predictions=greedy_predictions ) + + +class EncDecK2RnntSeqModel(EncDecRNNTModel, ASRK2Mixin): + """Encoder decoder models with various lattice losses.""" + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + loss_type = cfg.graph_module_cfg.get("loss_type", "rnnt") + criterion_type = cfg.graph_module_cfg.get("criterion_type", "ml") + if loss_type != "rnnt" or criterion_type != "ml": + raise ValueError( + f"""Class {self.__class__.__name__} does not support + `criterion_type`={criterion_type} with `loss_type`={loss_type}""" + ) + super().__init__(cfg=cfg, trainer=trainer) + self._init_k2() + + @classmethod + def list_available_models(cls) -> Optional[PretrainedModelInfo]: + """ + This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud. + + Returns: + List of available pre-trained models. + """ + pass + + def change_vocabulary(self, new_vocabulary: List[str]): + """ + Changes vocabulary used during CTC decoding process. Use this method when fine-tuning on from pre-trained model. + This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would + use it if you want to use pretrained encoder when fine-tuning on a data in another language, or when you'd need + model to learn capitalization, punctuation and/or special characters. + + If new_vocabulary == self.decoder.vocabulary then nothing will be changed. + + Args: + new_vocabulary: list with new vocabulary. Must contain at least 2 elements. Typically, \ + this is target alphabet. + + Returns: None + + """ + super().change_vocabulary(new_vocabulary) + + if self.use_graph_lm: + self.token_lm = None + logging.warning( + f"""With .change_vocabulary() call for a model with criterion_type=`{self.loss.criterion_type}`, + a new token_lm has to be set manually: call .update_k2_modules(new_cfg) + or update .graph_module_cfg.backend_cfg.token_lm before calling this method.""" + ) + + self.update_k2_modules(self.graph_module_cfg) + + +class EncDecK2RnntSeqModelBPE(EncDecRNNTBPEModel, ASRK2Mixin): + """Encoder decoder models with Byte Pair Encoding and various lattice losses.""" + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + loss_type = cfg.graph_module_cfg.get("loss_type", "rnnt") + criterion_type = cfg.graph_module_cfg.get("criterion_type", "ml") + if loss_type != "rnnt" or criterion_type != "ml": + raise ValueError( + f"""Class {self.__class__.__name__} does not support + `criterion_type`={criterion_type} with `loss_type`={loss_type}""" + ) + super().__init__(cfg=cfg, trainer=trainer) + self._init_k2() + + @classmethod + def list_available_models(cls) -> Optional[PretrainedModelInfo]: + """ + This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud. + + Returns: + List of available pre-trained models. + """ + pass + + def change_vocabulary(self, new_tokenizer_dir: str, new_tokenizer_type: str): + """ + Changes vocabulary of the tokenizer used during CTC decoding process. + Use this method when fine-tuning on from pre-trained model. + This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would + use it if you want to use pretrained encoder when fine-tuning on a data in another language, or when you'd need + model to learn capitalization, punctuation and/or special characters. + + Args: + new_tokenizer_dir: Path to the new tokenizer directory. + new_tokenizer_type: Either `bpe` or `wpe`. `bpe` is used for SentencePiece tokenizers, + whereas `wpe` is used for `BertTokenizer`. + + Returns: None + + """ + super().change_vocabulary(new_tokenizer_dir, new_tokenizer_type) + + if self.use_graph_lm: + self.token_lm = None + logging.warning( + f"""With .change_vocabulary() call for a model with criterion_type=`{self.loss.criterion_type}`, + a new token_lm has to be set manually: call .update_k2_modules(new_cfg) + or update .graph_module_cfg.backend_cfg.token_lm before calling this method.""" + ) + + self.update_k2_modules(self.graph_module_cfg) diff --git a/nemo/collections/asr/modules/graph_decoder.py b/nemo/collections/asr/modules/graph_decoder.py index 6f36f28cb620..d66dbd734a69 100644 --- a/nemo/collections/asr/modules/graph_decoder.py +++ b/nemo/collections/asr/modules/graph_decoder.py @@ -25,7 +25,7 @@ class ViterbiDecoderWithGraph(NeuralModule): """Viterbi Decoder with WFSA (Weighted Finite State Automaton) graphs. Note: - Requires k2 v1.11 or later to be installed to use this module. + Requires k2 v1.14 or later to be installed to use this module. Decoder can be set up via the config, and optionally be passed keyword arguments as follows. @@ -74,8 +74,8 @@ def input_types(self): """Returns definitions of module input ports. """ return { - "log_probs": NeuralType(("B", "T", "D"), LogprobsType()), - "log_probs_length": NeuralType(tuple("B"), LengthsType()), + "log_probs": NeuralType(("B", "T", "D") if self._3d_input else ("B", "T", "T", "D"), LogprobsType()), + "input_lengths": NeuralType(tuple("B"), LengthsType()), } @property @@ -113,20 +113,23 @@ def __init__( # we assume that self._blank + 1 == num_classes if backend == "k2": if self.dec_type == "topo": - from nemo.collections.asr.parts.k2.graph_decoders import BaseDecoder as Decoder + from nemo.collections.asr.parts.k2.graph_decoders import CtcDecoder as Decoder + elif self.dec_type == "topo_rnnt_ali": + from nemo.collections.asr.parts.k2.graph_decoders import RnntAligner as Decoder elif self.dec_type == "token_lm": from nemo.collections.asr.parts.k2.graph_decoders import TokenLMDecoder as Decoder - elif self.dec_type == "looseali": + elif self.dec_type == "loose_ali": raise NotImplementedError() elif self.dec_type == "tlg": - raise NotImplementedError(f"dec_type {dec_type} is not supported at the moment") + raise NotImplementedError(f"dec_type {self.dec_type} is not supported at the moment") else: - raise ValueError(f"Unsupported dec_type: {dec_type}") + raise ValueError(f"Unsupported dec_type: {self.dec_type}") self._decoder = Decoder(num_classes=self._blank + 1, blank=self._blank, cfg=graph_module_cfg) elif backend == "gtn": raise NotImplementedError("gtn-backed decoding is not implemented") + self._3d_input = self.dec_type != "topo_rnnt" super().__init__() def update_graph(self, graph): @@ -151,17 +154,17 @@ def _forward_impl(self, log_probs, log_probs_length, targets=None, target_length a, b, c, d, return_lattices=False, return_ilabels=False, output_aligned=True ) batch_size = log_probs.shape[0] - if self.split_batch_size > 0 and self.split_batch_size < batch_size: + if self.split_batch_size > 0 and self.split_batch_size <= batch_size: predictions = [] probs = [] for batch_idx in range(0, batch_size, self.split_batch_size): begin = batch_idx end = min(begin + self.split_batch_size, batch_size) - log_probs_part = log_probs[begin:end] log_probs_length_part = log_probs_length[begin:end] + log_probs_part = log_probs[begin:end, : log_probs_length_part.max()] if align: - targets_part = targets[begin:end] target_length_part = target_length[begin:end] + targets_part = targets[begin:end, : target_length_part.max()] predictions_part, probs_part = decode_func( log_probs_part, log_probs_length_part, targets_part, target_length_part ) diff --git a/nemo/collections/asr/parts/k2/autograd.py b/nemo/collections/asr/parts/k2/autograd.py deleted file mode 100644 index 1e4474aa67c9..000000000000 --- a/nemo/collections/asr/parts/k2/autograd.py +++ /dev/null @@ -1,96 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# See ../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script was copied from https://github.com/k2-fsa/k2/blob/master/k2/python/k2/sparse/autograd.py -# with minor changes fixing uncoalesced gradients. - -import torch - - -class _AbsFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, sparse_tensor: torch.Tensor) -> torch.Tensor: - """Compute the `abs` of a sparse tensor. - Args: - sparse_tensor: - A sparse tensor. It has to satisfy:: - assert sparse_tensor.is_coalesced() - Returns: - The absolute value of the sparse tensor. - The `abs` operation is applied element-wise. - """ - assert sparse_tensor.is_sparse - assert sparse_tensor.is_coalesced() - - indices = sparse_tensor.indices().clone() - values = sparse_tensor.values() - size = sparse_tensor.size() - - values_abs = values.abs() - - ans = torch.sparse_coo_tensor( - indices=indices, values=values_abs, size=size, dtype=sparse_tensor.dtype, device=sparse_tensor.device, - ) - - ctx.save_for_backward(sparse_tensor) - return ans - - @staticmethod - def backward(ctx, ans_grad: torch.Tensor) -> torch.Tensor: - (sparse_tensor,) = ctx.saved_tensors - - indices = sparse_tensor.indices().clone() - values = sparse_tensor.values() - size = sparse_tensor.size() - - sparse_tensor_grad_values = ans_grad.coalesce().values() * values.sign() - - sparse_tensor_grad = torch.sparse_coo_tensor( - indices=indices, - values=sparse_tensor_grad_values, - size=size, - dtype=sparse_tensor.dtype, - device=sparse_tensor.device, - ) - - return sparse_tensor_grad - - -def sparse_abs(sparse_tensor: torch.Tensor) -> torch.Tensor: - """Compute the `abs` of a sparse tensor. - It supports autograd. - Args: - sparse_tensor: - A sparse tensor. It has to satisfy:: - assert sparse_tensor.is_coalesced() - Returns: - The absolute value of the sparse tensor. - The `abs` operation is applied element-wise. - """ - return _AbsFunction.apply(sparse_tensor) diff --git a/nemo/collections/asr/parts/k2/classes.py b/nemo/collections/asr/parts/k2/classes.py index 59aed0c924c2..bb749e15d4c6 100644 --- a/nemo/collections/asr/parts/k2/classes.py +++ b/nemo/collections/asr/parts/k2/classes.py @@ -41,16 +41,16 @@ class GraphModuleConfig: topo_type: str = "default" topo_with_self_loops: bool = True - graph_type: str = "topo" - loss_type: str = "mmi" token_lm: Optional[Any] = None intersect_pruned: bool = False intersect_conf: GraphIntersectDenseConfig = GraphIntersectDenseConfig() boost_coeff: float = 0.0 + predictor_window_size: int = 0 + predictor_step_size: int = 1 class ASRK2Mixin(ABC): - """ k2 Mixin class that simplifies the construction of various models with k2-based losses. + """k2 Mixin class that simplifies the construction of various models with k2-based losses. It does the following: - Sets up the graph loss and decoder (methods _init_k2 and update_k2_modules). @@ -97,20 +97,28 @@ def update_k2_modules(self, input_cfg: DictConfig): if hasattr(self, "transcribe_decoder"): del self.transcribe_decoder + if hasattr(self, "joint"): + # RNNT + num_classes = self.joint.num_classes_with_blank - 1 + else: + # CTC, MMI, ... + num_classes = self.decoder.num_classes_with_blank - 1 + remove_consecutive = input_cfg.backend_cfg.get("topo_with_self_loops", True) and input_cfg.backend_cfg.get( + "topo_type", "default" + ) not in ["forced_blank", "identity",] + self._wer.remove_consecutive = remove_consecutive + from nemo.collections.asr.losses.lattice_losses import LatticeLoss self.loss = LatticeLoss( - num_classes=self.decoder.num_classes_with_blank - 1, + num_classes=num_classes, reduction=self._cfg.get("ctc_reduction", "mean_batch"), backend="k2", criterion_type=input_cfg.get("criterion_type", "ml"), + loss_type=input_cfg.get("loss_type", "ctc"), split_batch_size=input_cfg.get("split_batch_size", 0), graph_module_cfg=input_cfg.backend_cfg, ) - remove_consecutive = input_cfg.backend_cfg.get("topo_with_self_loops", True) and input_cfg.backend_cfg.get( - "topo_type", "default" - ) not in ["forced_blank", "identity",] - self._wer.remove_consecutive = remove_consecutive criterion_type = self.loss.criterion_type self.use_graph_lm = criterion_type == "map" @@ -126,7 +134,7 @@ def update_k2_modules(self, input_cfg: DictConfig): from nemo.collections.asr.modules.graph_decoder import ViterbiDecoderWithGraph self.transcribe_decoder = ViterbiDecoderWithGraph( - num_classes=self.decoder.num_classes_with_blank - 1, + num_classes=num_classes, backend="k2", dec_type="token_lm", return_type="1best", diff --git a/nemo/collections/asr/parts/k2/grad_utils.py b/nemo/collections/asr/parts/k2/grad_utils.py index b2f153a84a5b..6278fb9c86ca 100644 --- a/nemo/collections/asr/parts/k2/grad_utils.py +++ b/nemo/collections/asr/parts/k2/grad_utils.py @@ -14,14 +14,7 @@ import torch - -def make_non_pad_mask(input_lengths: torch.Tensor, seq_len: int): - batch_size = input_lengths.shape[0] - seq_range = torch.arange(0, seq_len, dtype=torch.int64) - seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, seq_len) - seq_length_expand = seq_range_expand.new(input_lengths.cpu()).unsqueeze(-1) - mask = seq_range_expand < seq_length_expand - return mask +from nemo.collections.asr.parts.k2.utils import make_non_pad_mask class GradExpNormalize(torch.autograd.Function): @@ -34,10 +27,9 @@ def forward( ctx, log_probs: torch.Tensor, input_lengths: torch.Tensor, reduction: str = "mean", ): mask = make_non_pad_mask(input_lengths, log_probs.shape[1]) - max_log_prob, _ = log_probs.max(-1) - probs = torch.exp(log_probs - max_log_prob.unsqueeze(-1)) + probs = log_probs.exp() norm_probs = torch.zeros_like(log_probs) - norm_probs[mask] += (probs / probs.sum(-1).unsqueeze(-1))[mask] + norm_probs[mask] += probs[mask] if reduction == "mean": norm_probs /= norm_probs.shape[0] ctx.save_for_backward(norm_probs) @@ -45,7 +37,7 @@ def forward( @staticmethod def backward(ctx, grad_output: torch.Tensor): - return grad_output + ctx.saved_tensors[0], None, None + return grad_output - grad_output.sum(-1).unsqueeze(-1) * ctx.saved_tensors[0], None, None class GradInsert(torch.autograd.Function): diff --git a/nemo/collections/asr/parts/k2/graph_compilers.py b/nemo/collections/asr/parts/k2/graph_compilers.py index 736be912acdb..8b82dcf5cc0f 100644 --- a/nemo/collections/asr/parts/k2/graph_compilers.py +++ b/nemo/collections/asr/parts/k2/graph_compilers.py @@ -30,15 +30,14 @@ import torch -from nemo.collections.asr.parts.k2.topologies import build_topo -from nemo.collections.asr.parts.k2.utils import compose_with_self_loops, intersect_with_self_loops +from nemo.collections.asr.parts.k2.utils import add_self_loops, compose_with_self_loops, intersect_with_self_loops from nemo.core.utils.k2_guard import k2 # import k2 from guard module class CtcTopologyCompiler(object): """Default graph compiler. - It applies its topology to the input token sequence to compile the numerator graph. + It applies its topology to the input token sequence to compile the supervision graph. Based on https://github.com/k2-fsa/snowfall/blob/master/snowfall/training/ctc_graph.py """ @@ -46,13 +45,16 @@ class CtcTopologyCompiler(object): def __init__( self, num_classes: int, + blank: int, topo_type: str = "default", topo_with_self_loops: bool = True, device: torch.device = torch.device("cpu"), ): self.topo_type = topo_type self.device = device - self.base_graph = k2.arc_sort(build_topo(topo_type, list(range(num_classes)), topo_with_self_loops)).to( + from nemo.collections.asr.parts.k2.topologies import build_topo + + self.base_graph = k2.arc_sort(build_topo(topo_type, list(range(num_classes)), blank, topo_with_self_loops)).to( self.device ) self.ctc_topo_inv = k2.arc_sort(self.base_graph.invert()) @@ -65,33 +67,33 @@ def to(self, device: torch.device): def compile(self, targets: torch.Tensor, target_lengths: torch.Tensor) -> 'k2.Fsa': token_ids_list = [t[:l].tolist() for t, l in zip(targets, target_lengths)] - # see https://github.com/k2-fsa/k2/issues/835 label_graph = k2.linear_fsa(token_ids_list).to(self.device) label_graph.aux_labels = label_graph.labels.clone() - decoding_graphs = compose_with_self_loops(self.base_graph, label_graph) - decoding_graphs = k2.arc_sort(decoding_graphs).to(self.device) + supervision_graphs = compose_with_self_loops(self.base_graph, label_graph) + supervision_graphs = k2.arc_sort(supervision_graphs).to(self.device) # make sure the gradient is not accumulated - decoding_graphs.requires_grad_(False) - return decoding_graphs + supervision_graphs.requires_grad_(False) + return supervision_graphs class CtcNumGraphCompiler(CtcTopologyCompiler): """Graph compiler with auxiliary graph to compose with the topology. - The numerator graph contains the auxiliary graph information. + The supervision graph contains the auxiliary graph information. """ def __init__( self, num_classes: int, + blank: int, topo_type: str = "default", topo_with_self_loops: bool = True, device: torch.device = torch.device("cpu"), aux_graph: Optional['k2.Fsa'] = None, ): - super().__init__(num_classes, topo_type, topo_with_self_loops, device) + super().__init__(num_classes, blank, topo_type, topo_with_self_loops, device) if aux_graph is None: - self.den_graph = k2.create_fsa_vec([self.ctc_topo_inv.invert()]).to(self.device) + self.decoding_graph = k2.create_fsa_vec([self.ctc_topo_inv.invert()]).to(self.device) else: self.base_graph = intersect_with_self_loops(self.ctc_topo_inv, aux_graph).invert_() self.base_graph = k2.arc_sort(self.base_graph).to(self.device) @@ -111,37 +113,79 @@ def compile( class MmiGraphCompiler(CtcNumGraphCompiler): """Graph compiler for MMI loss. - The denominator graph is a composition of the auxiliary graph and the topology. - It is returned along with the numerator graph on every compile() call. + The decoding graph is a composition of the auxiliary graph and the topology. + It is returned along with the supervision graph on every compile() call. """ def __init__( self, num_classes: int, + blank: int, topo_type: str = "default", topo_with_self_loops: bool = True, device: torch.device = torch.device("cpu"), aux_graph: Optional['k2.Fsa'] = None, ): - super().__init__(num_classes, topo_type, topo_with_self_loops, device, aux_graph) + super().__init__(num_classes, blank, topo_type, topo_with_self_loops, device, aux_graph) if aux_graph is None: - self.den_graph = k2.create_fsa_vec([self.ctc_topo_inv.invert()]).to(self.device) + self.decoding_graph = k2.create_fsa_vec([self.ctc_topo_inv.invert()]).to(self.device) else: - self.den_graph = k2.create_fsa_vec([self.base_graph.detach()]).to(self.device) + self.decoding_graph = k2.create_fsa_vec([self.base_graph.detach()]).to(self.device) def to(self, device: torch.device): - if self.den_graph is not None: - self.den_graph = self.den_graph.to(device) + if self.decoding_graph is not None: + self.decoding_graph = self.decoding_graph.to(device) super().to(device) def compile( self, targets: torch.Tensor, target_lengths: torch.Tensor, aux_graph: Optional['k2.Fsa'] = None, ) -> Tuple['k2.Fsa', 'k2.Fsa']: - num_graphs = super().compile(targets, target_lengths, aux_graph) - if aux_graph is None and self.den_graph is None: + supervision_graphs = super().compile(targets, target_lengths, aux_graph) + if aux_graph is None and self.decoding_graph is None: raise ValueError( - f"At least one of aux_graph and self.den_graph must be set: {aux_graph}, {self.den_graph}" + f"At least one of aux_graph and self.decoding_graph must be set: {aux_graph}, {self.decoding_graph}" ) elif aux_graph is not None: - self.den_graph = k2.create_fsa_vec([self.base_graph.detach()]).to(self.device) - return num_graphs, self.den_graph + self.decoding_graph = k2.create_fsa_vec([self.base_graph.detach()]).to(self.device) + return supervision_graphs, self.decoding_graph + + +class RnntTopologyCompiler(CtcTopologyCompiler): + """Default graph compiler for RNNT loss. + Each supervision graph is composed with the corresponding RNNT emission adapter. + + If max_adapter_length is provided, the maximum adapter length is limited. + + Note: + The actual number of classes is `num_classes` + 1 with as the class 0. + + Warning: + It is currently not recommended to use topologies other than "minimal". + """ + + def __init__( + self, + num_classes: int, + blank: int, + topo_type: str = "minimal", + topo_with_self_loops: bool = True, + device: torch.device = torch.device("cpu"), + max_adapter_length: int = 0, + ): + if topo_type == "compact": + raise NotImplementedError(f"This compiler does not support topo_type==`compact`.") + super().__init__(num_classes, blank, topo_type, topo_with_self_loops, device) + from nemo.collections.asr.parts.k2.topologies import RnntEmissionAdapterBuilder + + self.max_adapter_length = max_adapter_length + self._builder = RnntEmissionAdapterBuilder(list(range(num_classes)), blank, num_classes) + + def compile(self, targets: torch.Tensor, target_lengths: torch.Tensor) -> 'k2.Fsa': + supervision_graphs = add_self_loops(super().compile(targets, target_lengths), self._builder.eps_num, "input") + + adapters = self._builder( + torch.where(target_lengths > self.max_adapter_length, self.max_adapter_length, target_lengths) + if self.max_adapter_length > 0 and self.max_adapter_length < target_lengths.max() + else target_lengths + ).to(device=self.device) + return k2.intersect(adapters, supervision_graphs, treat_epsilons_specially=False) diff --git a/nemo/collections/asr/parts/k2/graph_decoders.py b/nemo/collections/asr/parts/k2/graph_decoders.py index d43a0b4e2a1e..33218588b79f 100644 --- a/nemo/collections/asr/parts/k2/graph_decoders.py +++ b/nemo/collections/asr/parts/k2/graph_decoders.py @@ -12,34 +12,29 @@ # See the License for the specific language governing permissions and # limitations under the License. +from abc import abstractmethod from typing import List, Optional, Tuple, Union import torch from omegaconf import DictConfig from nemo.collections.asr.parts.k2.classes import GraphIntersectDenseConfig -from nemo.collections.asr.parts.k2.graph_compilers import CtcNumGraphCompiler, CtcTopologyCompiler -from nemo.collections.asr.parts.k2.utils import ( - create_supervision, - get_arc_weights, - invert_permutation, - load_graph, - make_blank_first, - prep_padded_densefsavec, - shift_labels_inpl, -) -from nemo.core.utils.k2_guard import k2 # import k2 from guard module +from nemo.collections.asr.parts.k2.loss_mixins import CtcK2Mixin, RnntK2Mixin +from nemo.collections.asr.parts.k2.utils import invert_permutation, load_graph from nemo.utils import logging class BaseDecoder(object): """Base graph decoder with topology for decoding graph. Typically uses the same parameters as for the corresponding loss function. - + + Can do decoding and forced alignment. + cfg takes precedence over all optional parameters We keep explicit parameter setting to be able to create an instance without the need of a config. """ + @abstractmethod def __init__( self, num_classes: int, @@ -66,12 +61,8 @@ def __init__( self.topo_with_self_loops = topo_with_self_loops self.pad_fsavec = self.topo_type == "ctc_compact" self.intersect_conf = intersect_conf - if not hasattr(self, "graph_compiler") or self.graph_compiler is None: - self.graph_compiler = CtcTopologyCompiler( - self.num_classes, self.topo_type, self.topo_with_self_loops, self.device - ) - if not hasattr(self, "base_graph") or self.base_graph is None: - self.base_graph = k2.create_fsa_vec([self.graph_compiler.ctc_topo_inv.invert()]).to(self.device) + self.graph_compiler = None # expected to be initialized in child classes + self.base_graph = None # expected to be initialized in child classes self.decoding_graph = None def to(self, device: torch.device): @@ -86,10 +77,10 @@ def to(self, device: torch.device): def update_graph(self, graph: 'k2.Fsa'): raise NotImplementedError - def decode( + def _decode_impl( self, log_probs: torch.Tensor, - log_probs_length: torch.Tensor, + supervisions: torch.Tensor, return_lattices: bool = False, return_ilabels: bool = False, output_aligned: bool = True, @@ -97,43 +88,30 @@ def decode( if self.decoding_graph is None: self.decoding_graph = self.base_graph - if self.blank != 0: - # rearrange log_probs to put blank at the first place - # and shift targets to emulate blank = 0 - log_probs, _ = make_blank_first(self.blank, log_probs, None) - supervisions, order = create_supervision(log_probs_length) - if self.decoding_graph.shape[0] > 1: - self.decoding_graph = k2.index_fsa(self.decoding_graph, order).to(device=log_probs.device) - if log_probs.device != self.device: self.to(log_probs.device) - dense_fsa_vec = ( - prep_padded_densefsavec(log_probs, supervisions) - if self.pad_fsavec - else k2.DenseFsaVec(log_probs, supervisions) - ) + emissions_graphs = self._prepare_emissions_graphs(log_probs, supervisions) if self.intersect_pruned: lats = k2.intersect_dense_pruned( a_fsas=self.decoding_graph, - b_fsas=dense_fsa_vec, + b_fsas=emissions_graphs, search_beam=self.intersect_conf.search_beam, output_beam=self.intersect_conf.output_beam, min_active_states=self.intersect_conf.min_active_states, max_active_states=self.intersect_conf.max_active_states, ) else: - indices = torch.zeros(dense_fsa_vec.dim0(), dtype=torch.int32, device=self.device) + indices = torch.zeros(emissions_graphs.dim0(), dtype=torch.int32, device=self.device) dec_graphs = ( k2.index_fsa(self.decoding_graph, indices) if self.decoding_graph.shape[0] == 1 else self.decoding_graph ) - lats = k2.intersect_dense(dec_graphs, dense_fsa_vec, self.intersect_conf.output_beam) - if self.pad_fsavec: - shift_labels_inpl([lats], -1) + lats = k2.intersect_dense(dec_graphs, emissions_graphs, self.intersect_conf.output_beam) self.decoding_graph = None + order = supervisions[:, 0] if return_lattices: lats = k2.index_fsa(lats, invert_permutation(order).to(device=log_probs.device)) if self.blank != 0: @@ -145,24 +123,24 @@ def decode( shortest_path_fsas = k2.index_fsa( k2.shortest_path(lats, True), invert_permutation(order).to(device=log_probs.device), ) - shortest_paths = [] - probs = [] - # direct iterating does not work as expected - for i in range(shortest_path_fsas.shape[0]): - shortest_path_fsa = shortest_path_fsas[i] - labels = ( - shortest_path_fsa.labels[:-1].to(dtype=torch.long) - if return_ilabels - else shortest_path_fsa.aux_labels[:-1].to(dtype=torch.long) - ) - if self.blank != 0: - # suppose self.blank == self.num_classes - 1 - labels = torch.where(labels == 0, self.blank, labels - 1) - if not return_ilabels and not output_aligned: - labels = labels[labels != self.blank] - shortest_paths.append(labels[::2] if self.pad_fsavec else labels) - probs.append(get_arc_weights(shortest_path_fsa)[:-1].to(device=log_probs.device).exp()) - return shortest_paths, probs + return self._extract_labels_and_probabilities(shortest_path_fsas, return_ilabels, output_aligned) + + def decode( + self, + log_probs: torch.Tensor, + log_probs_length: torch.Tensor, + return_lattices: bool = False, + return_ilabels: bool = False, + output_aligned: bool = True, + ) -> Union['k2.Fsa', Tuple[List[torch.Tensor], List[torch.Tensor]]]: + log_probs, supervisions, _, _ = self._prepare_log_probs_and_targets(log_probs, log_probs_length, None, None) + return self._decode_impl( + log_probs, + supervisions, + return_lattices=return_lattices, + return_ilabels=return_ilabels, + output_aligned=output_aligned, + ) def align( self, @@ -174,14 +152,125 @@ def align( return_ilabels: bool = False, output_aligned: bool = True, ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: - if self.blank != 0: - targets = targets + 1 - if self.pad_fsavec: - targets = targets + 1 - self.decoding_graph = self.graph_compiler.compile(targets, target_lengths) - return self.decode( + log_probs, supervisions, targets, target_lengths = self._prepare_log_probs_and_targets( + log_probs, log_probs_length, targets, target_lengths + ) + order = supervisions[:, 0].to(dtype=torch.long) + self.decoding_graph = self.graph_compiler.compile(targets[order], target_lengths[order]) + return self._decode_impl( + log_probs, + supervisions, + return_lattices=return_lattices, + return_ilabels=return_ilabels, + output_aligned=output_aligned, + ) + + +class CtcDecoder(BaseDecoder, CtcK2Mixin): + """Regular CTC graph decoder with custom topologies. + Available topologies: + - `default`, with or without self-loops + - `compact`, with or without self-loops + - `shared_blank`, with or without self-loops + - `minimal`, without self-loops + + Can do decoding and forced alignment. + """ + + def __init__( + self, + num_classes: int, + blank: int, + cfg: Optional[DictConfig] = None, + intersect_pruned: bool = False, + intersect_conf: GraphIntersectDenseConfig = GraphIntersectDenseConfig(), + topo_type: str = "default", + topo_with_self_loops: bool = True, + device: torch.device = torch.device("cpu"), + ): + super().__init__( + num_classes, blank, cfg, intersect_pruned, intersect_conf, topo_type, topo_with_self_loops, device + ) + from nemo.collections.asr.parts.k2.graph_compilers import CtcTopologyCompiler + + self.graph_compiler = CtcTopologyCompiler( + self.num_classes, self.blank, self.topo_type, self.topo_with_self_loops, self.device + ) + self.base_graph = k2.create_fsa_vec([self.graph_compiler.ctc_topo_inv.invert()]).to(self.device) + + +class RnntAligner(BaseDecoder, RnntK2Mixin): + """RNNT graph decoder with the `minimal` topology. + If predictor_window_size is not provided, this decoder works as a Viterbi over regular RNNT lattice. + With predictor_window_size provided, it applies uniform pruning when compiling Emission FSAs + to reduce memory and compute consumption. + + Can only do forced alignment. + """ + + def __init__( + self, + num_classes: int, + blank: int, + cfg: Optional[DictConfig] = None, + intersect_pruned: bool = False, + intersect_conf: GraphIntersectDenseConfig = GraphIntersectDenseConfig(), + topo_type: str = "default", + topo_with_self_loops: bool = True, + predictor_window_size: int = 0, + predictor_step_size: int = 1, + device: torch.device = torch.device("cpu"), + ): + if cfg is not None: + topo_type = cfg.get("topo_type", topo_type) + predictor_window_size = cfg.get("predictor_window_size", predictor_window_size) + predictor_step_size = cfg.get("predictor_step_size", predictor_step_size) + if topo_type != "minimal": + raise NotImplementedError(f"Only topo_type=`minimal` is supported at the moment.") + super().__init__( + num_classes, blank, cfg, intersect_pruned, intersect_conf, topo_type, topo_with_self_loops, device + ) + self.predictor_window_size = predictor_window_size + self.predictor_step_size = predictor_step_size + from nemo.collections.asr.parts.k2.graph_compilers import RnntTopologyCompiler + + self.graph_compiler = RnntTopologyCompiler( + self.num_classes, + self.blank, + self.topo_type, + self.topo_with_self_loops, + self.device, + max_adapter_length=self.predictor_window_size, + ) + self.base_graph = self.graph_compiler.base_graph + + def decode( + self, + log_probs: torch.Tensor, + log_probs_length: torch.Tensor, + return_lattices: bool = False, + return_ilabels: bool = False, + output_aligned: bool = True, + ) -> Union['k2.Fsa', Tuple[List[torch.Tensor], List[torch.Tensor]]]: + raise NotImplementedError("RNNT decoding is not implemented. Only .align(...) method is supported.") + + def align( + self, + log_probs: torch.Tensor, + log_probs_length: torch.Tensor, + targets: torch.Tensor, + target_lengths: torch.Tensor, + return_lattices: bool = False, + return_ilabels: bool = False, + output_aligned: bool = True, + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + assert self.predictor_window_size == 0 or log_probs.size(2) <= self.predictor_window_size + 1 + + return super().align( log_probs, log_probs_length, + targets, + target_lengths, return_lattices=return_lattices, return_ilabels=return_ilabels, output_aligned=output_aligned, @@ -190,7 +279,14 @@ def align( class TokenLMDecoder(BaseDecoder): """Graph decoder with token_lm-based decoding graph. - + Available topologies: + - `default`, with or without self-loops + - `compact`, with or without self-loops + - `shared_blank`, with or without self-loops + - `minimal`, without self-loops + + Can do decoding and forced alignment. + cfg takes precedence over all optional parameters We keep explicit parameter setting to be able to create an instance without the need of a config. """ @@ -236,9 +332,7 @@ def update_graph(self, graph: 'k2.Fsa'): labels = token_lm.labels if labels.max() != self.num_classes - 1: raise ValueError(f"token_lm is not compatible with the num_classes: {labels.unique()}, {self.num_classes}") - if self.pad_fsavec: - shift_labels_inpl([token_lm], 1) self.graph_compiler = CtcNumGraphCompiler( - self.num_classes, self.topo_type, self.topo_with_self_loops, self.device, token_lm + self.num_classes, self.blank, self.topo_type, self.topo_with_self_loops, self.device, token_lm ) self.base_graph = k2.create_fsa_vec([self.graph_compiler.base_graph]).to(self.device) diff --git a/nemo/collections/asr/parts/k2/loss_mixins.py b/nemo/collections/asr/parts/k2/loss_mixins.py new file mode 100644 index 000000000000..ad8286e43e23 --- /dev/null +++ b/nemo/collections/asr/parts/k2/loss_mixins.py @@ -0,0 +1,233 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC +from typing import List, Optional, Tuple + +import torch + +from nemo.collections.asr.parts.k2.grad_utils import GradExpNormalize +from nemo.collections.asr.parts.k2.utils import ( + create_supervision, + get_arc_weights, + get_uniform_rnnt_prune_ranges, + make_non_pad_mask, + make_non_pad_mask_3d, + prep_padded_densefsavec, +) +from nemo.core.utils.k2_guard import k2 # import k2 from guard module + + +class CtcK2Mixin(ABC): + """k2 Mixin class that simplifies the construction of various k2-based CTC-like losses. + + It does the following: + - Prepares and adapts the input tensors (method _prepare_log_probs_and_targets). + - Creates Emissions graphs (method _prepare_emissions_graphs). + - Extracts the labels and probabilities of the best lattice path (method _extract_labels_and_probabilities). + """ + + def _prepare_log_probs_and_targets( + self, + log_probs: torch.Tensor, + input_lengths: torch.Tensor, + targets: Optional[torch.Tensor] = None, + target_lengths: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + """Creates k2-style supervisions and shifts targets by one if the number is not zero. + """ + assert log_probs.size(-1) == self.num_classes + supervisions = create_supervision(input_lengths) + # shift targets to make output epsilon ID zero + return ( + log_probs, + supervisions, + torch.where(targets < self.blank, targets + 1, targets) if targets is not None else None, + target_lengths, + ) + + def _prepare_emissions_graphs(self, log_probs: torch.Tensor, supervisions: torch.Tensor) -> 'k2.DenseFsaVec': + """Creates DenseFsaVec, padding it with frames if the topology is `compact`. + In particular, every second frame of the DenseFsaVec is the frame. + + frame is a frame with log-probability zero and every other log-probability is -inf. + """ + return ( + prep_padded_densefsavec(log_probs, supervisions) + if self.pad_fsavec + else k2.DenseFsaVec(log_probs, supervisions) + ) + + def _maybe_normalize_gradients(self, log_probs: torch.Tensor, input_lengths: torch.Tensor) -> torch.Tensor: + """PyTorch is doing the log-softmax normalization as part of the CTC computation. + More: https://github.com/k2-fsa/k2/issues/575 + """ + return GradExpNormalize.apply(log_probs, input_lengths, "mean" if self.reduction != "sum" else "none") + + def _extract_labels_and_probabilities( + self, shortest_path_fsas: 'k2.Fsa', return_ilabels: bool = False, output_aligned: bool = True + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """Extracts the labels and probabilities of the best lattice path, + dropping arcs and restoring the targets shift, if needed. + """ + shortest_paths = [] + probs = [] + # direct iterating does not work as expected + for i in range(shortest_path_fsas.shape[0]): + shortest_path_fsa = shortest_path_fsas[i] + # suppose that artificial input epsilon numbers >= self.num_classes + non_eps_mask = (shortest_path_fsa.labels != -1) & (shortest_path_fsa.labels < self.num_classes) + if return_ilabels: + labels = shortest_path_fsa.labels[non_eps_mask] + else: + labels = shortest_path_fsa.aux_labels[non_eps_mask] + if self.blank != 0: + # suppose output epsilon number == 0 + # since the input epsilons were removed, we treat all remaining epsilons as blanks + labels[labels == 0] = self.blank + labels[(labels > 0) & (labels < self.blank)] -= 1 + labels = labels.to(dtype=torch.long) + if not return_ilabels and not output_aligned: + labels = labels[labels != self.blank] + shortest_paths.append(labels) + probs.append(get_arc_weights(shortest_path_fsa)[non_eps_mask].exp().to(device=shortest_path_fsas.device)) + return shortest_paths, probs + + +class RnntK2Mixin(CtcK2Mixin): + """k2 Mixin class that simplifies the construction of various k2-based RNNT-like losses. Inherits CtcK2Mixin. + + It does the following: + - Prepares and adapts the input tensors. + - Creates Emissions graphs. + - Extracts the labels and probabilities of the best lattice path (method _extract_labels_and_probabilities). + """ + + def _prepare_log_probs_and_targets( + self, + log_probs: torch.Tensor, + input_lengths: torch.Tensor, + targets: torch.Tensor, + target_lengths: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Before calling super()._prepare_log_probs_and_targets, this method reshapes the log_probs tensor + from (B, T, U+1, D) to (B, T', D) where T' = T*(U+1), shifts paddings along T and U towards the end of T', + and recomputes input_lengths. + + It also calculates indices on which steps should be applied to the log_probs tensor to emulate + arcs shift of the Emissions graph for the pruned RNNT variant. + """ + assert len(log_probs.size()) == 4 # B T U D + B, T, U, D = log_probs.size() + TU = T * U + + # save step indices if, as we assume, decoder output pruning has been applied + if self.predictor_window_size > 0 and self.predictor_window_size < target_lengths.max(): + window_size_with_blank = self.predictor_window_size + 1 + ranges_begin = get_uniform_rnnt_prune_ranges( + input_lengths, target_lengths, window_size_with_blank, self.predictor_step_size, T, True + ) + step_sizes = ranges_begin[:, 1:] - ranges_begin[:, :-1] + raw_step_indices = torch.where(step_sizes > 0) + if self.predictor_step_size > 1: + raw_step_indices = torch.repeat_interleave( + torch.stack(raw_step_indices).T, step_sizes[raw_step_indices], dim=0 + ).T + raw_step_indices = (raw_step_indices[0], raw_step_indices[1]) + unique, count = torch.unique(raw_step_indices[0], return_counts=True) + shift_mask = raw_step_indices[0].unsqueeze(0).repeat(len(unique), 1) == unique.unsqueeze(-1) + step_indices = ( + raw_step_indices[0], + ( + torch.arange(ranges_begin.size(1)).unsqueeze(0).repeat(ranges_begin.size(0), 1) + * window_size_with_blank + )[(raw_step_indices[0], raw_step_indices[1] + 1)] + + torch.cumsum(shift_mask, 1)[shift_mask] + - 1, + ) + max_count = count.max() + max_count_vec = torch.full((B,), max_count) + max_count_vec[unique] -= count + pad_indices_row = torch.repeat_interleave(torch.arange(B), max_count_vec) + pad_unique = torch.unique(pad_indices_row) + pad_shift_mask = pad_indices_row.unsqueeze(0).repeat(len(pad_unique), 1) == pad_unique.unsqueeze(-1) + pad_indices = ( + pad_indices_row, + T * window_size_with_blank + max_count - torch.cumsum(pad_shift_mask, 1)[pad_shift_mask], + ) + self.__step_indices = ( + torch.cat((step_indices[0], pad_indices[0])), + torch.cat((step_indices[1], pad_indices[1])), + ) + self.__supervisions_add = max_count - max_count_vec + else: + self.__step_indices = None + self.__supervisions_add = None + + # reshape 4D log_probs to 3D with respect to target_lengths + non_pad_mask_true = make_non_pad_mask_3d(input_lengths, target_lengths + 1, T, U).flatten(1) + input_lengths = non_pad_mask_true.sum(1) + non_pad_mask_fake = make_non_pad_mask(input_lengths, TU).flatten() + non_pad_mask_true = non_pad_mask_true.flatten() + rearranged_indices = torch.arange(TU * B, device=log_probs.device) + rearranged_indices_buffer = rearranged_indices.clone() + rearranged_indices[non_pad_mask_fake] = rearranged_indices_buffer[non_pad_mask_true] + rearranged_indices[~non_pad_mask_fake] = rearranged_indices_buffer[~non_pad_mask_true] + log_probs = log_probs.reshape(-1, D)[rearranged_indices].view(B, -1, D) + + return super()._prepare_log_probs_and_targets(log_probs, input_lengths, targets, target_lengths) + + def _prepare_emissions_graphs(self, log_probs: torch.Tensor, supervisions: torch.Tensor) -> 'k2.DenseFsaVec': + """Overrides super()._prepare_emissions_graphs. + Creates DenseFsaVec, adding outputs to the end of the D dimension. + + If pruning is used, this method also pads the DenseFsaVec with frames + according to the steps, calculated before. + + frame is a frame with log-probability zero and every other log-probability is -inf. + """ + if self.__step_indices is None or self.__supervisions_add is None: + log_probs_eps = torch.cat( + (log_probs, torch.zeros((log_probs.size(0), log_probs.size(1), 1), device=log_probs.device)), dim=2 + ) + else: + mask = torch.zeros( + (log_probs.size(0), log_probs.size(1) + int(len(self.__step_indices[0]) / log_probs.size(0))), + dtype=torch.bool, + ) + mask[self.__step_indices] = True + log_probs_eps = torch.zeros((mask.size(0), mask.size(1), log_probs.size(2) + 1), device=log_probs.device) + log_probs_eps[mask] = torch.tensor( + [torch.finfo(torch.float32).min] * log_probs.size(2) + [0], device=log_probs.device + ) + log_probs_eps[~mask] = torch.cat( + (log_probs, torch.zeros((log_probs.size(0), log_probs.size(1), 1), device=log_probs.device)), dim=2 + ).view(-1, log_probs.size(-1) + 1) + input_lengths = supervisions[:, -1] + self.__supervisions_add[supervisions[:, 0].to(dtype=torch.long)] + if not torch.all(input_lengths[:-1] - input_lengths[1:] >= 0): + # have to reorder supervisions inplace + order = torch.argsort(input_lengths, descending=True) + # the second column is assumed to be zero + supervisions[:, 0] = supervisions[order, 0] + supervisions[:, -1] = input_lengths[order] + else: + supervisions[:, -1] = input_lengths + self.__step_indices = None + self.__supervisions_add = None + return k2.DenseFsaVec(log_probs_eps, supervisions) + + def _maybe_normalize_gradients(self, log_probs: torch.Tensor, input_lengths: torch.Tensor) -> torch.Tensor: + """Not required for RNNT. + """ + return log_probs diff --git a/nemo/collections/asr/parts/k2/map_loss.py b/nemo/collections/asr/parts/k2/map_loss.py index ec9b904344f2..c261a4f2ef6b 100644 --- a/nemo/collections/asr/parts/k2/map_loss.py +++ b/nemo/collections/asr/parts/k2/map_loss.py @@ -26,31 +26,29 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union +from abc import abstractmethod +from typing import Any, Optional, Tuple, Union import torch from omegaconf import DictConfig -from nemo.collections.asr.parts.k2.autograd import sparse_abs from nemo.collections.asr.parts.k2.classes import GraphIntersectDenseConfig +from nemo.collections.asr.parts.k2.loss_mixins import CtcK2Mixin +from nemo.collections.asr.parts.k2.ml_loss import MLLoss from nemo.collections.asr.parts.k2.utils import ( create_sparse_wrapped, - create_supervision, get_tot_objf_and_finite_mask, + invert_permutation, load_graph, - make_blank_first, - prep_padded_densefsavec, - shift_labels_inpl, ) from nemo.core.utils.k2_guard import k2 # import k2 from guard module from nemo.utils import logging -class MAPLoss(torch.nn.Module): +class MAPLoss(MLLoss): """ Maximum a Posteriori Probability criterion. - It is implemented as Lattice-Free Maximum Mutual Information (LF-MMI) - and LF-boosted-MMI (LF-bMMI) losses. + It implements Lattice-Free Maximum Mutual Information (LF-MMI) and LF-boosted-MMI (LF-bMMI) losses. Based on https://github.com/k2-fsa/snowfall/blob/master/snowfall/objectives/mmi.py @@ -58,6 +56,7 @@ class MAPLoss(torch.nn.Module): We keep explicit parameter setting to be able to create an instance without the need of a config. """ + @abstractmethod def __init__( self, num_classes: int, @@ -66,34 +65,30 @@ def __init__( cfg: Optional[DictConfig] = None, topo_type: str = "default", topo_with_self_loops: bool = True, - loss_type: str = "mmi", token_lm: Optional[Union['k2.Fsa', str]] = None, intersect_pruned: bool = False, intersect_conf: GraphIntersectDenseConfig = GraphIntersectDenseConfig(), boost_coeff: float = 0.0, ): - super().__init__() + super().__init__( + num_classes=num_classes, + blank=blank, + reduction=reduction, + cfg=cfg, + topo_type=topo_type, + topo_with_self_loops=topo_with_self_loops, + ) if cfg is not None: - topo_type = cfg.get("topo_type", topo_type) - topo_with_self_loops = cfg.get("topo_with_self_loops", topo_with_self_loops) - loss_type = cfg.get("loss_type", loss_type) token_lm = cfg.get("token_lm", token_lm) intersect_pruned = cfg.get("intersect_pruned", intersect_pruned) intersect_conf = cfg.get("intersect_conf", intersect_conf) boost_coeff = cfg.get("boost_coeff", boost_coeff) - self.num_classes = num_classes - self.blank = blank - self.reduction = reduction - self.loss_type = loss_type self.boost_coeff = boost_coeff - self.intersect_calc_scores = ( - self._intersect_calc_scores_mmi_pruned if intersect_pruned else self._intersect_calc_scores_mmi_exact + self._intersect_calc_scores_impl = ( + self._intersect_calc_scores_impl_pruned if intersect_pruned else self._intersect_calc_scores_impl_exact_opt ) self.intersect_conf = intersect_conf - self.topo_type = topo_type - self.topo_with_self_loops = topo_with_self_loops - self.pad_fsavec = topo_type == "compact" - self.graph_compiler = None + self.graph_compiler = None # expected to be initialized in .update_graph(...) if token_lm is None: logging.warning( f"""token_lm is empty. @@ -107,25 +102,19 @@ def __init__( else: self.update_graph(self.lm_graph) + @abstractmethod def update_graph(self, graph: 'k2.Fsa'): - self.lm_graph = graph - lm_graph = self.lm_graph.clone() - if hasattr(lm_graph, "aux_labels"): - delattr(lm_graph, "aux_labels") - labels = lm_graph.labels - if labels.max() != self.num_classes - 1: - raise ValueError(f"lm_graph is not compatible with the num_classes: {labels.unique()}, {self.num_classes}") - if self.pad_fsavec: - shift_labels_inpl([lm_graph], 1) - if self.loss_type == "mmi": - from nemo.collections.asr.parts.k2.graph_compilers import MmiGraphCompiler as compiler - else: - raise ValueError(f"Invalid value of `loss_type`: {self.loss_type}.") - self.graph_compiler = compiler(self.num_classes, self.topo_type, self.topo_with_self_loops, aux_graph=lm_graph) - - def _intersect_calc_scores_mmi_exact( - self, dense_fsa_vec: k2.DenseFsaVec, num_graphs: 'k2.Fsa', den_graph: 'k2.Fsa', return_lats: bool = True, - ): + # expected to be set in child classes + raise NotImplementedError + + def _intersect_calc_scores_impl_exact_opt( + self, dense_fsa_vec: 'k2.DenseFsaVec', num_graphs: 'k2.Fsa', den_graph: 'k2.Fsa', return_lats: bool = True, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional['k2.Fsa'], Optional['k2.Fsa']]: + """Inner intersection method. + Does joint (simultaneous) exact intersection of dense_fsa_vec against num_graphs and den_graph. + + Optiolally returns the numerator and the denominator lattices. + """ device = dense_fsa_vec.device assert device == num_graphs.device and device == den_graph.device @@ -165,7 +154,7 @@ def _intersect_calc_scores_mmi_exact( seqframe_idx_name="seqframe_idx" if return_lats else None, ) - num_den_tot_scores = num_den_lats.get_tot_scores(log_semiring=True, use_double_scores=True) + num_den_tot_scores = num_den_lats.get_tot_scores(log_semiring=True, use_double_scores=False) num_tot_scores = num_den_tot_scores[::2] den_tot_scores = num_den_tot_scores[1::2] @@ -180,9 +169,14 @@ def _intersect_calc_scores_mmi_exact( else: return num_tot_scores, den_tot_scores, None, None - def _intersect_calc_scores_mmi_pruned( - self, dense_fsa_vec: k2.DenseFsaVec, num_graphs: 'k2.Fsa', den_graph: 'k2.Fsa', return_lats: bool = True, - ): + def _intersect_calc_scores_impl_pruned( + self, dense_fsa_vec: 'k2.DenseFsaVec', num_graphs: 'k2.Fsa', den_graph: 'k2.Fsa', return_lats: bool = True, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional['k2.Fsa'], Optional['k2.Fsa']]: + """Inner intersection method. + Does exact intersection of dense_fsa_vec against num_graphs and pruned intersection against den_graph. + + Optiolally returns the numerator and the denominator lattices. + """ device = dense_fsa_vec.device assert device == num_graphs.device and device == den_graph.device @@ -205,84 +199,122 @@ def _intersect_calc_scores_mmi_pruned( seqframe_idx_name="seqframe_idx" if return_lats else None, ) - # use_double_scores=True does matter - # since otherwise it sometimes makes rounding errors - num_tot_scores = num_lats.get_tot_scores(log_semiring=True, use_double_scores=True) - den_tot_scores = den_lats.get_tot_scores(log_semiring=True, use_double_scores=True) + num_tot_scores = num_lats.get_tot_scores(log_semiring=True, use_double_scores=False) + den_tot_scores = den_lats.get_tot_scores(log_semiring=True, use_double_scores=False) if return_lats: return num_tot_scores, den_tot_scores, num_lats, den_lats else: return num_tot_scores, den_tot_scores, None, None - def forward( - self, - log_probs: torch.Tensor, - targets: torch.Tensor, - input_lengths: torch.Tensor, - target_lengths: torch.Tensor, - ) -> torch.Tensor: - assert self.graph_compiler is not None - boosted = self.boost_coeff != 0.0 - if self.blank != 0: - # rearrange log_probs to put blank at the first place - # and shift targets to emulate blank = 0 - log_probs, targets = make_blank_first(self.blank, log_probs, targets) - supervisions, order = create_supervision(input_lengths) - order = order.long() - targets = targets[order] - target_lengths = target_lengths[order] - - if log_probs.device != self.graph_compiler.device: - self.graph_compiler.to(log_probs.device) - - num_graphs, den_graph = self.graph_compiler.compile( - targets + 1 if self.pad_fsavec else targets, target_lengths - ) + def _intersect_calc_scores( + self, emissions_graphs: 'k2.DenseFsaVec', supervision_graphs: Any, supervisions: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Intersects emissions_graphs with supervision_graphs and calculates lattice scores. + This version implicitly assumes supervision_graphs to be a pair of the numerator and the denominator FSAs. - dense_fsa_vec = ( - prep_padded_densefsavec(log_probs, supervisions) - if self.pad_fsavec - else k2.DenseFsaVec(log_probs, supervisions) - ) + It can also calculate accuracy between the numerator and the denominator lattices to use it as additional loss. - num_tot_scores, den_tot_scores, num_lats, den_lats = self.intersect_calc_scores( - dense_fsa_vec, num_graphs, den_graph, boosted + Can be overridden. + """ + boosted = self.boost_coeff != 0.0 + num_tot_scores, den_tot_scores, num_lats, den_lats = self._intersect_calc_scores_impl( + emissions_graphs, supervision_graphs[0], supervision_graphs[1], boosted ) - tot_scores = num_tot_scores - den_tot_scores + inverted_batch_order = invert_permutation(supervisions[:, 0].to(dtype=torch.long)) + self.__batch_order = None + tot_scores = (num_tot_scores - den_tot_scores)[inverted_batch_order] mmi_tot_scores, mmi_valid_mask = get_tot_objf_and_finite_mask(tot_scores, self.reduction) if boosted: assert num_lats is not None and den_lats is not None size = ( - dense_fsa_vec.dim0(), - dense_fsa_vec.scores.shape[0], - dense_fsa_vec.scores.shape[1] - 1, + emissions_graphs.dim0(), + emissions_graphs.scores.shape[0], + emissions_graphs.scores.shape[1] - 1, ) - row_ids = dense_fsa_vec.dense_fsa_vec.shape().row_ids(1) + row_ids = emissions_graphs.emissions_graphs.shape().row_ids(1) num_sparse = create_sparse_wrapped( indices=[k2.index_select(row_ids, num_lats.seqframe_idx), num_lats.seqframe_idx, num_lats.phones,], values=num_lats.get_arc_post(False, True).exp(), size=size, min_col_index=0, ) + del num_lats den_sparse = create_sparse_wrapped( indices=[k2.index_select(row_ids, den_lats.seqframe_idx), den_lats.seqframe_idx, den_lats.phones,], values=den_lats.get_arc_post(False, True).exp(), size=size, min_col_index=0, ) + del den_lats + + acc_loss = torch.sparse.sum((num_sparse - den_sparse).coalesce().abs(), (1, 2)).to_dense() + del num_sparse, den_sparse - # NOTE: Due to limited support of PyTorch's autograd for sparse tensors, - # we cannot use (num_sparse - den_sparse) here - # TODO (alaptev): propose sparse_abs to k2 - acc_loss = torch.sparse.sum(sparse_abs((num_sparse + (-den_sparse)).coalesce()), (1, 2)).to_dense() acc_tot_scores, acc_valid_mask = get_tot_objf_and_finite_mask(acc_loss, self.reduction) valid_mask = mmi_valid_mask & acc_valid_mask - total_loss = self.boost_coeff * acc_tot_scores[valid_mask] - mmi_tot_scores[valid_mask] + total_loss = ( + (self.boost_coeff * acc_tot_scores[inverted_batch_order][valid_mask] - mmi_tot_scores[valid_mask]) + if self.reduction == "none" + else self.boost_coeff * acc_tot_scores - mmi_tot_scores + ) else: valid_mask = mmi_valid_mask - total_loss = -mmi_tot_scores[mmi_valid_mask] + total_loss = -mmi_tot_scores[valid_mask] if self.reduction == "none" else -mmi_tot_scores return total_loss, valid_mask + + +class CtcMmiLoss(MAPLoss, CtcK2Mixin): + """MMI loss with custom CTC topologies. + Available topologies: + - `default`, with or without self-loops + - `compact`, with or without self-loops + - `shared_blank`, with or without self-loops + - `minimal`, without self-loops + + cfg takes precedence over all optional parameters + We keep explicit parameter setting to be able to create an instance without the need of a config. + """ + + def __init__( + self, + num_classes: int, + blank: int, + reduction: str, + cfg: Optional[DictConfig] = None, + topo_type: str = "default", + topo_with_self_loops: bool = True, + token_lm: Optional[Union['k2.Fsa', str]] = None, + intersect_pruned: bool = False, + intersect_conf: GraphIntersectDenseConfig = GraphIntersectDenseConfig(), + boost_coeff: float = 0.0, + ): + super().__init__( + num_classes=num_classes, + blank=blank, + reduction=reduction, + cfg=cfg, + topo_type=topo_type, + topo_with_self_loops=topo_with_self_loops, + token_lm=token_lm, + intersect_pruned=intersect_pruned, + intersect_conf=intersect_conf, + boost_coeff=boost_coeff, + ) + + def update_graph(self, graph: 'k2.Fsa'): + self.lm_graph = graph + lm_graph = self.lm_graph.clone() + if hasattr(lm_graph, "aux_labels"): + delattr(lm_graph, "aux_labels") + labels = lm_graph.labels + if labels.max() != self.num_classes - 1: + raise ValueError(f"lm_graph is not compatible with the num_classes: {labels.unique()}, {self.num_classes}") + from nemo.collections.asr.parts.k2.graph_compilers import MmiGraphCompiler as compiler + + self.graph_compiler = compiler( + self.num_classes, self.blank, self.topo_type, self.topo_with_self_loops, aux_graph=lm_graph + ) diff --git a/nemo/collections/asr/parts/k2/ml_loss.py b/nemo/collections/asr/parts/k2/ml_loss.py index 0cd2fe5f6090..ef916ee2f69d 100644 --- a/nemo/collections/asr/parts/k2/ml_loss.py +++ b/nemo/collections/asr/parts/k2/ml_loss.py @@ -26,28 +26,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple, Union +from abc import abstractmethod +from typing import Any, Optional, Tuple import torch from omegaconf import DictConfig -from nemo.collections.asr.parts.k2.grad_utils import GradExpNormalize -from nemo.collections.asr.parts.k2.utils import ( - create_supervision, - get_tot_objf_and_finite_mask, - load_graph, - make_blank_first, - prep_padded_densefsavec, -) - +from nemo.collections.asr.parts.k2.graph_compilers import CtcTopologyCompiler, RnntTopologyCompiler +from nemo.collections.asr.parts.k2.loss_mixins import CtcK2Mixin, RnntK2Mixin +from nemo.collections.asr.parts.k2.utils import get_tot_objf_and_finite_mask, invert_permutation from nemo.core.utils.k2_guard import k2 # import k2 from guard module class MLLoss(torch.nn.Module): """ Maximum Likelihood criterion. - It is implemented as Connectionist Temporal Classification (CTC) loss, - but can be extended to support other loss functions (ASG, HMM, ...). + It implements Connectionist Temporal Classification (CTC) loss, + but can be extended to support other loss functions (ASG, HMM, RNNT, ...). Based on https://github.com/k2-fsa/snowfall/blob/master/snowfall/objectives/ctc.py @@ -55,6 +50,7 @@ class MLLoss(torch.nn.Module): We keep explicit parameter setting to be able to create an instance without the need of a config. """ + @abstractmethod def __init__( self, num_classes: int, @@ -63,33 +59,59 @@ def __init__( cfg: Optional[DictConfig] = None, topo_type: str = "default", topo_with_self_loops: bool = True, - graph_type: str = "topo", - token_lm: Optional[Union['k2.Fsa', str]] = None, ): super().__init__() if cfg is not None: topo_type = cfg.get("topo_type", topo_type) topo_with_self_loops = cfg.get("topo_with_self_loops", topo_with_self_loops) - graph_type = cfg.get("graph_type", graph_type) - token_lm = cfg.get("token_lm", token_lm) self.blank = blank self.num_classes = num_classes self.reduction = reduction + self.topo_type = topo_type + self.topo_with_self_loops = topo_with_self_loops self.pad_fsavec = topo_type == "compact" - if graph_type == "topo": - from nemo.collections.asr.parts.k2.graph_compilers import CtcTopologyCompiler as compiler + self.graph_compiler = None # expected to be initialized in child classes + + def _prepare_graphs_for_intersection( + self, + log_probs: torch.Tensor, + targets: torch.Tensor, + input_lengths: torch.Tensor, + target_lengths: torch.Tensor, + ) -> Tuple['k2.DenseFsaVec', Any, torch.Tensor]: + """Converts input tensors to FST graphs: + log_probs to supervision_graphs (DenseFsaVec) + targets to supervision_graphs + Can be overridden. + """ + log_probs, supervisions, targets, target_lengths = self._prepare_log_probs_and_targets( + log_probs, input_lengths, targets, target_lengths + ) + log_probs = self._maybe_normalize_gradients(log_probs, supervisions[:, -1].to(dtype=torch.long)) + emissions_graphs = self._prepare_emissions_graphs(log_probs, supervisions) + del log_probs - self.graph_compiler = compiler(self.num_classes, topo_type, topo_with_self_loops) - elif graph_type == "token_lm": - from nemo.collections.asr.parts.k2.graph_compilers import CtcNumGraphCompiler as compiler + if emissions_graphs.device != self.graph_compiler.device: + self.graph_compiler.to(emissions_graphs.device) + order = supervisions[:, 0].to(dtype=torch.long) + supervision_graphs = self.graph_compiler.compile(targets[order], target_lengths[order]) - if isinstance(token_lm, str): - token_lm = load_graph(token_lm) - self.graph_compiler = compiler(self.num_classes, topo_type, topo_with_self_loops, aux_graph=token_lm) + return emissions_graphs, supervision_graphs, supervisions - raise NotImplementedError("Not tested yet") - else: - raise ValueError(f"Invalid value of `graph_type`: {graph_type}.") + def _intersect_calc_scores( + self, emissions_graphs: 'k2.DenseFsaVec', supervision_graphs: Any, supervisions: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Intersects emissions_graphs with supervision_graphs and calculates lattice scores. + Can be overridden. + """ + lats = k2.intersect_dense(supervision_graphs, emissions_graphs, torch.finfo(torch.float32).max / 10) + del emissions_graphs + + num_tot_scores = lats.get_tot_scores(log_semiring=True, use_double_scores=False) + del lats + tot_scores = num_tot_scores[invert_permutation(supervisions[:, 0].to(dtype=torch.long))] + tot_scores, valid_mask = get_tot_objf_and_finite_mask(tot_scores, self.reduction) + return -tot_scores[valid_mask] if self.reduction == "none" else -tot_scores, valid_mask def forward( self, @@ -98,33 +120,101 @@ def forward( input_lengths: torch.Tensor, target_lengths: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: - if self.blank != 0: - # rearrange log_probs to put blank at the first place - # and shift targets to emulate blank = 0 - log_probs, targets = make_blank_first(self.blank, log_probs, targets) - supervisions, order = create_supervision(input_lengths) - order = order.long() - targets = targets[order] - target_lengths = target_lengths[order] - # PyTorch is doing the log-softmax normalization as part of the CTC computation. - # More: https://github.com/k2-fsa/k2/issues/575 - log_probs = GradExpNormalize.apply(log_probs, input_lengths, "mean" if self.reduction != "sum" else "none") - - if log_probs.device != self.graph_compiler.device: - self.graph_compiler.to(log_probs.device) - num_graphs = self.graph_compiler.compile(targets + 1 if self.pad_fsavec else targets, target_lengths) - - dense_fsa_vec = ( - prep_padded_densefsavec(log_probs, supervisions) - if self.pad_fsavec - else k2.DenseFsaVec(log_probs, supervisions) + assert self.graph_compiler is not None + + emissions_graphs, supervision_graphs, supervisions = self._prepare_graphs_for_intersection( + log_probs, targets, input_lengths, target_lengths ) + scores, mask = self._intersect_calc_scores(emissions_graphs, supervision_graphs, supervisions) + return scores, mask - num_lats = k2.intersect_dense(num_graphs, dense_fsa_vec, torch.finfo(torch.float32).max) - # use_double_scores=True does matter - # since otherwise it sometimes makes rounding errors - num_tot_scores = num_lats.get_tot_scores(log_semiring=True, use_double_scores=True) - tot_scores = num_tot_scores - tot_scores, valid_mask = get_tot_objf_and_finite_mask(tot_scores, self.reduction) - return -tot_scores[valid_mask], valid_mask +class CtcLoss(MLLoss, CtcK2Mixin): + """Regular CTC loss with custom topologies. + Available topologies: + - `default`, with or without self-loops + - `compact`, with or without self-loops + - `shared_blank`, with or without self-loops + - `minimal`, without self-loops + cfg takes precedence over all optional parameters + We keep explicit parameter setting to be able to create an instance without the need of a config. + """ + + def __init__( + self, + num_classes: int, + blank: int, + reduction: str, + cfg: Optional[DictConfig] = None, + topo_type: str = "default", + topo_with_self_loops: bool = True, + ): + super().__init__( + num_classes=num_classes, + blank=blank, + reduction=reduction, + cfg=cfg, + topo_type=topo_type, + topo_with_self_loops=topo_with_self_loops, + ) + self.graph_compiler = CtcTopologyCompiler( + self.num_classes, self.blank, self.topo_type, self.topo_with_self_loops + ) + + +class RnntLoss(MLLoss, RnntK2Mixin): + """RNNT loss with the `minimal` topology. + If predictor_window_size is not provided, this loss works as regular RNNT. + With predictor_window_size provided, it applies uniform pruning when compiling Emission FSAs + to reduce memory and compute consumption. + cfg takes precedence over all optional parameters + We keep explicit parameter setting to be able to create an instance without the need of a config. + """ + + def __init__( + self, + num_classes: int, + blank: int, + reduction: str, + cfg: Optional[DictConfig] = None, + topo_type: str = "minimal", + topo_with_self_loops: bool = True, + predictor_window_size: int = 0, + predictor_step_size: int = 1, + ): + super().__init__( + num_classes=num_classes, + blank=blank, + reduction=reduction, + cfg=cfg, + topo_type=topo_type, + topo_with_self_loops=topo_with_self_loops, + ) + if cfg is not None: + topo_type = cfg.get("topo_type", topo_type) + predictor_window_size = cfg.get("predictor_window_size", predictor_window_size) + predictor_step_size = cfg.get("predictor_step_size", predictor_step_size) + if topo_type != "minimal": + raise NotImplementedError(f"Only topo_type=`minimal` is supported at the moment.") + self.predictor_window_size = predictor_window_size + self.predictor_step_size = predictor_step_size + self.graph_compiler = RnntTopologyCompiler( + self.num_classes, + self.blank, + self.topo_type, + self.topo_with_self_loops, + max_adapter_length=self.predictor_window_size, + ) + + def forward( + self, + log_probs: torch.Tensor, + targets: torch.Tensor, + input_lengths: torch.Tensor, + target_lengths: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + assert self.predictor_window_size == 0 or log_probs.size(2) <= self.predictor_window_size + 1 + + return super().forward( + log_probs=log_probs, targets=targets, input_lengths=input_lengths, target_lengths=target_lengths + ) diff --git a/nemo/collections/asr/parts/k2/topologies.py b/nemo/collections/asr/parts/k2/topologies.py index 8a5b825d8391..c892b2643332 100644 --- a/nemo/collections/asr/parts/k2/topologies.py +++ b/nemo/collections/asr/parts/k2/topologies.py @@ -12,43 +12,57 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +from functools import lru_cache +from typing import List, Optional, Union + +import torch from nemo.core.utils.k2_guard import k2 # import k2 from guard module -def build_topo(name: str, tokens: List[int], with_self_loops: bool = True) -> 'k2.Fsa': +def build_topo(name: str, tokens: List[int], blank_num: int, with_self_loops: bool = True) -> 'k2.Fsa': """Helper function to build a topology. + It allows to build topologies with a non-zero blank ID. Args: name: The topology name. Choices: default, compact, shared_blank, minimal tokens: A list of tokens, e.g., phones, characters, etc. + blank_num: + Blank number. Must be in tokens with_self_loops: Whether to add token-to-epsilon self-loops to a topology Returns: Returns a topology FST. """ if name == "default": - return build_default_topo(tokens, with_self_loops) + ans = build_default_topo(tokens, with_self_loops) elif name == "compact": - return build_compact_topo(tokens, with_self_loops) + ans = build_compact_topo(tokens, with_self_loops) elif name == "shared_blank": - return build_shared_blank_topo(tokens, with_self_loops) + ans = build_shared_blank_topo(tokens, with_self_loops) elif name == "minimal": - return build_minimal_topo(tokens) + ans = build_minimal_topo(tokens) else: raise ValueError(f"Unknown topo name: {name}") + if blank_num != 0: + blank_mask = ans.labels == 0 + ans.labels[(ans.labels != -1) & (ans.labels <= blank_num)] -= 1 + ans.labels[blank_mask] = blank_num + ans = k2.arc_sort(ans) + return ans def build_default_topo(tokens: List[int], with_self_loops: bool = True) -> 'k2.Fsa': """Build the default CTC topology. + Zero is assumed to be the ID of the blank symbol. """ - assert 0 in tokens, "We assume 0 is ID of the blank symbol" + assert -1 not in tokens, "We assume -1 is ID of the final transition" + assert 0 in tokens, "We assume 0 is the ID of the blank symbol" num_states = len(tokens) final_state = num_states - arcs = "" if with_self_loops else "0 0 0 0 0.0\n" + arcs = "" if with_self_loops else f"0 0 0 0 0.0\n" for i in range(num_states): for j in range(num_states): if i == j: @@ -65,22 +79,24 @@ def build_default_topo(tokens: List[int], with_self_loops: bool = True) -> 'k2.F def build_compact_topo(tokens: List[int], with_self_loops: bool = True) -> 'k2.Fsa': """Build the compact CTC topology. + Zero is assumed to be the ID of the blank symbol. See https://arxiv.org/abs/2110.03098 """ - assert 0 in tokens, "We assume 0 is ID of the blank symbol" + assert -1 not in tokens, "We assume -1 is ID of the final transition" + assert 0 in tokens, "We assume 0 is the ID of the blank symbol" + eps_num = tokens[-1] + 1 selfloops_shift = int(with_self_loops) - blank_num = 1 num_states = len(tokens) + selfloops_shift final_state = num_states - arcs = f"0 {selfloops_shift} {blank_num} 0 0.0\n" - for i in range(blank_num + selfloops_shift, num_states): - arcs += f"0 {i} {tokens[i - selfloops_shift] + 1} {tokens[i - selfloops_shift] + 1} 0.0\n" + arcs = "" + for i in range(selfloops_shift, num_states): + arcs += f"0 {i} {tokens[i - selfloops_shift]} {tokens[i - selfloops_shift]} 0.0\n" arcs += f"0 {final_state} -1 -1 0.0\n" - for i in range(blank_num, num_states): - arcs += f"{i} 0 0 0 0.0\n" + for i in range(1, num_states): + arcs += f"{i} 0 {eps_num} 0 0.0\n" if with_self_loops: - arcs += f"{i} {i} {tokens[i - selfloops_shift] + 1} 0 0.0\n" + arcs += f"{i} {i} {tokens[i - selfloops_shift]} 0 0.0\n" arcs += f"{final_state}" ans = k2.Fsa.from_str(arcs, num_aux_labels=1) ans = k2.arc_sort(ans) @@ -89,8 +105,10 @@ def build_compact_topo(tokens: List[int], with_self_loops: bool = True) -> 'k2.F def build_shared_blank_topo(tokens: List[int], with_self_loops: bool = True) -> 'k2.Fsa': """Build the shared blank CTC topology. + Zero is assumed to be the ID of the blank symbol. See https://github.com/k2-fsa/k2/issues/746#issuecomment-856421616 """ + assert -1 not in tokens, "We assume -1 is ID of the final transition" assert 0 in tokens, "We assume 0 is the ID of the blank symbol" tokens = tokens.copy() @@ -120,9 +138,11 @@ def build_shared_blank_topo(tokens: List[int], with_self_loops: bool = True) -> def build_minimal_topo(tokens: List[int]) -> 'k2.Fsa': """Build the minimal topology. + Zero is assumed to be the ID of the blank symbol. See https://arxiv.org/abs/2110.03098 """ - assert 0 in tokens, "We assume 0 is ID of the blank symbol" + assert -1 not in tokens, "We assume -1 is ID of the final transition" + assert 0 in tokens, "We assume 0 is the ID of the blank symbol" num_tokens = len(tokens) final_state = 1 @@ -134,3 +154,56 @@ def build_minimal_topo(tokens: List[int]) -> 'k2.Fsa': ans = k2.Fsa.from_str(arcs, num_aux_labels=1) ans = k2.arc_sort(ans) return ans + + +class RnntEmissionAdapterBuilder(object): + """Builder class for RNNT Emission Adapters. + + An Emission Adapter is an FSA used to emulate desired temporal Emissions FSA properties of a trivial Emissions FSA. + Temporal properties are emulated by -arcs with zero log-weight. + These additional arcs do not contribute to the lattice scores and can be easily removed from the best path. + + k2 does not have Emissions FSAs. Instead, it has DenseFsaVec, which is not a real FSA. + Thus, Emission Adapters should be composed with Supervision FSAs. + IMPOTRANT: -outputs are expected to be present in the DenseFsaVec. + + These RNNT adapters do only the re-routing (emulate hopping over U dimension). + Redundant non- are not removed by these adapters. + + At initialization, the builder expects a list of tokens, number and number. + When called, the builder returns adapters according to the provided text lengths. + """ + + def __init__(self, tokens: List[int], blank_num: int, eps_num: Optional[int] = None): + assert -1 not in tokens, "We assume -1 is ID of the final transition" + assert blank_num in tokens, "The blank ID must be in tokens" + assert eps_num is None or eps_num not in tokens, "The epsion ID must not be in tokens" + + self.tokens = tokens + self.blank_num = blank_num + self.eps_num = self.tokens[-1] + 1 if eps_num is None else eps_num + + def __call__(self, adapter_lengths: Union[torch.Tensor, List[int]]) -> 'k2.Fsa': + # if you don't make adapter_lengths a list beforehand, + # "i" will be implicitly converted to int, and this will always be considered a cache miss + return k2.create_fsa_vec([self._build_single_adapter(i) for i in adapter_lengths.tolist()]) + + @lru_cache(maxsize=1024) + def _build_single_adapter(self, adapter_length: int) -> 'k2.Fsa': + assert adapter_length >= 1, "`adapter_length` cannot be less than one" + + first_eps_state = adapter_length + 1 + final_state = adapter_length * 2 + 1 + arcs = "" + for i in range(adapter_length): + for j in range(len(self.tokens)): + if j != self.blank_num: + arcs += f"{i} {i + 1} {self.tokens[j]} 0.0\n" + arcs += f"{i} {first_eps_state} {self.blank_num} 0.0\n" + arcs += f"{adapter_length} {first_eps_state} {self.blank_num} 0.0\n" + for i in range(first_eps_state, final_state): + arcs += f"{i} {i + 1 if i < final_state - 1 else 0} {self.eps_num} 0.0\n" + arcs += f"{i} {final_state} -1 0.0\n" + arcs += f"{final_state}" + + return k2.arc_sort(k2.Fsa.from_str(arcs, acceptor=True)) diff --git a/nemo/collections/asr/parts/k2/utils.py b/nemo/collections/asr/parts/k2/utils.py index 45b71b6e3de3..f55620a81356 100644 --- a/nemo/collections/asr/parts/k2/utils.py +++ b/nemo/collections/asr/parts/k2/utils.py @@ -37,7 +37,7 @@ from nemo.utils import logging -def create_supervision(input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: +def create_supervision(input_lengths: torch.Tensor) -> torch.Tensor: """Creates a special supervisions tensor from input lengths. These supervisions are required for some k2 methods. """ @@ -45,8 +45,7 @@ def create_supervision(input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch (torch.tensor(range(input_lengths.shape[0])), torch.zeros(input_lengths.shape[0]), input_lengths.cpu(),), 1, ).to(dtype=torch.int32) # the duration column has to be sorted in decreasing order - order = torch.argsort(supervisions[:, -1], descending=True).to(dtype=torch.int32) - return supervisions[order.to(dtype=torch.long)], order + return supervisions[torch.argsort(supervisions[:, -1], descending=True)] def invert_permutation(indices: torch.Tensor) -> torch.Tensor: @@ -59,17 +58,41 @@ def invert_permutation(indices: torch.Tensor) -> torch.Tensor: return ans -def make_blank_first( - blank_idx: int, log_probs: torch.Tensor, targets: torch.Tensor -) -> Tuple[torch.Tensor, torch.Tensor]: - """Puts blank logits at the first place in input log_probs tensor. +def make_non_pad_mask(input_lengths: torch.Tensor, seq_len: int): + """Converts input_lengths to a non-padding mask. The mask is 2D. + """ + batch_size = input_lengths.shape[0] + seq_range = torch.arange(0, seq_len, device=input_lengths.device) + seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, seq_len) + seq_length_expand = input_lengths.clone().detach().to(seq_range_expand.device).unsqueeze(-1) + mask = seq_range_expand < seq_length_expand + return mask + + +def make_non_pad_mask_3d( + lengths_x: torch.Tensor, lengths_y: torch.Tensor, max_length_x: int, max_length_y: int +) -> torch.Tensor: + """Converts two orthogonal input_lengths to a non-padding mask. The mask is 3D. + """ + assert lengths_x.size() == lengths_y.size() + return make_non_pad_mask(lengths_x, max_length_x).unsqueeze(2) & make_non_pad_mask( + lengths_y, max_length_y + ).unsqueeze(1) + + +def ragged_to_tensor_2axes_simple(rt: k2.RaggedTensor) -> Optional[torch.Tensor]: + """Converts k2.RaggedTensor to torch.Tensor if the RaggedTensor is shallow (has two axes). """ - index = list(range(log_probs.shape[-1])) - del index[blank_idx] - index = torch.tensor([blank_idx] + index).to(log_probs.device) - new_log_probs = torch.index_select(log_probs, -1, index) - # TODO (alaptev): replace targets + 1 with torch.where to work for non-last blank_id - return new_log_probs, None if targets is None else targets + 1 + rt_list = rt.tolist() + result_list = [] + for e in rt_list: + if len(e) == 0: + result_list.append(0) + elif len(e) == 1: + result_list.append(e[0]) + else: + return None + return torch.tensor(result_list, dtype=torch.int32) def load_graph(graph_path: str) -> 'k2.Fsa': @@ -104,7 +127,7 @@ def intersect_with_self_loops(base_graph: 'k2.Fsa', aux_graph: 'k2.Fsa') -> 'k2. assert hasattr(base_graph, "aux_labels") assert not hasattr(aux_graph, "aux_labels") aux_graph_with_self_loops = k2.arc_sort(k2.add_epsilon_self_loops(aux_graph)).to(base_graph.device) - result = k2.intersect(k2.arc_sort(base_graph), aux_graph_with_self_loops, treat_epsilons_specially=False,) + result = k2.intersect(k2.arc_sort(base_graph), aux_graph_with_self_loops, treat_epsilons_specially=False) setattr(result, "phones", result.labels) return result @@ -113,7 +136,7 @@ def compose_with_self_loops(base_graph: 'k2.Fsa', aux_graph: 'k2.Fsa') -> 'k2.Fs """Composition helper function. """ aux_graph_with_self_loops = k2.arc_sort(k2.add_epsilon_self_loops(aux_graph)).to(base_graph.device) - return k2.compose(base_graph, aux_graph_with_self_loops, treat_epsilons_specially=False, inner_labels="phones",) + return k2.compose(base_graph, aux_graph_with_self_loops, treat_epsilons_specially=False, inner_labels="phones") def create_sparse_wrapped( @@ -154,18 +177,17 @@ def create_sparse_wrapped( def prep_padded_densefsavec(log_softmax: torch.Tensor, supervisions: torch.Tensor) -> 'k2.DenseFsaVec': """Performs special epsilon-padding required for composition with some of the topologies. """ - log_softmax_shifted = torch.cat( + log_softmax_eps = torch.cat( [ - torch.full((log_softmax.shape[0], log_softmax.shape[1], 1), -float("inf"), device=log_softmax.device,), log_softmax, + torch.full((log_softmax.shape[0], log_softmax.shape[1], 1), -float("inf"), device=log_softmax.device,), ], axis=-1, ) log_softmax_padded = torch.zeros( - (log_softmax_shifted.shape[0], log_softmax_shifted.shape[1] * 2, log_softmax_shifted.shape[2],), - device=log_softmax.device, + (log_softmax_eps.shape[0], log_softmax_eps.shape[1] * 2, log_softmax_eps.shape[2],), device=log_softmax.device, ) - log_softmax_padded[:, ::2] = log_softmax_shifted + log_softmax_padded[:, ::2] = log_softmax_eps supervisions_padded = supervisions.clone() supervisions_padded[:, 2] *= 2 dense_log_softmax_padded = k2.DenseFsaVec(log_softmax_padded, supervisions_padded) @@ -173,7 +195,8 @@ def prep_padded_densefsavec(log_softmax: torch.Tensor, supervisions: torch.Tenso def shift_labels_inpl(lattices: List['k2.Fsa'], shift: int): - """Shifts lattice labels and aux_labels by a given number. This is an in-place operation. + """Shifts lattice labels and aux_labels by a given number. + This is an in-place operation, if the lattice is on GPU. """ for lattice in lattices: mask = lattice.labels > 0 @@ -181,7 +204,34 @@ def shift_labels_inpl(lattices: List['k2.Fsa'], shift: int): if hasattr(lattice, "aux_labels"): mask = lattice.aux_labels > 0 lattice.aux_labels[mask] += shift - return lattices + return reset_properties_fsa(lattices) + + +def reset_properties_fsa(graph: 'k2.Fsa'): + """Resets properties of a graph. + In-place (does not create a new graph) if the graph is on GPU. + Use this every time you alter a graph in-place. + See https://github.com/k2-fsa/k2/issues/978 for more information.""" + graph.__dict__["_properties"] = None + # CPU graphs need to be sorted e.g. for intersection + if graph.device == torch.device("cpu"): + graph = k2.arc_sort(graph) + return graph + + +def add_self_loops(graph: 'k2.Fsa', label: int = 0, mode: str = "auto"): + """Adds self-loops with given label to a graph. + Supported modes are ``input``, ``output``, and ``auto``, + Where ``input`` leaves aux_labels zeroes, if present, ``output`` leaves labels zeroes""" + assert mode in ("input", "output", "auto"), "Supported modes are ``input``, ``output``, and ``auto``: {mode}" + assert mode != "output" or hasattr(graph, "aux_labels"), "Graph must have aux_labels for mode ``output``" + new_graph, arc_map = k2.add_epsilon_self_loops(graph, ret_arc_map=True) + + if mode != "output": + new_graph.labels[arc_map == -1] = label + if mode != "input" and hasattr(graph, "aux_labels"): + new_graph.aux_labels[arc_map == -1] = label + return reset_properties_fsa(new_graph) def get_arc_weights(graph: 'k2.Fsa') -> torch.Tensor: @@ -189,7 +239,7 @@ def get_arc_weights(graph: 'k2.Fsa') -> torch.Tensor: """ if len(graph.shape) > 2: raise NotImplementedError("FsaVec is not supported at the moment.") - weights_int = graph.arcs_as_tensor()[:, -1].tolist() + weights_int = graph.arcs.values()[:, -1].tolist() weights_float = struct.unpack('%sf' % len(weights_int), struct.pack('%si' % len(weights_int), *weights_int)) return torch.Tensor(weights_float) @@ -214,3 +264,63 @@ def get_tot_objf_and_finite_mask(tot_scores: torch.Tensor, reduction: str) -> Tu elif reduction == "sum": tot_scores = tot_scores[finite_mask].sum() return tot_scores, finite_mask + + +def get_uniform_rnnt_prune_ranges( + encoded_lengths: torch.Tensor, + target_lengths: torch.Tensor, + window_size_with_blank: int, + step: int = 1, + max_seq_len: Optional[int] = None, + begin_only: bool = False, +) -> torch.Tensor: + """Creates the pruning ranges for the Encoder and Predictor of RNNT. + The ranges are similar to https://k2-fsa.github.io/k2/python_api/api.html#k2.get_rnnt_prune_ranges + but they are constructed under the assumption of the uniform distribution token activations across time frames + and without any posterior knowledge. + """ + assert window_size_with_blank > 1 + assert step >= 1 + assert window_size_with_blank > step + assert len(encoded_lengths) == len(target_lengths) + ranges_begin = torch.zeros( + ( + len(encoded_lengths), + encoded_lengths.max() if max_seq_len is None else max(max_seq_len, encoded_lengths.max()), + ), + dtype=torch.long, + ) + for i in (target_lengths >= window_size_with_blank).nonzero(as_tuple=True)[0]: + encoded_len = encoded_lengths[i] + ranges_begin_raw = torch.arange(int((target_lengths[i] - window_size_with_blank) / step + 2)) * step + ranges_begin_raw[-1] = target_lengths[i] - window_size_with_blank + 1 + ranges_begin[i, :encoded_len] = torch.nn.functional.interpolate( + ranges_begin_raw.reshape(1, 1, -1).to(dtype=torch.float), encoded_len, mode="nearest-exact" + ).to(dtype=torch.long) + ranges_begin[i, encoded_len:] = ranges_begin[i, encoded_len - 1] + return ( + ranges_begin + if begin_only + else ranges_begin.unsqueeze(-1).repeat(1, 1, window_size_with_blank) + torch.arange(window_size_with_blank) + ) + + +def apply_rnnt_prune_ranges( + encoder_outputs: torch.Tensor, decoder_outputs: torch.Tensor, ranges: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + """Prepares pruned encoder and decoder outputs according to the prune ranges. + Based on k2.do_rnnt_pruning(...) + """ + B, T, window_size_with_blank = ranges.size() + D1 = encoder_outputs.size(-1) + _, U, D2 = decoder_outputs.size() + assert B == encoder_outputs.size(0) + assert T == encoder_outputs.size(1) + assert B == decoder_outputs.size(0) + encoder_outputs_pruned = encoder_outputs.unsqueeze(2).expand((B, T, window_size_with_blank, D1)) + decoder_outputs_pruned = torch.gather( + decoder_outputs.unsqueeze(1).expand((B, T, U, D2)), + dim=2, + index=ranges.reshape((B, T, window_size_with_blank, 1)).expand((B, T, window_size_with_blank, D2)), + ) + return encoder_outputs_pruned, decoder_outputs_pruned diff --git a/nemo/core/utils/k2_guard.py b/nemo/core/utils/k2_guard.py index e818788a3ae7..df4a01b03963 100644 --- a/nemo/core/utils/k2_guard.py +++ b/nemo/core/utils/k2_guard.py @@ -20,12 +20,13 @@ """ import textwrap +from typing import Tuple from packaging.version import Version from pytorch_lightning.utilities.imports import package_available __K2_MINIMUM_MAJOR_VERSION = 1 -__K2_MINIMUM_MINOR_VERSION = 11 +__K2_MINIMUM_MINOR_VERSION = 14 __K2_MINIMUM_VERSION = Version(f"{__K2_MINIMUM_MAJOR_VERSION}.{__K2_MINIMUM_MINOR_VERSION}") @@ -40,11 +41,14 @@ ) if not package_available("k2"): - raise ModuleNotFoundError(K2_INSTALLATION_MESSAGE) + raise ModuleNotFoundError("Module k2 is not available.\n" + K2_INSTALLATION_MESSAGE) import k2 # noqa: E402 -__k2_version = Version(k2.__dev_version__) +try: + __k2_version = Version(k2.__dev_version__) +except AttributeError: + raise ImportError("Module k2 is corrupted.\n" + K2_INSTALLATION_MESSAGE) if __k2_version < __K2_MINIMUM_VERSION: raise ImportError( diff --git a/tests/collections/asr/k2/test_ctc.py b/tests/collections/asr/k2/test_ctc.py new file mode 100644 index 000000000000..24b9c8eb07e9 --- /dev/null +++ b/tests/collections/asr/k2/test_ctc.py @@ -0,0 +1,311 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest +import torch +from torch.nn import CTCLoss as CTCLoss_Pytorch + +DEVICES = ['cpu'] + +if torch.cuda.is_available(): + DEVICES.append('cuda') + + +def wrap_and_call(fn, acts, labels, device): + if not torch.is_tensor(acts): + acts = torch.FloatTensor(acts) + + if 'cuda' in device: + acts = acts.cuda() + + if not acts.requires_grad: + acts.requires_grad = True + + lengths = [acts.shape[1]] * acts.shape[0] + label_lengths = [len(l) for l in labels] + labels = torch.LongTensor(labels) + lengths = torch.LongTensor(lengths) + label_lengths = torch.LongTensor(label_lengths) + log_probs = torch.nn.functional.log_softmax(acts.transpose(0, 1), -1) + if 'cuda' in device: + labels = labels.cuda() + lengths = lengths.cuda() + label_lengths = label_lengths.cuda() + + costs = fn(log_probs, labels, lengths, label_lengths) + cost = torch.sum(costs) + cost.backward() + + if 'cuda' in device: + torch.cuda.synchronize() + + if acts.grad is not None: + grad = acts.grad.data.cpu().numpy() + else: + grad = None + + return costs.data.cpu().numpy(), grad + + +def init_k2_ctc(**kwargs): + from nemo.collections.asr.parts.k2.ml_loss import CtcLoss + + ctc = CtcLoss(**kwargs) + return lambda log_probs, labels, lengths, label_lengths: ctc( + log_probs.transpose(0, 1), labels, lengths, label_lengths + )[0] + + +def skip_test_if_unsupported(device, k2_is_appropriate, k2_cuda_is_enabled): + if device == 'cpu': + supported, msg = k2_is_appropriate + elif device == 'cuda': + supported, msg = k2_cuda_is_enabled + else: + raise ValueError(f"Unknown device: {device}") + if not supported: + pytest.skip(f"k2 test is skipped. Reason : {msg}") + + +class TestCTCLossK2: + @pytest.mark.unit + @pytest.mark.parametrize('device', DEVICES) + def test_case_small(self, device, k2_is_appropriate, k2_cuda_is_enabled): + skip_test_if_unsupported(device, k2_is_appropriate, k2_cuda_is_enabled) + + acts = np.array( + [ + [ + [0.1, 0.6, 0.1, 0.1, 0.1], + [0.1, 0.1, 0.6, 0.1, 0.1], + [0.1, 0.1, 0.2, 0.8, 0.1], + [0.1, 0.6, 0.1, 0.1, 0.1], + [0.1, 0.1, 0.2, 0.1, 0.1], + [0.7, 0.1, 0.2, 0.1, 0.1], + ] + ] + ) + labels = [[1, 2, 3]] + + fn_k2 = init_k2_ctc(num_classes=acts.shape[-1], blank=0, reduction='sum') + k2_cost, k2_grads = wrap_and_call(fn_k2, acts, labels, device) + + expected_cost = 5.0279555 + expected_grads = np.array( + [ + [ + [0.00157518, -0.53266853, 0.17703111, 0.17703111, 0.17703111], + [-0.02431531, -0.17048728, -0.15925968, 0.17703113, 0.17703113], + [-0.06871005, 0.03236287, -0.2943067, 0.16722652, 0.16342735], + [-0.09178554, 0.25313747, -0.17673965, -0.16164337, 0.17703108], + [-0.10229809, 0.19587973, 0.05823242, -0.34769377, 0.19587973], + [-0.22203964, 0.1687112, 0.18645471, -0.30183747, 0.1687112], + ] + ] + ) + + assert np.allclose(k2_cost, expected_cost, rtol=1e-6), "small_test costs mismatch." + assert np.allclose(k2_grads, expected_grads, atol=1e-6), "small_test gradient mismatch." + + @pytest.mark.unit + @pytest.mark.parametrize('device', DEVICES) + def test_case_small_blank_last(self, device, k2_is_appropriate, k2_cuda_is_enabled): + skip_test_if_unsupported(device, k2_is_appropriate, k2_cuda_is_enabled) + + acts = np.array( + [ + [ + [0.0, 1.0, 3.0], + [0.0, 2.0, 3.0], + [1.0, 1.0, 3.0], + [2.0, 3.0, 2.0], + [0.0, 0.0, 1.0], + [0.0, 1.0, 1.0], + [1.0, 0.0, 1.0], + [2.0, 2.0, 0.0], + [0.0, 2.0, 5.0], + [0.0, 3.0, 5.0], + [1.0, 2.0, 5.0], + [2.0, 4.0, 4.0], + [0.0, 3.0, 4.0], + [0.0, 4.0, 4.0], + [1.0, 3.0, 4.0], + [2.0, 5.0, 3.0], + [2.0, 2.0, 1.0], + [2.0, 3.0, 1.0], + [3.0, 2.0, 1.0], + [4.0, 4.0, 0.0], + ] + ] + ) + labels = [[0, 1, 0, 0, 1, 0]] + + fn_k2 = init_k2_ctc(num_classes=acts.shape[-1], blank=acts.shape[-1] - 1, reduction='sum') + k2_cost, k2_grads = wrap_and_call(fn_k2, acts, labels, device) + + expected_cost = 6.823422 + expected_grads = np.array( + [ + [ + [-0.09792291, 0.11419516, -0.01627225], + [-0.08915664, 0.22963384, -0.14047718], + [-0.19687234, 0.06477807, 0.13209426], + [-0.22838503, 0.1980845, 0.03030053], + [-0.07985485, -0.0589368, 0.13879165], + [-0.04722299, 0.01424287, 0.03298012], + [0.01492161, 0.02710512, -0.04202673], + [-0.43219852, 0.4305843, 0.00161422], + [-0.00332598, 0.0440818, -0.04075582], + [-0.01329869, 0.11521607, -0.10191737], + [-0.03721291, 0.04389342, -0.00668051], + [-0.2723349, 0.43273386, -0.16039898], + [-0.03499417, 0.1896997, -0.15470551], + [-0.02911933, 0.29706067, -0.26794133], + [-0.04593367, -0.04479058, 0.09072424], + [-0.07227867, 0.16096972, -0.08869105], + [0.13993078, -0.20230117, 0.06237038], + [-0.05889719, 0.04007925, 0.01881794], + [-0.09667239, 0.07077749, 0.0258949], + [-0.49002117, 0.4954626, -0.00544143], + ] + ] + ) + + assert np.allclose(k2_cost, expected_cost, rtol=1e-6), "small_test_blank_last costs mismatch." + assert np.allclose(k2_grads, expected_grads, atol=1e-6), "small_test_blank_last gradient mismatch." + + @pytest.mark.unit + @pytest.mark.parametrize('device', DEVICES) + def test_case_small_random(self, device, k2_is_appropriate, k2_cuda_is_enabled): + skip_test_if_unsupported(device, k2_is_appropriate, k2_cuda_is_enabled) + + rng = np.random.RandomState(0) + acts = rng.randn(1, 4, 3) + labels = [[1, 2]] + + fn_k2 = init_k2_ctc(num_classes=acts.shape[-1], blank=0, reduction='sum') + k2_cost, k2_grads = wrap_and_call(fn_k2, acts, labels, device) + + fn_pt = CTCLoss_Pytorch(reduction='sum', zero_infinity=True) + pt_cost, pt_grads = wrap_and_call(fn_pt, acts, labels, device) + + assert np.allclose(k2_cost, pt_cost, rtol=1e-6), "small_random_test costs mismatch." + assert np.allclose(k2_grads, pt_grads, atol=1e-6), "small_random_test gradient mismatch." + + @pytest.mark.unit + @pytest.mark.parametrize('device', DEVICES) + def test_case_big_tensor(self, device, k2_is_appropriate, k2_cuda_is_enabled): + skip_test_if_unsupported(device, k2_is_appropriate, k2_cuda_is_enabled) + + # minibatch x T x alphabet_size + acts = [ + [ + [0.06535690384862791, 0.7875301411923206, 0.08159176605666074], + [0.5297155426466327, 0.7506749639230854, 0.7541348379087998], + [0.6097641124736383, 0.8681404965673826, 0.6225318186056529], + [0.6685222872103057, 0.8580392805336061, 0.16453892311765583], + [0.989779515236694, 0.944298460961015, 0.6031678586829663], + [0.9467833543605416, 0.666202507295747, 0.28688179752461884], + [0.09418426230195986, 0.3666735970751962, 0.736168049462793], + [0.1666804425271342, 0.7141542198635192, 0.3993997272216727], + [0.5359823524146038, 0.29182076440286386, 0.6126422611507932], + [0.3242405528768486, 0.8007644367291621, 0.5241057606558068], + [0.779194617063042, 0.18331417220174862, 0.113745182072432], + [0.24022162381327106, 0.3394695622533106, 0.1341595066017014], + ], + [ + [0.5055615569388828, 0.051597282072282646, 0.6402903936686337], + [0.43073311517251, 0.8294731834714112, 0.1774668847323424], + [0.3207001991262245, 0.04288308912457006, 0.30280282975568984], + [0.6751777088333762, 0.569537369330242, 0.5584738347504452], + [0.08313242153985256, 0.06016544344162322, 0.10795752845152584], + [0.7486153608562472, 0.943918041459349, 0.4863558118797222], + [0.4181986264486809, 0.6524078485043804, 0.024242983423721887], + [0.13458171554507403, 0.3663418070512402, 0.2958297395361563], + [0.9236695822497084, 0.6899291482654177, 0.7418981733448822], + [0.25000547599982104, 0.6034295486281007, 0.9872887878887768], + [0.5926057265215715, 0.8846724004467684, 0.5434495396894328], + [0.6607698886038497, 0.3771277082495921, 0.3580209022231813], + ], + ] + + expected_costs = [6.388067, 5.2999153] + expected_grads = [ + [ + [0.06130501, -0.3107036, 0.24939862], + [0.08428053, -0.07131141, -0.01296911], + [-0.04510102, 0.21943177, -0.17433074], + [-0.1970142, 0.37144178, -0.17442757], + [-0.08807078, 0.35828218, -0.2702114], + [-0.24209887, 0.33242193, -0.09032306], + [-0.07871056, 0.3116736, -0.23296304], + [-0.27552277, 0.43320477, -0.157682], + [-0.16173504, 0.27361175, -0.1118767], + [-0.13012655, 0.42030025, -0.2901737], + [-0.2378576, 0.26685005, -0.02899244], + [0.08487711, 0.36765888, -0.45253596], + ], + [ + [-0.14147596, -0.2702151, 0.41169107], + [-0.05323913, -0.18442528, 0.23766442], + [-0.24160458, -0.11692462, 0.3585292], + [-0.1004294, -0.17919227, 0.27962166], + [-0.01819841, -0.12625945, 0.14445786], + [-0.00131121, 0.06060241, -0.0592912], + [-0.09093696, 0.2536721, -0.16273515], + [-0.08962183, 0.34198248, -0.25236064], + [-0.19668606, 0.25176668, -0.05508063], + [0.0232805, 0.1351273, -0.1584078], + [0.09494846, -0.17026341, 0.07531495], + [0.00775955, -0.30424336, 0.29648378], + ], + ] + + acts = np.array(acts) + expected_costs = np.array(expected_costs) + labels = [[1, 2, 2, 2, 2], [1, 1, 2, 2, 1]] + + fn_k2 = init_k2_ctc(num_classes=acts.shape[-1], blank=0, reduction='none') + k2_costs, k2_grads = wrap_and_call(fn_k2, acts, labels, device) + + assert np.allclose(k2_costs, expected_costs), "big_test average costs mismatch." + assert np.allclose(k2_grads, expected_grads, rtol=1e-3), "big_test grads for average cost mismatch." + + @pytest.mark.unit + @pytest.mark.parametrize('device', DEVICES) + def test_case_large_random(self, device, k2_is_appropriate, k2_cuda_is_enabled): + skip_test_if_unsupported(device, k2_is_appropriate, k2_cuda_is_enabled) + + rng = np.random.RandomState(0) + acts = rng.randn(4, 80, 5) + labels = [ + [1, 2, 4, 3, 2, 2, 1, 1, 1, 1, 1, 1, 2, 1, 2, 3, 3, 1, 1, 1], + [3, 2, 2, 3, 4, 1, 1, 1, 1, 1, 4, 4, 1, 2, 1, 3, 4, 3, 1, 2], + [4, 4, 1, 2, 1, 3, 4, 3, 1, 2, 3, 2, 2, 3, 4, 1, 1, 1, 1, 1], + [1, 1, 2, 1, 2, 3, 3, 1, 1, 1, 1, 2, 4, 3, 2, 2, 1, 1, 1, 1], + ] + + fn_k2 = init_k2_ctc(num_classes=acts.shape[-1], blank=0, reduction='sum') + k2_costs, k2_grads = wrap_and_call(fn_k2, acts, labels, device) + + fn_pt = CTCLoss_Pytorch(reduction='sum', zero_infinity=True) + pt_costs, pt_grads = wrap_and_call(fn_pt, acts, labels, device) + + assert np.allclose(k2_costs, pt_costs, atol=1e-5, rtol=1e-3), "large_random_test costs mismatch." + assert np.allclose(k2_grads, pt_grads, atol=1e-5, rtol=1e-3), "large_random_test gradient mismatch." + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/collections/asr/k2/test_rnnt.py b/tests/collections/asr/k2/test_rnnt.py new file mode 100644 index 000000000000..8b2dabb44d73 --- /dev/null +++ b/tests/collections/asr/k2/test_rnnt.py @@ -0,0 +1,341 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest +import torch + +from nemo.collections.asr.parts.numba.rnnt_loss.rnnt_numpy import RNNTLoss as RNNTLoss_Numpy + +DEVICES = ['cpu'] + +if torch.cuda.is_available(): + DEVICES.append('cuda') + + +def wrap_and_call(fn, acts, labels, device): + if not torch.is_tensor(acts): + acts = torch.FloatTensor(acts) + + if 'cuda' in device: + acts = acts.cuda() + + if not acts.requires_grad: + acts.requires_grad = True + + lengths = [acts.shape[1]] * acts.shape[0] + label_lengths = [len(l) for l in labels] + labels = torch.LongTensor(labels) + lengths = torch.LongTensor(lengths) + label_lengths = torch.LongTensor(label_lengths) + if 'cuda' in device: + labels = labels.cuda() + lengths = lengths.cuda() + label_lengths = label_lengths.cuda() + + costs = fn(acts, labels, lengths, label_lengths) + cost = torch.sum(costs) + cost.backward() + + if 'cuda' in device: + torch.cuda.synchronize() + + if acts.grad is not None: + grad = acts.grad.data.cpu().numpy() + else: + grad = None + + return costs.data.cpu().numpy(), grad + + +def init_k2_rnnt(**kwargs): + from nemo.collections.asr.parts.k2.ml_loss import RnntLoss + + rnnt = RnntLoss(**kwargs) + return lambda acts, labels, lengths, label_lengths: rnnt( + torch.nn.functional.log_softmax(acts, -1), + labels.to(dtype=torch.long), + lengths.to(dtype=torch.long), + label_lengths.to(dtype=torch.long), + )[0] + + +def skip_test_if_unsupported(device, k2_is_appropriate, k2_cuda_is_enabled): + if device == 'cpu': + supported, msg = k2_is_appropriate + elif device == 'cuda': + supported, msg = k2_cuda_is_enabled + else: + raise ValueError(f"Unknown device: {device}") + if not supported: + pytest.skip(f"k2 test is skipped. Reason : {msg}") + + +class TestRNNTLossK2: + @pytest.mark.unit + @pytest.mark.parametrize('device', DEVICES) + def test_case_small(self, device, k2_is_appropriate, k2_cuda_is_enabled): + skip_test_if_unsupported(device, k2_is_appropriate, k2_cuda_is_enabled) + + acts = np.array( + [ + [ + [[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1], [0.1, 0.1, 0.2, 0.8, 0.1]], + [[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.2, 0.1, 0.1], [0.7, 0.1, 0.2, 0.1, 0.1]], + ] + ] + ) + labels = [[1, 2]] + + fn_k2 = init_k2_rnnt(num_classes=acts.shape[-1], blank=0, reduction='sum') + k2_cost, k2_grads = wrap_and_call(fn_k2, acts, labels, device) + + expected_cost = 4.495666 + expected_grads = np.array( + [ + [ + [ + [-0.13116688, -0.3999269, 0.17703125, 0.17703125, 0.17703125], + [-0.18572757, 0.12247056, -0.18168412, 0.12247056, 0.12247056], + [-0.32091254, 0.06269141, 0.06928472, 0.12624499, 0.06269141], + ], + [ + [0.05456069, -0.21824276, 0.05456069, 0.05456069, 0.05456069], + [0.12073959, 0.12073959, -0.48295835, 0.12073959, 0.12073959], + [-0.6925882, 0.16871116, 0.18645467, 0.16871116, 0.16871116], + ], + ] + ] + ) + + assert np.allclose(k2_cost, expected_cost, rtol=1e-6), "small_test costs mismatch." + assert np.allclose(k2_grads, expected_grads, atol=1e-6), "small_test gradient mismatch." + + @pytest.mark.unit + @pytest.mark.parametrize('device', DEVICES) + def test_case_small_blank_last(self, device, k2_is_appropriate, k2_cuda_is_enabled): + skip_test_if_unsupported(device, k2_is_appropriate, k2_cuda_is_enabled) + + acts = np.array( + [ + [ + [[0.0, 1.0, 3.0], [0.0, 2.0, 3.0], [1.0, 1.0, 3.0], [2.0, 3.0, 2.0]], + [[0.0, 0.0, 1.0], [0.0, 1.0, 1.0], [1.0, 0.0, 1.0], [2.0, 2.0, 0.0]], + [[0.0, 2.0, 5.0], [0.0, 3.0, 5.0], [1.0, 2.0, 5.0], [2.0, 4.0, 4.0]], + [[0.0, 3.0, 4.0], [0.0, 4.0, 4.0], [1.0, 3.0, 4.0], [2.0, 5.0, 3.0]], + [[2.0, 2.0, 1.0], [2.0, 3.0, 1.0], [3.0, 2.0, 1.0], [4.0, 4.0, 0.0]], + ] + ] + ) + labels = [[0, 1, 0]] + + fn_k2 = init_k2_rnnt(num_classes=acts.shape[-1], blank=acts.shape[-1] - 1, reduction='sum') + k2_cost, k2_grads = wrap_and_call(fn_k2, acts, labels, device) + + expected_cost = 6.789285182952881 + expected_grads = np.array( + [ + [ + [ + [-0.03551076725125313, 0.11419519782066345, -0.07868456840515137], + [0.0027224558871239424, 0.00704305712133646, -0.009765520691871643], + [0.0013856772566214204, 0.0013924005907028913, -0.0027780719101428986], + [1.4249643527364242e-06, 3.873454716085689e-06, -5.298420546751004e-06], + ], + [ + [-0.1934257447719574, 0.19551163911819458, -0.0020859241485595703], + [0.07043898105621338, 0.05738453567028046, -0.12782356142997742], + [0.061031512916088104, 0.02286236733198166, -0.08389391005039215], + [0.0005252412520349026, 0.0005252412520349026, -0.0010504829697310925], + ], + [ + [-0.007841046899557114, 0.025142310187220573, -0.017301201820373535], + [0.0019501042552292347, 0.0005148053169250488, -0.0024650096893310547], + [0.0027856370434165, 0.008609085343778133, -0.01139475405216217], + [9.526080975774676e-05, 0.0007038871408440173, -0.000799147819634527], + ], + [ + [-0.01533521432429552, 0.1386115401983261, -0.12327653169631958], + [0.002850571647286415, -0.006693005561828613, 0.003842458128929138], + [0.009236274287104607, 0.08995233476161957, -0.0991886705160141], + [0.0001865450612967834, 0.0037468576338142157, -0.003933403175324202], + ], + [ + [-0.2888762652873993, 0.211185485124588, 0.07769080251455307], + [0.15952755510807037, -0.2182144820690155, 0.05868690833449364], + [-0.3332723379135132, 0.2436419129371643, 0.0896308496594429], + [0.4954628646373749, 0.4954628646373749, -0.9909257292747498], + ], + ] + ] + ) + + assert np.allclose(k2_cost, expected_cost, rtol=1e-6), "small_test_blank_last costs mismatch." + assert np.allclose(k2_grads, expected_grads, atol=1e-6), "small_test_blank_last gradient mismatch." + + @pytest.mark.unit + @pytest.mark.parametrize('device', DEVICES) + def test_case_small_random(self, device, k2_is_appropriate, k2_cuda_is_enabled): + skip_test_if_unsupported(device, k2_is_appropriate, k2_cuda_is_enabled) + + rng = np.random.RandomState(0) + acts = rng.randn(1, 4, 3, 3) + labels = [[1, 2]] + + fn_k2 = init_k2_rnnt(num_classes=acts.shape[-1], blank=0, reduction='sum') + k2_cost, k2_grads = wrap_and_call(fn_k2, acts, labels, device) + + fn_np = RNNTLoss_Numpy() + np_cost, np_grads = wrap_and_call(fn_np, acts, labels, device) + + assert np.allclose(k2_cost, np_cost, rtol=1e-6), "small_random_test costs mismatch." + assert np.allclose(k2_grads, np_grads, atol=1e-6), "small_random_test gradient mismatch." + + @pytest.mark.unit + @pytest.mark.parametrize('device', DEVICES) + def test_case_big_tensor(self, device, k2_is_appropriate, k2_cuda_is_enabled): + skip_test_if_unsupported(device, k2_is_appropriate, k2_cuda_is_enabled) + + # minibatch x T x U x alphabet_size + acts = [ + [ + [ + [0.06535690384862791, 0.7875301411923206, 0.08159176605666074], + [0.5297155426466327, 0.7506749639230854, 0.7541348379087998], + [0.6097641124736383, 0.8681404965673826, 0.6225318186056529], + ], + [ + [0.6685222872103057, 0.8580392805336061, 0.16453892311765583], + [0.989779515236694, 0.944298460961015, 0.6031678586829663], + [0.9467833543605416, 0.666202507295747, 0.28688179752461884], + ], + [ + [0.09418426230195986, 0.3666735970751962, 0.736168049462793], + [0.1666804425271342, 0.7141542198635192, 0.3993997272216727], + [0.5359823524146038, 0.29182076440286386, 0.6126422611507932], + ], + [ + [0.3242405528768486, 0.8007644367291621, 0.5241057606558068], + [0.779194617063042, 0.18331417220174862, 0.113745182072432], + [0.24022162381327106, 0.3394695622533106, 0.1341595066017014], + ], + ], + [ + [ + [0.5055615569388828, 0.051597282072282646, 0.6402903936686337], + [0.43073311517251, 0.8294731834714112, 0.1774668847323424], + [0.3207001991262245, 0.04288308912457006, 0.30280282975568984], + ], + [ + [0.6751777088333762, 0.569537369330242, 0.5584738347504452], + [0.08313242153985256, 0.06016544344162322, 0.10795752845152584], + [0.7486153608562472, 0.943918041459349, 0.4863558118797222], + ], + [ + [0.4181986264486809, 0.6524078485043804, 0.024242983423721887], + [0.13458171554507403, 0.3663418070512402, 0.2958297395361563], + [0.9236695822497084, 0.6899291482654177, 0.7418981733448822], + ], + [ + [0.25000547599982104, 0.6034295486281007, 0.9872887878887768], + [0.5926057265215715, 0.8846724004467684, 0.5434495396894328], + [0.6607698886038497, 0.3771277082495921, 0.3580209022231813], + ], + ], + ] + + expected_costs = [4.2806528590890736, 3.9384369822503591] + expected_grads = [ + [ + [ + [-1.86843902e-01, -6.25548810e-02, 2.49398798e-01], + [-2.03376666e-01, 2.02399328e-01, 9.77333169e-04], + [-1.41016081e-01, 7.91234672e-02, 6.18926100e-02], + ], + [ + [-1.15517676e-02, -8.12802389e-02, 9.28319991e-02], + [-1.54257029e-01, 2.29432687e-01, -7.51756504e-02], + [-2.46593088e-01, 1.46404594e-01, 1.00188486e-01], + ], + [ + [-1.29182907e-02, -6.15932420e-02, 7.45115355e-02], + [-5.59857301e-02, 2.19830811e-01, -1.63845062e-01], + [-4.97626871e-01, 2.09239945e-01, 2.88386941e-01], + ], + [ + [1.36048580e-02, -3.02196294e-02, 1.66147724e-02], + [1.13924511e-01, 6.27811998e-02, -1.76705718e-01], + [-6.67078257e-01, 3.67658824e-01, 2.99419403e-01], + ], + ], + [ + [ + [-3.56343776e-01, -5.53474613e-02, 4.11691219e-01], + [-9.69219357e-02, 2.94591039e-02, 6.74628317e-02], + [-6.35175705e-02, 2.76544970e-02, 3.58630717e-02], + ], + [ + [-1.54499024e-01, -7.39420280e-02, 2.28441030e-01], + [-1.66789949e-01, -8.78955179e-05, 1.66877866e-01], + [-1.72369644e-01, 1.05565332e-01, 6.68043196e-02], + ], + [ + [2.38748826e-02, -1.18255816e-01, 9.43809375e-02], + [-1.04707085e-01, -1.08934477e-01, 2.13641584e-01], + [-3.69844258e-01, 1.80118099e-01, 1.89726159e-01], + ], + [ + [2.57137045e-02, -7.94617534e-02, 5.37480488e-02], + [1.22328237e-01, -2.38788679e-01, 1.16460443e-01], + [-5.98686993e-01, 3.02203178e-01, 2.96483815e-01], + ], + ], + ] + + acts = np.array(acts) + expected_costs = np.array(expected_costs) + labels = [[1, 2], [1, 1]] + + fn_k2 = init_k2_rnnt(num_classes=acts.shape[-1], blank=0, reduction='none') + k2_costs, k2_grads = wrap_and_call(fn_k2, acts, labels, device) + + assert np.allclose(k2_costs, expected_costs), "big_test average costs mismatch." + assert np.allclose(k2_grads, expected_grads, rtol=1e-3), "big_test grads for average cost mismatch." + + @pytest.mark.unit + @pytest.mark.parametrize('device', DEVICES) + def test_case_large_random(self, device, k2_is_appropriate, k2_cuda_is_enabled): + skip_test_if_unsupported(device, k2_is_appropriate, k2_cuda_is_enabled) + + rng = np.random.RandomState(0) + acts = rng.randn(4, 8, 11, 5) + labels = [ + [1, 2, 4, 3, 2, 2, 1, 1, 1, 1], + [3, 2, 2, 3, 4, 1, 1, 1, 1, 1], + [4, 4, 1, 2, 1, 3, 4, 3, 1, 2], + [1, 1, 2, 1, 2, 3, 3, 1, 1, 1], + ] + + fn_k2 = init_k2_rnnt(num_classes=acts.shape[-1], blank=0, reduction='sum') + k2_costs, k2_grads = wrap_and_call(fn_k2, acts, labels, device) + + fn_np = RNNTLoss_Numpy() + np_costs, np_grads = wrap_and_call(fn_np, acts, labels, device) + + assert np.allclose(k2_costs, np_costs, atol=1e-5, rtol=1e-3), "large_random_test costs mismatch." + assert np.allclose(k2_grads, np_grads, atol=1e-5, rtol=1e-3), "large_random_test gradient mismatch." + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/conftest.py b/tests/conftest.py index 5987416e22e3..4bccc40a9a15 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging import os.path import shutil import tarfile @@ -20,6 +21,7 @@ from os.path import dirname, exists, getsize, join from pathlib import Path from shutil import rmtree +from typing import Tuple import pytest @@ -140,6 +142,33 @@ def extract_data_from_tar(test_dir, test_data_archive, url=None, local_data=Fals tar.close() +@pytest.fixture(scope="session") +def k2_is_appropriate() -> Tuple[bool, str]: + try: + from nemo.core.utils.k2_guard import k2 # noqa: E402 + + return True, "k2 is appropriate." + except Exception as e: + logging.exception(e, exc_info=True) + return False, "k2 is not available or does not meet the requirements." + + +@pytest.fixture(scope="session") +def k2_cuda_is_enabled(k2_is_appropriate) -> Tuple[bool, str]: + if not k2_is_appropriate[0]: + return k2_is_appropriate + + import torch # noqa: E402 + from nemo.core.utils.k2_guard import k2 # noqa: E402 + + if torch.cuda.is_available() and k2.with_cuda: + return True, "k2 supports CUDA." + elif torch.cuda.is_available(): + return False, "k2 does not support CUDA. Consider using a k2 build with CUDA support." + else: + return False, "k2 needs CUDA to be available in torch." + + def pytest_configure(config): """ Initial configuration of conftest.