Closed
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Hi y'all
We can reject this PR, it's just for testing :).
Summary:
flash_attn_varlen_qkvpacked_func(which needs bfloat16 or float16) relative toscaled_dot_product_attention+bfloat16only gives a further advantage in case of long & variable length sequencesSo,
Long version
So I noticed that
F.scaled_dot_product_attentiondoes not use FlashAttention if itsattn_maskargument is not equal toNone. We need theattn_maskargument for batches of sequences of different lengths.You can verify this by adding the context manager around the function call to force the use of FlashAttention: an error will be raised.
There seem to be a few options to work around the limitation and use a faster attention anyway.
torch.nestedtensors, which getscaled_dot_product_attentionsupport starting from pytorch 2.2 according to the release notes. How to do attention with them is shown here. I tried for a bit, but I didn't manage to convert a padded tensor into a nested tensor without a for-loop. Returning nested tensors directly from your dataloader seems like a pain because this would require you to update a lot of upstream code and nested tensors have far less defined operations available.Note that the documentation starting from 2.7 also states "NJTs are designed to be used with torch.compile() for optimal performance".
torch.compile(). Allows you to define all kinds of attention score biases like relative positions, but this might be a bit overkill for Presto.flash-attnpackage. Requires pytorch 2.2, CUDA 12.0, with only certain GPU types supported, and the compiling/installation takes a while.Training
I implemented the latter option and ran a few tests to compare training speed and peak GPU memory usage with the supplied
dw_144_mini_shard_44.tardata.The file train_time.py contains code to time a 100 forward+backward passes with
torch.utils.benchmark.The training always has sequences of equal length (I think?) so the mask actually isn't necessary in that case. We can add a line of code to set it to
Nonewhen it only contains True (but with short sequences the gain is minimal, see timings withset to Nonebelow).I added code to set the mask so it results in sequences of different length (these rows are marked with
varlen mask).Rows with
bfloat16first cast tensors and the model weights tobfloat16, which is necessary for FlashAttention to work. All runs are on an A100.Note that var. len is always faster because more elements are masked out. This mirrors the situation during inference, if sequences are of different length. Eq. length mirrors training, with sequences of equal length.
Also note that setting the attention mask to
Nonein case of equal length sequences with scaled_dot_product_attention should result in pytorch using FlashAttention behind the scenes (unless it judges that one of the other implementations should be even faster).Pytorch 2.0:
Pytorch 2.2:
Casting to
bfloat16gives more speedup than using FlashAttention (also upgrading to pytorch 2.2 might be worth it?).I think this is mainly caused by the fact that sequence lengths in Presto are quite short. You can see in the plots here that FlashAttention is especially fast for longer sequences.
A forward pass of 12 months with a mask_ratio of 0.0 gives a sequence length of 110 for
xthe Attention forward method, which is quite short compared to the minimum length of 512 that they test for in the plots in the flash-attn repo.To verify this, I added code to
train_time.pyto make the sequence 10 or 20 times longer (2162/4324 without masking, lowering batch size to 64/32). This gives the following. All runs with pytorch 2.2 (and withset to Noneadded).The gain only becomes really significant for very long sequences.
Inference
inference_time.pydoes the same but for inference. For default length:And for 10x and 20x length:
(*) not sure why these are higher, but I'm tired of running this code 😬