diff --git a/preprocess.py b/preprocess.py index 54e7efa0..0aebd64a 100644 --- a/preprocess.py +++ b/preprocess.py @@ -119,7 +119,7 @@ def unk_consumer(word, idx): tokens_list.append(tokens.numpy()) with open(output_file, 'wb') as outf: - pickle.dump(tokens_list, outf, protocol=pickle.HIGHEST_PROTOCOL) + pickle.dump(tokens_list, outf, protocol=pickle.DEFAULT_PROTOCOL) if not args.quiet: logging.info('Built a binary dataset for {}: {} sentences, {} tokens, {:.3f}% replaced by unknown token'.format( input_file, nsent, ntok, 100.0 * sum(unk_counter.values()) / ntok, dictionary.unk_word)) diff --git a/seq2seq/models/lstm.py b/seq2seq/models/lstm.py index 0ce28a37..95077a36 100644 --- a/seq2seq/models/lstm.py +++ b/seq2seq/models/lstm.py @@ -117,6 +117,7 @@ def forward(self, src_tokens, src_lengths): batch_size, src_time_steps = src_tokens.size() if self.is_cuda: src_tokens = utils.move_to_cuda(src_tokens) + src_lengths = utils.move_to_cuda(src_lengths) src_embeddings = self.embedding(src_tokens) _src_embeddings = F.dropout(src_embeddings, p=self.dropout_in, training=self.training) @@ -124,7 +125,7 @@ def forward(self, src_tokens, src_lengths): src_embeddings = _src_embeddings.transpose(0, 1) # Pack embedded tokens into a PackedSequence - packed_source_embeddings = nn.utils.rnn.pack_padded_sequence(src_embeddings, src_lengths) + packed_source_embeddings = nn.utils.rnn.pack_padded_sequence(src_embeddings, src_lengths.cpu()) # Pass source input through the recurrent layer(s) packed_outputs, (final_hidden_states, final_cell_states) = self.lstm(packed_source_embeddings) diff --git a/train.py b/train.py index ff8c1e56..fe4a6afe 100644 --- a/train.py +++ b/train.py @@ -16,7 +16,7 @@ def get_args(): """ Defines training-specific hyper-parameters. """ parser = argparse.ArgumentParser('Sequence to Sequence Model') - parser.add_argument('--cuda', default=False, help='Use a GPU') + parser.add_argument('--cuda', action='store_true', help='Use a GPU') # Add data arguments parser.add_argument('--data', default='indomain/preprocessed_data/', help='path to data directory') diff --git a/translate.py b/translate.py index cf70da83..ce632d87 100644 --- a/translate.py +++ b/translate.py @@ -15,7 +15,7 @@ def get_args(): """ Defines generation-specific hyper-parameters. """ parser = argparse.ArgumentParser('Sequence to Sequence Model') - parser.add_argument('--cuda', default=False, help='Use a GPU') + parser.add_argument('--cuda', action='store_true', help='Use a GPU') parser.add_argument('--seed', default=42, type=int, help='pseudo random number generator seed') # Add data arguments