diff --git a/data_utils.py b/data_utils.py index e16c895..88a2115 100644 --- a/data_utils.py +++ b/data_utils.py @@ -66,10 +66,10 @@ def create_vocabulary(vocabulary_path, data_path, max_vocabulary_size, tokens = tokenizer(line) if tokenizer else basic_tokenizer(line) for w in tokens: word = re.sub(_DIGIT_RE, b"0", w) if normalize_digits else w - if word in vocab: - vocab[word] += 1 - else: + if word not in vocab: vocab[word] = 1 + else: + vocab[word] += 1 vocab_list = _START_VOCAB + sorted(vocab, key=vocab.get, reverse=True) print('>> Full Vocabulary Size :',len(vocab_list)) if len(vocab_list) > max_vocabulary_size: diff --git a/seq2seq_model.py b/seq2seq_model.py index 5b9f39b..bbb6407 100644 --- a/seq2seq_model.py +++ b/seq2seq_model.py @@ -21,6 +21,7 @@ import random +from tqdm import tqdm import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf