From b1d439377f52ed4400d02e1d818ec16547ec04de Mon Sep 17 00:00:00 2001 From: SYSTEMS-OPERATOR <155610697+SYSTEMS-OPERATOR@users.noreply.github.com> Date: Thu, 26 Jun 2025 23:54:52 -0400 Subject: [PATCH] Fix dataset window bug and tokenizer type --- model/dataset.py | 12 ++++++++---- model/tokenizer.py | 7 +++++-- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/model/dataset.py b/model/dataset.py index 12a35a8..6f60a14 100644 --- a/model/dataset.py +++ b/model/dataset.py @@ -31,12 +31,16 @@ def __init__(self, datapath: str, window_size: int, vocab_size: int, def __iter__(self): for line_idx in range(len(self.data)): - line = self.data[line_idx].strip().split(' ') - start = randint(0, len(line)-self.window_size-1) + line_tokens = self.data[line_idx].strip().split(' ') + + if len(line_tokens) <= self.window_size: + continue + + start = randint(0, len(line_tokens) - self.window_size - 1) end = start + self.window_size + 1 - ids = LongTensor([int(x) for x in line[start:end]]) - ignore = (ids==self.unk_token).float() + ids = LongTensor([int(x) for x in line_tokens[start:end]]) + ignore = (ids == self.unk_token).float() yield ids[:-1], ids[1:], ignore[:-1] diff --git a/model/tokenizer.py b/model/tokenizer.py index 9f6cfa9..092b04e 100644 --- a/model/tokenizer.py +++ b/model/tokenizer.py @@ -220,8 +220,11 @@ def load(path: str) -> 'BytePairTokenizer': @staticmethod - def train_bpe(filepaths: List[str], mincount: int, merges: int) \ - -> 'BytePairtokenizer': + def train_bpe( + filepaths: List[str], + mincount: int, + merges: int, + ) -> 'BytePairTokenizer': """ Create trained byte pair tokenizer Args: