-
Notifications
You must be signed in to change notification settings - Fork 48
Open
Description
How to Reproduce
Just make the model keep generating new words and non-stop, until the generated sequence length exceeds the default seq_len.
For example, change the prompt into
prompt = 'a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a'and it will crash after generating 1022 tokens:
local_cache = val_cache.select(0, l).narrow(0, pos, 3)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: start (1022) + length (3) exceeds dimension size (1024).
How to Fix
The bug is due to the construction of local_cache:
local_cache = val_cache.select(0, l).narrow(0, pos, 3)
when pos = seq_len - 2, using val_cache for this in-place construction for local_cache will cause an error.
For a quick (but perhaps not "beautiful") fix, just change line 74 into
val_cache = torch.zeros([n_layers, seq_len + 3, dim], dtype=data_type, device=device).clone()
to reserve more place for local_cache.
Metadata
Metadata
Assignees
Labels
No labels