Skip to content

[Bug]: Attention Mask Ignored in transformer_engine Backend with Packed Sequences (Attention Leakage) #2357

@BlackSamorez

Description

@BlackSamorez

Summary

When training pretrain_gpt.py with sequence packing enabled (--reset-position-ids and --reset-attention-mask) and using the --transformer-impl transformer_engine backend, the custom block-diagonal attention mask generated by GPTDataset is effectively ignored.

The Transformer Engine (TE) layer defaults to attn_mask_type='causal', which causes it to disregard the attention_mask tensor passed during the forward pass. This results in silent attention leakage between unrelated documents within a packed sequence.

Reproduction Steps

Run pretrain_gpt.py with the following combination of flags:

python pretrain_gpt.py \
    --transformer-impl transformer_engine \
    --reset-position-ids \
    --reset-attention-mask \

Root Cause Analysis

1. Dataset Does Not Provide cu_seqlens

GPTDataset generates a dense boolean (or FP8) attention_mask tensor to handle document boundaries. It does not calculate or return cu_seqlens (cumulative sequence lengths) or PackedSeqParams.

2. TE Defaults to Causal Masking

The Transformer Layer is initialized with an attention mask type of causal.

3. API Contract Violation

According to the Transformer Engine documentation, the attention_mask argument in the forward pass is conditional:

Argument attention_mask in the forward call is only used when attn_mask_type includes ‘“padding”’ or “arbitrary”.

Because the configuration remains 'causal', TE invokes the underlying kernel (FlashAttention) with is_causal=True and no custom mask. This applies a standard lower-triangular mask over the entire packed sequence buffer (0..args.seq_length), allowing tokens in Document B to attend to tokens in Document A.

Moreover, if --reset-position-ids was used, the documents will have overlapping positions, making the attention confuse them.

Impact

  • Correctness: The autoregressive independence assumption is violated for packed sequences.
  • Silent Failure: The model trains without error, but gradients are computed based on invalid context.

Proposed Solution

The model initialization logic needs to detect if the user has requested a custom mask (via --reset-attention-mask) and configure the TE layer accordingly.

Suggested Logic:
If args.reset_attention_mask is True, the attn_mask_type passed to te.pytorch.TransformerLayer must be forced to 'arbitrary'. This forces TE to utilize the attention_mask tensor provided by the dataset. Or just rewrite the entire pretrain_gpt.py up to date with PackedSeqParams IDK. Maybe add more asserts on transformer engine arguments.

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions