Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def unk_consumer(word, idx):
tokens_list.append(tokens.numpy())

with open(output_file, 'wb') as outf:
pickle.dump(tokens_list, outf, protocol=pickle.HIGHEST_PROTOCOL)
pickle.dump(tokens_list, outf, protocol=pickle.DEFAULT_PROTOCOL)
if not args.quiet:
logging.info('Built a binary dataset for {}: {} sentences, {} tokens, {:.3f}% replaced by unknown token'.format(
input_file, nsent, ntok, 100.0 * sum(unk_counter.values()) / ntok, dictionary.unk_word))
Expand Down
3 changes: 2 additions & 1 deletion seq2seq/models/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,14 +117,15 @@ def forward(self, src_tokens, src_lengths):
batch_size, src_time_steps = src_tokens.size()
if self.is_cuda:
src_tokens = utils.move_to_cuda(src_tokens)
src_lengths = utils.move_to_cuda(src_lengths)
src_embeddings = self.embedding(src_tokens)
_src_embeddings = F.dropout(src_embeddings, p=self.dropout_in, training=self.training)

# Transpose batch: [batch_size, src_time_steps, num_features] -> [src_time_steps, batch_size, num_features]
src_embeddings = _src_embeddings.transpose(0, 1)

# Pack embedded tokens into a PackedSequence
packed_source_embeddings = nn.utils.rnn.pack_padded_sequence(src_embeddings, src_lengths)
packed_source_embeddings = nn.utils.rnn.pack_padded_sequence(src_embeddings, src_lengths.cpu())

# Pass source input through the recurrent layer(s)
packed_outputs, (final_hidden_states, final_cell_states) = self.lstm(packed_source_embeddings)
Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
def get_args():
""" Defines training-specific hyper-parameters. """
parser = argparse.ArgumentParser('Sequence to Sequence Model')
parser.add_argument('--cuda', default=False, help='Use a GPU')
parser.add_argument('--cuda', action='store_true', help='Use a GPU')

# Add data arguments
parser.add_argument('--data', default='indomain/preprocessed_data/', help='path to data directory')
Expand Down
2 changes: 1 addition & 1 deletion translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
def get_args():
""" Defines generation-specific hyper-parameters. """
parser = argparse.ArgumentParser('Sequence to Sequence Model')
parser.add_argument('--cuda', default=False, help='Use a GPU')
parser.add_argument('--cuda', action='store_true', help='Use a GPU')
parser.add_argument('--seed', default=42, type=int, help='pseudo random number generator seed')

# Add data arguments
Expand Down