From 2a7ce926575a793351525c8161538a8e611b73fc Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Tue, 7 Jun 2022 19:39:58 +0800 Subject: [PATCH] streaming model [distialltion with codebook loss --- .../pruned_transducer_stateless2/conformer.py | 43 +++++++++++--- .../ASR/pruned_transducer_stateless2/model.py | 58 ++++++++++++++++++- .../ASR/pruned_transducer_stateless4/train.py | 55 ++++++++++++++++-- 3 files changed, 140 insertions(+), 16 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index e28b5034dc..371d4028c1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -86,6 +86,7 @@ def __init__( short_chunk_size: int = 25, num_left_chunks: int = -1, causal: bool = False, + middle_output_layer: int = None, # 0-based layer index ) -> None: super(Conformer, self).__init__() @@ -121,12 +122,27 @@ def __init__( cnn_module_kernel, causal, ) - self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) + + output_layers = [] + if middle_output_layer is not None: + assert ( + middle_output_layer >= 0 + and middle_output_layer < num_encoder_layers + ) + output_layers.append(middle_output_layer) + + # The last layer is always needed. + output_layers.append(num_encoder_layers - 1) + + self.encoder = ConformerEncoder( + encoder_layer, num_encoder_layers, output_layers=output_layers + ) + self._init_state: List[torch.Tensor] = [torch.empty(0)] def forward( self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0 - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[List[torch.Tensor], torch.Tensor]: """ Args: x: @@ -176,7 +192,7 @@ def forward( num_left_chunks=self.num_left_chunks, device=x.device, ) - x = self.encoder( + layer_results = self.encoder( x, pos_emb, mask=mask, @@ -184,7 +200,7 @@ def forward( warmup=warmup, ) # (T, N, C) else: - x = self.encoder( + layer_results = self.encoder( x, pos_emb, mask=None, @@ -192,8 +208,7 @@ def forward( warmup=warmup, ) # (T, N, C) - x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - return x, lengths + return layer_results, lengths @torch.jit.export def get_init_state( @@ -647,12 +662,18 @@ class ConformerEncoder(nn.Module): >>> out = conformer_encoder(src, pos_emb) """ - def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None: + def __init__( + self, + encoder_layer: nn.Module, + num_layers: int, + output_layers: List[int], + ) -> None: super().__init__() self.layers = nn.ModuleList( [copy.deepcopy(encoder_layer) for i in range(num_layers)] ) self.num_layers = num_layers + self.output_layers = output_layers def forward( self, @@ -661,7 +682,7 @@ def forward( mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, warmup: float = 1.0, - ) -> Tensor: + ) -> List[Tensor]: r"""Pass the input through the encoder layers in turn. Args: @@ -682,6 +703,7 @@ def forward( """ output = src + layer_results = [] for layer_index, mod in enumerate(self.layers): output = mod( output, @@ -690,8 +712,11 @@ def forward( src_key_padding_mask=src_key_padding_mask, warmup=warmup, ) + if layer_index in self.output_layers: + # (T, N, C) --> (N, T, C) + layer_results.append(output.permute(1, 0, 2)) - return output + return layer_results @torch.jit.export def chunk_forward( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index 2434fd41d6..43109b59ff 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -23,6 +23,7 @@ from icefall.utils import add_sos +from multi_quantization.prediction import JointCodebookLoss class Transducer(nn.Module): """It implements https://arxiv.org/pdf/1211.3711.pdf @@ -38,6 +39,7 @@ def __init__( decoder_dim: int, joiner_dim: int, vocab_size: int, + num_codebooks: int = 0, ): """ Args: @@ -55,6 +57,8 @@ def __init__( (N, U, decoder_dim). Its output shape is (N, T, U, vocab_size). Note that its output contains unnormalized probs, i.e., not processed by log-softmax. + num_codebooks: + Used by distillation loss. """ super().__init__() assert isinstance(encoder, EncoderInterface), type(encoder) @@ -68,6 +72,10 @@ def __init__( encoder_dim, vocab_size, initial_speed=0.5 ) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) + if num_codebooks > 0: + self.codebook_loss_net = JointCodebookLoss( + predictor_channels=encoder_dim, num_codebooks=num_codebooks + ) def forward( self, @@ -78,6 +86,7 @@ def forward( am_scale: float = 0.0, lm_scale: float = 0.0, warmup: float = 1.0, + codebook_indexes: torch.Tensor = None, ) -> torch.Tensor: """ Args: @@ -101,6 +110,8 @@ def forward( warmup: A value warmup >= 0 that determines which modules are active, values warmup > 1 "are fully warmed up" and all modules will be active. + codebook_indexes: + codebook_indexes extracted from a teacher model. Returns: Return the transducer loss. @@ -116,7 +127,22 @@ def forward( assert x.size(0) == x_lens.size(0) == y.dim0 - encoder_out, x_lens = self.encoder(x, x_lens, warmup=warmup) + layer_results, x_lens = self.encoder(x, x_lens, warmup=warmup) + encoder_out = layer_results[-1] + middle_layer_output = layer_results[0] + if self.training and codebook_indexes is not None: + assert hasattr(self, "codebook_loss_net") + if codebook_indexes.shape[1] != middle_layer_output.shape[1]: + codebook_indexes = self.concat_successive_codebook_indexes( + middle_layer_output, codebook_indexes + ) + codebook_loss = self.codebook_loss_net( + middle_layer_output, codebook_indexes + ) + else: + # when codebook index is not available. + codebook_loss = None + assert torch.all(x_lens > 0) # Now for the decoder, i.e., the prediction network @@ -191,4 +217,32 @@ def forward( reduction="sum", ) - return (simple_loss, pruned_loss) + return (simple_loss, pruned_loss, codebook_loss) + + @staticmethod + def concat_successive_codebook_indexes( + middle_layer_output, codebook_indexes + ): + # Output rate of hubert is 50 frames per second, + # while that of current encoder is 25. + # Following code handling two issues: + # 1. + # Roughly speaking, to generate another frame output, + # hubert needes extra two frames, + # while current encoder needs extra four frames. + # Suppose there are only extra three frames provided, + # hubert will generate another frame while current encoder does nothing. + # 2. + # codebook loss is a frame-wise loss, to enalbe 25 frames studnet output + # learns from 50 frames teacher output, two successive frames of teacher model + # output is concatenated together. + t_expected = middle_layer_output.shape[1] + N, T, C = codebook_indexes.shape + + # Handling issue 1. + if T >= t_expected * 2: + codebook_indexes = codebook_indexes[:, : t_expected * 2, :] + # Handling issue 2. + codebook_indexes = codebook_indexes.reshape(N, t_expected, C * 2) + assert middle_layer_output.shape[1] == codebook_indexes.shape[1] + return codebook_indexes diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index 5cf92c77e3..e4f013607a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -26,7 +26,7 @@ --world-size 4 \ --num-epochs 30 \ --start-epoch 1 \ - --exp-dir pruned_transducer_stateless2/exp \ + --exp-dir pruned_transducer_stateless4/exp \ --full-libri 1 \ --max-duration 300 @@ -37,7 +37,7 @@ --num-epochs 30 \ --start-epoch 1 \ --use-fp16 1 \ - --exp-dir pruned_transducer_stateless2/exp \ + --exp-dir pruned_transducer_stateless4/exp \ --full-libri 1 \ --max-duration 550 @@ -74,9 +74,10 @@ from conformer import Conformer from decoder import Decoder from joiner import Joiner -from lhotse.cut import Cut +from lhotse.cut import Cut, MonoCut from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed +from lhotse.dataset.collation import collate_custom_field from model import Transducer from optim import Eden, Eve from torch import Tensor @@ -235,6 +236,13 @@ def get_parser(): "with this parameter before adding to the final loss.", ) + parser.add_argument( + "--codebook-loss-scale", + type=float, + default=0.1, + help="The scale of codebook loss.", + ) + parser.add_argument( "--seed", type=int, @@ -398,6 +406,13 @@ def get_params() -> AttributeDict: # parameters for Noam "model_warm_step": 3000, # arg given to model, not for lrate "env_info": get_env_info(), + # parameters for distillation with codebook indexes. + "enable_distiallation": True, + "distillation_layer": 5, # 0-based index + # Since output rate of hubert is 50, while that of encoder is 8, + # two successive codebook_index are concatenated together. + # Detailed in function Transducer::concat_sucessive_codebook_indexes. + "num_codebooks": 16, # used to construct distillation loss } ) @@ -417,6 +432,9 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: short_chunk_size=params.short_chunk_size, num_left_chunks=params.num_left_chunks, causal=params.causal_convolution, + middle_output_layer=params.distillation_layer + if params.enable_distiallation + else None, ) return encoder @@ -454,6 +472,9 @@ def get_transducer_model(params: AttributeDict) -> nn.Module: decoder_dim=params.decoder_dim, joiner_dim=params.joiner_dim, vocab_size=params.vocab_size, + num_codebooks=params.num_codebooks + if params.enable_distiallation + else 0, ) return model @@ -577,6 +598,18 @@ def save_checkpoint( copyfile(src=filename, dst=best_valid_filename) +def extract_codebook_indexes(batch): + cuts = batch["supervisions"]["cut"] + # -100 is identical to ignore_value in CE loss computation. + cuts_pre_mixed = [ + c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts + ] + codebook_indexes, codebook_indexes_lens = collate_custom_field( + cuts_pre_mixed, "codebook_indexes", pad_value=-100 + ) + return codebook_indexes, codebook_indexes_lens + + def compute_loss( params: AttributeDict, model: Union[nn.Module, DDP], @@ -620,8 +653,15 @@ def compute_loss( y = sp.encode(texts, out_type=int) y = k2.RaggedTensor(y).to(device) + info = MetricsTracker() + if is_training and params.enable_distiallation: + codebook_indexes, _ = extract_codebook_indexes(batch) + codebook_indexes = codebook_indexes.to(device) + else: + codebook_indexes = None + with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( + simple_loss, pruned_loss, codebook_loss = model( x=feature, x_lens=feature_lens, y=y, @@ -629,6 +669,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, warmup=warmup, + codebook_indexes=codebook_indexes, ) # after the main warmup step, we keep pruned_loss_scale small # for the same amount of time (model_warm_step), to avoid @@ -643,10 +684,12 @@ def compute_loss( params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss ) + if is_training and params.enable_distiallation: + assert codebook_loss is not None + loss += params.codebook_loss_scale * codebook_loss assert loss.requires_grad == is_training - info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") info["frames"] = ( @@ -657,6 +700,8 @@ def compute_loss( info["loss"] = loss.detach().cpu().item() info["simple_loss"] = simple_loss.detach().cpu().item() info["pruned_loss"] = pruned_loss.detach().cpu().item() + if is_training and params.enable_distiallation: + info["codebook_loss"] = codebook_loss.detach().cpu().item() return loss, info