diff --git a/model/dataset.py b/model/dataset.py index 12a35a8..6f60a14 100644 --- a/model/dataset.py +++ b/model/dataset.py @@ -31,12 +31,16 @@ def __init__(self, datapath: str, window_size: int, vocab_size: int, def __iter__(self): for line_idx in range(len(self.data)): - line = self.data[line_idx].strip().split(' ') - start = randint(0, len(line)-self.window_size-1) + line_tokens = self.data[line_idx].strip().split(' ') + + if len(line_tokens) <= self.window_size: + continue + + start = randint(0, len(line_tokens) - self.window_size - 1) end = start + self.window_size + 1 - ids = LongTensor([int(x) for x in line[start:end]]) - ignore = (ids==self.unk_token).float() + ids = LongTensor([int(x) for x in line_tokens[start:end]]) + ignore = (ids == self.unk_token).float() yield ids[:-1], ids[1:], ignore[:-1] diff --git a/model/tokenizer.py b/model/tokenizer.py index 98a0619..449e8fc 100644 --- a/model/tokenizer.py +++ b/model/tokenizer.py @@ -220,8 +220,12 @@ def load(path: str) -> 'BytePairTokenizer': @staticmethod - def train_bpe(filepaths: List[str], mincount: int, merges: int) \ - -> 'BytePairTokenizer': + def train_bpe( + filepaths: List[str], + mincount: int, + merges: int, + ) -> 'BytePairTokenizer': + """ Create trained byte pair tokenizer Args: