Open
Conversation
Collaborator
Author
|
Of course, the current solution is only applicable to the case of BLOCK_M==BLOCK_N. When the size of BLOCK_N is inconsistent, the following solution can be considered. max_m = tl.minimum((start_m + 1) * BLOCK_M, seq_len)
num_kv_blocks = tl.cdiv(max_m, BLOCK_N)
for block_n in range(num_kv_blocks): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
I noticed that the FlashAttention kernel has some redundant calculations during the calculation of qk. I have made modifications to this and provided testing code to support it.attention.py
The original KV block traversal logic is as follows. Each Q block will traverse all KV blocks. In the case of variable length sequences, if the length difference between sequences is too large, it will result in a large amount of redundancy. In the prefill stage, the current Q block only needs to traverse the current block and the previous KV block, without continuing to traverse the subsequent KV blocks.
optimized:
I conducted performance testing on this, and the performance testing code is shown below.
Compared with the original FlashAttention, multiple test cases with different length sequences were set up. The output result is shown below:
When the length difference of the input sequence is not significant, the performance is basically the same. The larger the difference in sequence length within the group, the more obvious the performance optimization.