[AIROCMLIR-499] Add support for sliding window masking for attention#2240
[AIROCMLIR-499] Add support for sliding window masking for attention#2240justinrosner wants to merge 12 commits intodevelopfrom
Conversation
There was a problem hiding this comment.
Pull request overview
This PR implements sliding window attention support combined with KVCache in rocMLIR. Sliding window attention masks key positions before max(0, currentSeqLen - windowSize) with -inf, effectively limiting attention to the last W positions relative to the current decoded position. This is achieved through pattern detection in TosaToRock conversion and masking logic in GridwiseGemmToBlockwise lowering.
Changes:
- Added optional
slidingWindowSizeattribute torock.attentionandrock.gridwise_attention_acceloperations with validation requiringcurrentSeqLento be set - Implemented pattern detection in TosaToRock for sliding window masks (detecting
greater(seqLen + negative_offset, col_indices)) and clip patterns on currentSeqLen - Extended GridwiseGemmToBlockwise to apply sliding window masking (independent of causal/KVCache masking) using precomputed lower bound
Reviewed changes
Copilot reviewed 11 out of 11 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| mlir/include/mlir/Dialect/Rock/IR/RockOps.td | Added slidingWindowSize optional attribute to AttentionOp and GridwiseAttentionAccelOp definitions with documentation |
| mlir/lib/Dialect/Rock/IR/RockDialect.cpp | Added validation logic for slidingWindowSize attribute (must be positive and requires currentSeqLen) |
| mlir/lib/Conversion/TosaToRock/TosaToRock.cpp | Implemented pattern detection for sliding window masks, clip pattern detection for currentSeqLen, and clip application during rewrite |
| mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp | Added sliding window masking logic with lower bound computation and masking on every iteration |
| mlir/lib/Dialect/Rock/Transforms/GemmToGridwise.cpp | Updated to pass slidingWindowSize attribute through the lowering pipeline |
| mlir/lib/Dialect/Rock/Transforms/DetectFlashDecoding.cpp | Propagated slidingWindowSize attribute when creating new AttentionOp |
| mlir/lib/Dialect/Rock/Transforms/SortDimensionsMemoryLayout.cpp | Propagated slidingWindowSize attribute during dimension sorting transformation |
| mlir/tools/rocmlir-gen/rocmlir-gen.cpp | Added nullptr for slidingWindowSize in generated attention kernels (feature not yet supported in codegen) |
| mlir/test/Dialect/Rock/gridwise_attention_accel_lowering.mlir | Added comprehensive test with FileCheck patterns validating sliding window masking behavior |
| mlir/test/Conversion/TosaToRock/tosa-to-rock-attention-kvcache.mlir | Added test for sliding window pattern detection and clip pattern handling |
| mlir/test/fusion/pr-e2e/attention/mixr-attention-sliding-window-kvcache.mlir | Added end-to-end test for sliding window attention with KVCache |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| // Validate sliding window constraints | ||
| if (getSlidingWindowSize()) { | ||
| int32_t windowSize = static_cast<int32_t>(*getSlidingWindowSize()); | ||
| if (windowSize <= 0) | ||
| return emitError("slidingWindowSize must be positive"); | ||
| if (!getCurrentSeqLen()) | ||
| return emitError("slidingWindowSize requires currentSeqLen to be set"); | ||
| } |
There was a problem hiding this comment.
Missing validation: slidingWindowSize should only be allowed when enableSoftmax is true (for attention operations). Similar to the checks for currentSeqLen (line 2826), prefixOffset (line 2829), and causal (line 2832), there should be a check that rejects slidingWindowSize when enableSoftmax is false, since sliding window masking only makes sense in the context of attention operations.
| // Sliding window masking: mask when key_pos < max(0, currentSeqLen - | ||
| // windowSize). This is independent of causal masking and applies | ||
| // alongside KV-cache masking. | ||
| if (slidingWindowSize > 0) { |
There was a problem hiding this comment.
I think all the "ifs" around setGemm0OutputOutOfScope() are not necessary, because setGemm0OutputOutOfScope() internally checks if it needs to run.
There was a problem hiding this comment.
True, but do you find that it makes the code a bit more unreadable to do that? I'm fine either way, just wondering what your opinion is.
There was a problem hiding this comment.
I'm fine either way, I just think we need to remove of of the ifs, either the outer one of the inner one
There was a problem hiding this comment.
I've removed it in all of the cases except for the if/else block for prefix causal. Because prefixCausal also requires that causal be set, running the logic for both masks would not be correct in this case.
mlir/test/fusion/pr-e2e/attention/mixr-attention-sliding-window-kvcache.mlir
Show resolved
Hide resolved
umangyadav
left a comment
There was a problem hiding this comment.
I had some similar concerns as daniels. But apart from that looks good to me.
b5382ea to
10cd27e
Compare
Motivation
This PR adds support for sliding window attention combined with KVCache in rocMLIR. Sliding window attention limits each query to attent only to the last W key positions (relative to the current decoded position), masking earlier positions with
-inf.This implements: https://amd-hub.atlassian.net/browse/AIROCMLIR-499
Technical Details
slidingWindowSizeattribute torock.attentionopscurrentSeqLenis also settrySlidingWindowPatternin TosaToRock to detect the sliding window mask patterntryClipPatternfor KVCache pattern detection in TosaToRockclip(arg, lo, hi)Test Plan
Test Result
Submission Checklist