Skip to content

Use and benchmark FlashAttention#47

Closed
rubencart wants to merge 3 commits intomainfrom
benchmark-flash-attn
Closed

Use and benchmark FlashAttention#47
rubencart wants to merge 3 commits intomainfrom
benchmark-flash-attn

Conversation

@rubencart
Copy link
Copy Markdown
Collaborator

@rubencart rubencart commented Oct 30, 2025

Hi y'all

We can reject this PR, it's just for testing :).

Summary:

  • for short and long sequences, in inference or training, setting the attn_mask to None if all of its elements are True usually gives a small speed increase and sometimes a memory reduction,
  • especially for bfloat16 and in the case of equal length (long) sequences since then FlashAttention will be used,
  • using bfloat16 gives quite large speed and memory efficiency increases, in inference and training, especially for longer sequences but also memory improvement for shorter ones,
  • using flash_attn_varlen_qkvpacked_func (which needs bfloat16 or float16) relative to scaled_dot_product_attention+bfloat16 only gives a further advantage in case of long & variable length sequences
  • if no other changes, upgrading to pytorch 2.2 gives a small speed increase (but no memory reduction)

So,

  • it seems worth it to at least set the mask to None if all of its elements are True (I can do that in a separate PR)
  • You could consider upgrading to pytorch 2.2 (or further)
  • We could test whether performance degrades when using bfloat16 instead of full precision (in pretraining and/or finetuning and/or inference), there are large potential efficiency gains even for Presto's short sequences
  • adding flash-attn as a dependency (takes a while to compile and I had to resolve some libc version inconsistency) does not seem worth it for Presto. If Galileo has much longer sequences, it could be worth it for Galileo?

Long version

So I noticed that F.scaled_dot_product_attention does not use FlashAttention if its attn_mask argument is not equal to None. We need the attn_mask argument for batches of sequences of different lengths.

You can verify this by adding the context manager around the function call to force the use of FlashAttention: an error will be raised.

                with torch.backends.cuda.sdp_kernel(
                    enable_flash=True, enable_math=False, enable_mem_efficient=False
                ):
                    x = F.scaled_dot_product_attention(
                        q,
                        k,
                        v,
                        # a value of True indicates that the element should take part in attention
                        attn_mask=None,
                        dropout_p=self.attn_drop.p,
                    )

There seem to be a few options to work around the limitation and use a faster attention anyway.

  • Use torch.nested tensors, which get scaled_dot_product_attention support starting from pytorch 2.2 according to the release notes. How to do attention with them is shown here. I tried for a bit, but I didn't manage to convert a padded tensor into a nested tensor without a for-loop. Returning nested tensors directly from your dataloader seems like a pain because this would require you to update a lot of upstream code and nested tensors have far less defined operations available.
    Note that the documentation starting from 2.7 also states "NJTs are designed to be used with torch.compile() for optimal performance".
  • Use a package like FlexAttention, this requires pytorch 2.5. Requires torch.compile(). Allows you to define all kinds of attention score biases like relative positions, but this might be a bit overkill for Presto.
  • The simplest option seems to be to use the original FlashAttention implementations in the flash-attn package. Requires pytorch 2.2, CUDA 12.0, with only certain GPU types supported, and the compiling/installation takes a while.

Training

I implemented the latter option and ran a few tests to compare training speed and peak GPU memory usage with the supplied dw_144_mini_shard_44.tar data.
The file train_time.py contains code to time a 100 forward+backward passes with torch.utils.benchmark.

The training always has sequences of equal length (I think?) so the mask actually isn't necessary in that case. We can add a line of code to set it to None when it only contains True (but with short sequences the gain is minimal, see timings with set to None below).
I added code to set the mask so it results in sequences of different length (these rows are marked with varlen mask).

Rows with bfloat16 first cast tensors and the model weights to bfloat16, which is necessary for FlashAttention to work. All runs are on an A100.
Note that var. len is always faster because more elements are masked out. This mirrors the situation during inference, if sequences are of different length. Eq. length mirrors training, with sequences of equal length.
Also note that setting the attention mask to None in case of equal length sequences with scaled_dot_product_attention should result in pytorch using FlashAttention behind the scenes (unless it judges that one of the other implementations should be even faster).

Pytorch 2.0:

Attn extra dtype mask time (ms) peak mem (GB)
scaled_dot_product_attention / float32 eq. len 176.05 6.2
scaled_dot_product_attention set to None float32 158.21 5.2
scaled_dot_product_attention / bfloat16 107.86 3.2
scaled_dot_product_attention set to None bfloat16 98.34 2.7
scaled_dot_product_attention / float32 var. len 143.34 4.4
scaled_dot_product_attention / bfloat16 92.19 2.3

Pytorch 2.2:

Attn extra dtype mask time (ms) peak mem (GB)
scaled_dot_product_attention / float32 eq. len 144.21 6.2
scaled_dot_product_attention set to None float32 137.90 6.2
scaled_dot_product_attention / bfloat16 90.60 3.2
scaled_dot_product_attention set to None bfloat16 87.40 3.2
flash_attn_varlen_qkvpacked_func / bfloat16 101.68 2.8
scaled_dot_product_attention / float32 var. len 122.47 4.4
scaled_dot_product_attention / bfloat16 86.09 2.3
flash_attn_varlen_qkvpacked_func / bfloat16 91.36 2.1

Casting to bfloat16 gives more speedup than using FlashAttention (also upgrading to pytorch 2.2 might be worth it?).
I think this is mainly caused by the fact that sequence lengths in Presto are quite short. You can see in the plots here that FlashAttention is especially fast for longer sequences.
A forward pass of 12 months with a mask_ratio of 0.0 gives a sequence length of 110 for x the Attention forward method, which is quite short compared to the minimum length of 512 that they test for in the plots in the flash-attn repo.

To verify this, I added code to train_time.py to make the sequence 10 or 20 times longer (2162/4324 without masking, lowering batch size to 64/32). This gives the following. All runs with pytorch 2.2 (and with set to None added).

Attn dtype mask time (ms) 10x time (ms) 20x mem (GB) 10x mem (GB) 20x
scaled_dot_product_attention / float32 eq. 160.55 290.70 7.4 16.4
scaled_dot_product_attention to None float32 171.42 290.70 8.7 (?) 16.4
scaled_dot_product_attention / bfloat16 72.61 112.61 4.7 8.8
scaled_dot_product_attention to None bfloat16 41.73 49.38 1.4 1.4
flash_attn_varlen_qkvpacked_func / bfloat16 46.50 52.24 1.4 1.4
scaled_dot_product_attention / float32 var. 101.55 150.77 2.8 3.5
scaled_dot_product_attention / bfloat16 39.53 47.12 1.5 1.8
flash_attn_varlen_qkvpacked_func / bfloat16 38.40 41.25 1.1 1.1

The gain only becomes really significant for very long sequences.

Inference

inference_time.py does the same but for inference. For default length:

Attn extra dtype mask time (ms) peak mem (MB)
scaled_dot_product_attention / float32 eq. len 26.02 1564
scaled_dot_product_attention set to None float32 34.53 888
scaled_dot_product_attention / bfloat16 25.75 848
scaled_dot_product_attention set to None bfloat16 28.15 449
flash_attn_varlen_qkvpacked_func / bfloat16 30.44 452
scaled_dot_product_attention / float32 var. len 31.52 650
scaled_dot_product_attention / bfloat16 27.96 326
flash_attn_varlen_qkvpacked_func / bfloat16 28.53 268

And for 10x and 20x length:

Attn dtype mask time (ms) 10x time (ms) 20x mem (MB) 10x mem (MB) 20x
scaled_dot_product_attention to None float32 eq. 23.38 35.58 454 454
scaled_dot_product_attention to None bfloat16 7.94 8.81 231 231
flash_attn_varlen_qkvpacked_func / bfloat16 9.42 10.00 233 233
scaled_dot_product_attention / float32 var. 15.39 21.52 1121 (*) 1976 (*)
scaled_dot_product_attention / bfloat16 9.21 11.00 593 1072
flash_attn_varlen_qkvpacked_func / bfloat16 7.56 7.28 120 117

(*) not sure why these are higher, but I'm tired of running this code 😬

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant