diff --git a/model.py b/model.py index 4c7d7d267..4c96f4cb2 100644 --- a/model.py +++ b/model.py @@ -200,6 +200,41 @@ def inference(self, x): return outputs +class DecoderState(): + """ Initializes attention rnn states, decoder rnn states, attention + weights, attention cumulative weights, attention context, stores memory + and stores processed memory + PARAMS + ------ + memory: Encoder outputs + mask: Mask for padded data if training, expects None for inference + """ + def __init__(self, memory, mask, + attention_rnn_dim, decoder_rnn_dim, + encoder_embedding_dim, attention_layer): + B = memory.size(0) + MAX_TIME = memory.size(1) + + self.attention_hidden = Variable(memory.data.new( + B, attention_rnn_dim).zero_()) + self.attention_cell = Variable(memory.data.new( + B, attention_rnn_dim).zero_()) + + self.decoder_hidden = Variable(memory.data.new( + B, decoder_rnn_dim).zero_()) + self.decoder_cell = Variable(memory.data.new( + B, decoder_rnn_dim).zero_()) + + self.attention_weights = Variable(memory.data.new( + B, MAX_TIME).zero_()) + self.attention_weights_cum = Variable(memory.data.new( + B, MAX_TIME).zero_()) + self.attention_context = Variable(memory.data.new( + B, encoder_embedding_dim).zero_()) + + self.memory = memory + self.processed_memory = attention_layer.memory_layer(memory) + self.mask = mask class Decoder(nn.Module): def __init__(self, hparams): @@ -263,30 +298,15 @@ def initialize_decoder_states(self, memory, mask): ------ memory: Encoder outputs mask: Mask for padded data if training, expects None for inference - """ - B = memory.size(0) - MAX_TIME = memory.size(1) - - self.attention_hidden = Variable(memory.data.new( - B, self.attention_rnn_dim).zero_()) - self.attention_cell = Variable(memory.data.new( - B, self.attention_rnn_dim).zero_()) - - self.decoder_hidden = Variable(memory.data.new( - B, self.decoder_rnn_dim).zero_()) - self.decoder_cell = Variable(memory.data.new( - B, self.decoder_rnn_dim).zero_()) - self.attention_weights = Variable(memory.data.new( - B, MAX_TIME).zero_()) - self.attention_weights_cum = Variable(memory.data.new( - B, MAX_TIME).zero_()) - self.attention_context = Variable(memory.data.new( - B, self.encoder_embedding_dim).zero_()) - - self.memory = memory - self.processed_memory = self.attention_layer.memory_layer(memory) - self.mask = mask + RETURNS + ------- + decoder_state + """ + decoder_state = DecoderState(memory, mask, + self.attention_rnn_dim, self.decoder_rnn_dim, + self.encoder_embedding_dim, self.attention_layer) + return decoder_state def parse_decoder_inputs(self, decoder_inputs): """ Prepares decoder inputs, i.e. mel outputs @@ -337,46 +357,51 @@ def parse_decoder_outputs(self, mel_outputs, gate_outputs, alignments): return mel_outputs, gate_outputs, alignments - def decode(self, decoder_input): + def decode(self, decoder_input, decoder_state): """ Decoder step using stored states, attention and memory PARAMS ------ decoder_input: previous mel output - + decoder_state: decoder states + RETURNS ------- mel_output: gate_output: gate output energies attention_weights: """ - cell_input = torch.cat((decoder_input, self.attention_context), -1) - self.attention_hidden, self.attention_cell = self.attention_rnn( - cell_input, (self.attention_hidden, self.attention_cell)) - self.attention_hidden = F.dropout( - self.attention_hidden, self.p_attention_dropout, self.training) + cell_input = torch.cat((decoder_input, decoder_state.attention_context), -1) + decoder_state.attention_hidden, decoder_state.attention_cell = self.attention_rnn( + cell_input, (decoder_state.attention_hidden, decoder_state.attention_cell)) + decoder_state.attention_hidden = F.dropout( + decoder_state.attention_hidden, self.p_attention_dropout, self.training) + decoder_state.attention_cell = F.dropout( + decoder_state.attention_cell, self.p_attention_dropout, self.training) attention_weights_cat = torch.cat( - (self.attention_weights.unsqueeze(1), - self.attention_weights_cum.unsqueeze(1)), dim=1) - self.attention_context, self.attention_weights = self.attention_layer( - self.attention_hidden, self.memory, self.processed_memory, - attention_weights_cat, self.mask) + (decoder_state.attention_weights.unsqueeze(1), + decoder_state.attention_weights_cum.unsqueeze(1)), dim=1) + decoder_state.attention_context, decoder_state.attention_weights = self.attention_layer( + decoder_state.attention_hidden, decoder_state.memory, decoder_state.processed_memory, + attention_weights_cat, decoder_state.mask) - self.attention_weights_cum += self.attention_weights + decoder_state.attention_weights_cum += decoder_state.attention_weights decoder_input = torch.cat( - (self.attention_hidden, self.attention_context), -1) - self.decoder_hidden, self.decoder_cell = self.decoder_rnn( - decoder_input, (self.decoder_hidden, self.decoder_cell)) - self.decoder_hidden = F.dropout( - self.decoder_hidden, self.p_decoder_dropout, self.training) + (decoder_state.attention_hidden, decoder_state.attention_context), -1) + decoder_state.decoder_hidden, decoder_state.decoder_cell = self.decoder_rnn( + decoder_input, (decoder_state.decoder_hidden, decoder_state.decoder_cell)) + decoder_state.decoder_hidden = F.dropout( + decoder_state.decoder_hidden, self.p_decoder_dropout, self.training) + decoder_state.decoder_cell = F.dropout( + decoder_state.decoder_cell, self.p_decoder_dropout, self.training) decoder_hidden_attention_context = torch.cat( - (self.decoder_hidden, self.attention_context), dim=1) + (decoder_state.decoder_hidden, decoder_state.attention_context), dim=1) decoder_output = self.linear_projection( decoder_hidden_attention_context) gate_prediction = self.gate_layer(decoder_hidden_attention_context) - return decoder_output, gate_prediction, self.attention_weights + return decoder_output, gate_prediction, decoder_state.attention_weights def forward(self, memory, decoder_inputs, memory_lengths): """ Decoder forward pass for training @@ -398,14 +423,14 @@ def forward(self, memory, decoder_inputs, memory_lengths): decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0) decoder_inputs = self.prenet(decoder_inputs) - self.initialize_decoder_states( + decoder_state = self.initialize_decoder_states( memory, mask=~get_mask_from_lengths(memory_lengths)) mel_outputs, gate_outputs, alignments = [], [], [] while len(mel_outputs) < decoder_inputs.size(0) - 1: decoder_input = decoder_inputs[len(mel_outputs)] mel_output, gate_output, attention_weights = self.decode( - decoder_input) + decoder_input, decoder_state) mel_outputs += [mel_output.squeeze(1)] gate_outputs += [gate_output.squeeze()] alignments += [attention_weights] @@ -429,12 +454,12 @@ def inference(self, memory): """ decoder_input = self.get_go_frame(memory) - self.initialize_decoder_states(memory, mask=None) + decoder_state = self.initialize_decoder_states(memory, mask=None) mel_outputs, gate_outputs, alignments = [], [], [] while True: decoder_input = self.prenet(decoder_input) - mel_output, gate_output, alignment = self.decode(decoder_input) + mel_output, gate_output, alignment = self.decode(decoder_input, decoder_state) mel_outputs += [mel_output.squeeze(1)] gate_outputs += [gate_output]