Skip to content

Llama2_7b Example Will Crash When the Model Outputs Too Many Words #378

@shenzhiy21

Description

@shenzhiy21

How to Reproduce

Just make the model keep generating new words and non-stop, until the generated sequence length exceeds the default seq_len.

For example, change the prompt into

prompt = 'a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a'

and it will crash after generating 1022 tokens:

    local_cache = val_cache.select(0, l).narrow(0, pos, 3)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: start (1022) + length (3) exceeds dimension size (1024).

How to Fix

The bug is due to the construction of local_cache:

local_cache = val_cache.select(0, l).narrow(0, pos, 3)

when pos = seq_len - 2, using val_cache for this in-place construction for local_cache will cause an error.

For a quick (but perhaps not "beautiful") fix, just change line 74 into

val_cache = torch.zeros([n_layers, seq_len + 3, dim], dtype=data_type, device=device).clone()

to reserve more place for local_cache.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions