diff --git a/model/sequencer.py b/model/sequencer.py index a70fcb6..a04e5fc 100644 --- a/model/sequencer.py +++ b/model/sequencer.py @@ -158,8 +158,12 @@ def pad_token_ids(self, token_ids: Tensor, pad_id: int) -> Tensor: """ pad_size = self.window_size - token_ids.size(1) - padding = full((1, pad_size), pad_id, dtype=long_) - return cat([token_ids, padding.to(device=self.device)], dim=1) + if pad_size > 0: + padding = full((1, pad_size), pad_id, dtype=long_, device=self.device) + token_ids = cat([token_ids, padding], dim=1) + elif pad_size < 0: + token_ids = token_ids[:, -self.window_size:] + return token_ids def generate_text(self, tokens: List[str]) -> str: