Skip to content

Comments

[AIROCMLIR-499] Add support for sliding window masking for attention#2240

Open
justinrosner wants to merge 12 commits intodevelopfrom
499-sliding-window-tosa-to-rock
Open

[AIROCMLIR-499] Add support for sliding window masking for attention#2240
justinrosner wants to merge 12 commits intodevelopfrom
499-sliding-window-tosa-to-rock

Conversation

@justinrosner
Copy link
Contributor

@justinrosner justinrosner commented Feb 17, 2026

Motivation

This PR adds support for sliding window attention combined with KVCache in rocMLIR. Sliding window attention limits each query to attent only to the last W key positions (relative to the current decoded position), masking earlier positions with -inf.

This implements: https://amd-hub.atlassian.net/browse/AIROCMLIR-499

Technical Details

  • Added an optional slidingWindowSize attribute to rock.attention ops
    • This value must be positive, and can only be set when currentSeqLen is also set
  • Added new trySlidingWindowPattern in TosaToRock to detect the sliding window mask pattern
  • Also added tryClipPattern for KVCache pattern detection in TosaToRock
    • The currentSeqLen value in MIGraphX IR may be wrapped in a clip(arg, lo, hi)
    • This pattern matcher tries to match those bounds, which then allows for the rewrite phase to broadcast the proper clipped values before passing the tensor to the attention op
  • Added sliding window logic to GridwiseGemmToBlockwise

Test Plan

  • MIGraphX IR for sliding window attention can compile and produce correct results
  • Nightly CI

Test Result

  • MIGraphX sliding window attention
  • Nightly CI

Submission Checklist

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR implements sliding window attention support combined with KVCache in rocMLIR. Sliding window attention masks key positions before max(0, currentSeqLen - windowSize) with -inf, effectively limiting attention to the last W positions relative to the current decoded position. This is achieved through pattern detection in TosaToRock conversion and masking logic in GridwiseGemmToBlockwise lowering.

Changes:

  • Added optional slidingWindowSize attribute to rock.attention and rock.gridwise_attention_accel operations with validation requiring currentSeqLen to be set
  • Implemented pattern detection in TosaToRock for sliding window masks (detecting greater(seqLen + negative_offset, col_indices)) and clip patterns on currentSeqLen
  • Extended GridwiseGemmToBlockwise to apply sliding window masking (independent of causal/KVCache masking) using precomputed lower bound

Reviewed changes

Copilot reviewed 11 out of 11 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
mlir/include/mlir/Dialect/Rock/IR/RockOps.td Added slidingWindowSize optional attribute to AttentionOp and GridwiseAttentionAccelOp definitions with documentation
mlir/lib/Dialect/Rock/IR/RockDialect.cpp Added validation logic for slidingWindowSize attribute (must be positive and requires currentSeqLen)
mlir/lib/Conversion/TosaToRock/TosaToRock.cpp Implemented pattern detection for sliding window masks, clip pattern detection for currentSeqLen, and clip application during rewrite
mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp Added sliding window masking logic with lower bound computation and masking on every iteration
mlir/lib/Dialect/Rock/Transforms/GemmToGridwise.cpp Updated to pass slidingWindowSize attribute through the lowering pipeline
mlir/lib/Dialect/Rock/Transforms/DetectFlashDecoding.cpp Propagated slidingWindowSize attribute when creating new AttentionOp
mlir/lib/Dialect/Rock/Transforms/SortDimensionsMemoryLayout.cpp Propagated slidingWindowSize attribute during dimension sorting transformation
mlir/tools/rocmlir-gen/rocmlir-gen.cpp Added nullptr for slidingWindowSize in generated attention kernels (feature not yet supported in codegen)
mlir/test/Dialect/Rock/gridwise_attention_accel_lowering.mlir Added comprehensive test with FileCheck patterns validating sliding window masking behavior
mlir/test/Conversion/TosaToRock/tosa-to-rock-attention-kvcache.mlir Added test for sliding window pattern detection and clip pattern handling
mlir/test/fusion/pr-e2e/attention/mixr-attention-sliding-window-kvcache.mlir Added end-to-end test for sliding window attention with KVCache

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 2843 to 2850
// Validate sliding window constraints
if (getSlidingWindowSize()) {
int32_t windowSize = static_cast<int32_t>(*getSlidingWindowSize());
if (windowSize <= 0)
return emitError("slidingWindowSize must be positive");
if (!getCurrentSeqLen())
return emitError("slidingWindowSize requires currentSeqLen to be set");
}
Copy link

Copilot AI Feb 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing validation: slidingWindowSize should only be allowed when enableSoftmax is true (for attention operations). Similar to the checks for currentSeqLen (line 2826), prefixOffset (line 2829), and causal (line 2832), there should be a check that rejects slidingWindowSize when enableSoftmax is false, since sliding window masking only makes sense in the context of attention operations.

Copilot uses AI. Check for mistakes.
// Sliding window masking: mask when key_pos < max(0, currentSeqLen -
// windowSize). This is independent of causal masking and applies
// alongside KV-cache masking.
if (slidingWindowSize > 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think all the "ifs" around setGemm0OutputOutOfScope() are not necessary, because setGemm0OutputOutOfScope() internally checks if it needs to run.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, but do you find that it makes the code a bit more unreadable to do that? I'm fine either way, just wondering what your opinion is.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine either way, I just think we need to remove of of the ifs, either the outer one of the inner one

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've removed it in all of the cases except for the if/else block for prefix causal. Because prefixCausal also requires that causal be set, running the logic for both masks would not be correct in this case.

Copy link
Member

@umangyadav umangyadav left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had some similar concerns as daniels. But apart from that looks good to me.

@justinrosner justinrosner force-pushed the 499-sliding-window-tosa-to-rock branch from b5382ea to 10cd27e Compare February 24, 2026 14:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants