-
Notifications
You must be signed in to change notification settings - Fork 105
Open
Description
Hello, I've noticed that in the generate_latent_batch function within the model.py file, the attention mask for the past key-value (KV) pairs is an all-ones matrix. Won't this affect the results? Because when sentences are packed into a batch, shorter sentences are padded to the same length. In that case, the attention masks corresponding to the past KV pairs should be matrices containing zeros instead of using past_mask = torch.ones(...), shouldn't they?
if past_key_values is not None:
past_len = _past_length(past_key_values)
if past_len > 0:
past_mask = torch.ones(
(attention_mask.shape[0], past_len),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
attention_mask = torch.cat([past_mask, attention_mask], dim=-1)
Metadata
Metadata
Assignees
Labels
No labels