Skip to content

about past_KV #28

@analytistic

Description

@analytistic

Hello, I've noticed that in the generate_latent_batch function within the model.py file, the attention mask for the past key-value (KV) pairs is an all-ones matrix. Won't this affect the results? Because when sentences are packed into a batch, shorter sentences are padded to the same length. In that case, the attention masks corresponding to the past KV pairs should be matrices containing zeros instead of using past_mask = torch.ones(...), shouldn't they?

  if past_key_values is not None:
            past_len = _past_length(past_key_values)
            if past_len > 0:
                past_mask = torch.ones(
                    (attention_mask.shape[0], past_len),
                    dtype=attention_mask.dtype,
                    device=attention_mask.device,
                )
                attention_mask = torch.cat([past_mask, attention_mask], dim=-1)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions