diff --git a/model/dataset.py b/model/dataset.py index 6f60a14..73e4e0c 100644 --- a/model/dataset.py +++ b/model/dataset.py @@ -31,7 +31,7 @@ def __init__(self, datapath: str, window_size: int, vocab_size: int, def __iter__(self): for line_idx in range(len(self.data)): - line_tokens = self.data[line_idx].strip().split(' ') + line_tokens = self.data[line_idx].split() if len(line_tokens) <= self.window_size: continue