From 65c86ef45e6b03acbf0566cb07017ce7fc2757d2 Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Thu, 12 Feb 2026 21:52:59 +0000 Subject: [PATCH 01/12] Initial TosaToRock changes for sliding window attention --- mlir/include/mlir/Dialect/Rock/IR/RockOps.td | 12 +- mlir/lib/Conversion/TosaToRock/TosaToRock.cpp | 153 ++++++++++++++---- mlir/lib/Dialect/Rock/IR/RockDialect.cpp | 20 +++ .../Rock/Transforms/DetectFlashDecoding.cpp | 3 +- .../Rock/Transforms/GemmToGridwise.cpp | 15 +- .../Transforms/SortDimensionsMemoryLayout.cpp | 3 +- mlir/tools/rocmlir-gen/rocmlir-gen.cpp | 1 + 7 files changed, 168 insertions(+), 39 deletions(-) diff --git a/mlir/include/mlir/Dialect/Rock/IR/RockOps.td b/mlir/include/mlir/Dialect/Rock/IR/RockOps.td index 8d70078b06f0..e4b0ee7496d6 100644 --- a/mlir/include/mlir/Dialect/Rock/IR/RockOps.td +++ b/mlir/include/mlir/Dialect/Rock/IR/RockOps.td @@ -220,7 +220,8 @@ def Rock_AttentionOp Optional>:$lse, I32Attr:$numHeadsQ, I32Attr:$numHeadsKV, UnitAttr:$qTransposed, UnitAttr:$kTransposed, UnitAttr:$vTransposed, UnitAttr:$oTransposed, UnitAttr:$causal, - I32Attr:$splitKV, OptionalAttr:$features, + I32Attr:$splitKV, OptionalAttr:$slidingWindowSize, + OptionalAttr:$features, StoreMethodAttr:$storeMethod, OptionalAttr:$softmaxType, OptionalAttr:$params0, OptionalAttr:$params1, @@ -253,6 +254,11 @@ def Rock_AttentionOp - A tensor of shape [G]: per-group/batch offsets, allowing different prefix lengths for each sequence in the batch + If slidingWindowSize is set, we implement sliding window attention where + only the last `slidingWindowSize` key positions (relative to currentSeqLen) + are attended to. Positions before `max(0, currentSeqLen - slidingWindowSize)` + are masked with -inf. This requires currentSeqLen to be set. + LSE (log-sum-exp) is an optional output typically used for flash decoding. For flash decoding, you can pass splitKV > 1, the default value is 1, which means flash decoding is disabled. Flash decoding multiplies the number of blocks by splitKV. Note that "lse" has to be non-null for splitKV > 1. @@ -278,6 +284,7 @@ def Rock_AttentionOp ` ` `qk` `=` (`tr` $qTransposed^)? $queries `*` (`tr` $kTransposed^)? $keys `:` type($queries) `,` type($keys) `\n` (`currentSeqLen` `=` `(` $currentSeqLen^ `:` type($currentSeqLen) `)` `\n`)? (`prefixOffset` `=` `(` $prefixOffset^ `:` type($prefixOffset) `)` `\n`)? + (`slidingWindowSize` `=` $slidingWindowSize^ `\n`)? (`causal` `\n` $causal^)? (`lse` `=` $lse^ `:` type($lse) `\n`)? (`qk` `=` `elementwise` (`otherIns` `(` $preSoftmaxElemWiseInputs^ `:` type($preSoftmaxElemWiseInputs) `)`)? $preSoftmaxBody^ `\n`)? @@ -583,7 +590,8 @@ def Rock_GridwiseAttentionAccelOp Optional>:$prefixOffset, MemRefRankOf<[F32, F16, BF16], [3]>:$out, Optional>:$lse, UnitAttr:$causal, - I32Attr:$splitKV, OptionalAttr:$features, + I32Attr:$splitKV, OptionalAttr:$slidingWindowSize, + OptionalAttr:$features, StoreMethodAttr:$storeMethod, I32Attr:$blockSize, I32Attr:$gridSize, UnitAttr:$disableQBypassLDS, OptionalAttr:$prePadG0M, OptionalAttr:$prePadG0N, diff --git a/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp b/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp index 2b074fe65d62..83c312fef7e2 100644 --- a/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp +++ b/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp @@ -1625,6 +1625,7 @@ struct AttentionMatcherValues { Value currentSeqLen; bool isCausal; Value prefixOffset; + std::optional slidingWindowSize; Type softmaxType; ElementwiseRegionFinder preSoftmaxElementwiseFinder; }; @@ -1971,8 +1972,9 @@ struct AttentionRewritePattern : public OpRewritePattern { // Result struct for sequence length mask detection struct SeqLenMaskResult { Value inputToContinue; // The value to continue pattern matching with - Value seqLen; // The sequence length - Value prefixOffset; // The prefix offset value + Value seqLen; // The sequence length + Value prefixOffset; // The prefix offset value + std::optional slidingWindowSize; // The sliding window size }; // Helper to try detecting prefix causal pattern: add(row_indices, offset) @@ -2035,7 +2037,9 @@ struct AttentionRewritePattern : public OpRewritePattern { tryKVCachePattern(Value input, const DenseSet &seqLenSkip) const { DenseSet expandAndCollapse{ tensor::CollapseShapeOp::getOperationName(), - tensor::ExpandShapeOp::getOperationName()}; + tensor::ExpandShapeOp::getOperationName(), + tosa::MaximumOp::getOperationName(), + tosa::MinimumOp::getOperationName()}; FailureOr maybeNonOne = mulBroadcast(input); if (failed(maybeNonOne)) return failure(); @@ -2062,6 +2066,57 @@ struct AttentionRewritePattern : public OpRewritePattern { return currentSeqLen; } + // Helper to try detecting sliding window pattern: + // greater(add(seqLen, negative_const_offset) * broadcast, col_indices) + FailureOr + trySlidingWindowPattern(Value input, + const DenseSet &seqLenSkip) const { + DenseSet expandAndCollapse{ + tensor::CollapseShapeOp::getOperationName(), + tensor::ExpandShapeOp::getOperationName()}; + + // Trace through broadcast multiplication (mul by 1) + FailureOr maybeNonOne = mulBroadcast(input); + if (failed(maybeNonOne)) + maybeNonOne = input; + + // Look for add(seqLen, constant_offset) + auto maybeAdd = getDefiningOpSkipping(maybeNonOne.value(), + expandAndCollapse); + if (failed(maybeAdd)) + return failure(); + + auto add = maybeAdd.value(); + + // One operand of the add is currentSeqLen (already tracked by KV-cache), + // the other is a negative constant (-windowSize). Try both operands. + auto tryExtractNegativeConst = [&](Value candidate) -> FailureOr { + auto maybeSkipped = + getValueSkipping(candidate, expandAndCollapse); + Value constVal = + succeeded(maybeSkipped) ? maybeSkipped.value() : candidate; + + DenseElementsAttr constAttr; + if (!matchPattern(constVal, m_Constant(&constAttr))) + return failure(); + if (!constAttr.getElementType().isInteger(32) || !constAttr.isSplat()) + return failure(); + + int32_t offset = constAttr.getSplatValue(); + if (offset >= 0) + return failure(); + return -static_cast(offset); + }; + + auto maybeWindowSize = tryExtractNegativeConst(add.getInput2()); + if (failed(maybeWindowSize)) + maybeWindowSize = tryExtractNegativeConst(add.getInput1()); + if (failed(maybeWindowSize)) + return failure(); + + return maybeWindowSize.value(); + } + /* LSE pattern for seqLen1 would be simplified from log(sum(exp(sub(x, x)))) + max(x) @@ -2231,26 +2286,42 @@ struct AttentionRewritePattern : public OpRewritePattern { auto greater = maybeGreater.value(); - // input1 must be column indices (constant range from 0) - if (failed(isConstantRange(greater.getInput1(), 0))) - return; + // Standard direction: greater(col_indices, value) + // Used for KV-cache and prefix-causal masks + if (succeeded(isConstantRange(greater.getInput1(), 0))) { + Value input2 = greater.getInput2(); - Value input2 = greater.getInput2(); + // Try KV-cache pattern (scalar seqLen) if not already found + if (!result.seqLen) { + auto maybeKVCache = tryKVCachePattern(input2, seqLenSkip); + if (succeeded(maybeKVCache)) { + result.seqLen = maybeKVCache.value(); + } + } - // Try KV-cache pattern (scalar seqLen) if not already found - if (!result.seqLen) { - auto maybeKVCache = tryKVCachePattern(input2, seqLenSkip); - if (succeeded(maybeKVCache)) { - result.seqLen = maybeKVCache.value(); + // Try prefix causal pattern (row_indices + offset) if not already found + if (!result.prefixOffset) { + auto maybePrefixCausal = tryPrefixCausalPattern(input2, seqLenSkip); + if (succeeded(maybePrefixCausal)) { + result.prefixOffset = maybePrefixCausal.value(); + } } + return; } - // Try prefix causal pattern (row_indices + offset) if not already found - if (!result.prefixOffset) { - auto maybePrefixCausal = tryPrefixCausalPattern(input2, seqLenSkip); - if (succeeded(maybePrefixCausal)) { - result.prefixOffset = maybePrefixCausal.value(); + // Reversed direction: greater(value, col_indices) + // Used for sliding window mask where value = seqLen + negative_offset + if (succeeded(isConstantRange(greater.getInput2(), 0))) { + Value input1 = greater.getInput1(); + + // Try sliding window pattern if not already found + if (!result.slidingWindowSize) { + auto maybeSlidingWindow = + trySlidingWindowPattern(input1, seqLenSkip); + if (succeeded(maybeSlidingWindow)) + result.slidingWindowSize = maybeSlidingWindow.value(); } + return; } } @@ -2270,31 +2341,39 @@ struct AttentionRewritePattern : public OpRewritePattern { DenseSet seqLenSkip{tensor::CollapseShapeOp::getOperationName(), tensor::ExpandShapeOp::getOperationName(), tosa::TransposeOp::getOperationName(), - tosa::MulOp::getOperationName()}; + tosa::MulOp::getOperationName(), + tosa::MaximumOp::getOperationName(), + tosa::MinimumOp::getOperationName()}; Value inputToContinue = select.getInput3(); - SeqLenMaskResult currentResult{inputToContinue, nullptr, nullptr}; + SeqLenMaskResult currentResult{inputToContinue, nullptr, nullptr, + std::nullopt}; // Analyze the first (outer) select analyzeSelectForSeqLenMask(select, currentResult, opsToSkip, seqLenSkip); - // Check if the inputToContinue (input3) is another chained select with -inf - // This handles the case where KVCache and prefix causal use separate - // selects. + // Check if the inputToContinue (input3) is another chained select with + // -inf. This handles cases where multiple mask patterns (KVCache, prefix + // causal, sliding window) use separate selects. bool haveSeqLen = currentResult.seqLen != nullptr; bool havePrefixOffset = currentResult.prefixOffset != nullptr; + bool haveSlidingWindow = currentResult.slidingWindowSize.has_value(); - if (haveSeqLen != havePrefixOffset) { + // Try chaining if we found at least one pattern but not all + bool foundAny = haveSeqLen || havePrefixOffset || haveSlidingWindow; + bool foundAll = haveSeqLen && havePrefixOffset && haveSlidingWindow; + if (foundAny && !foundAll) { auto maybeChainedSelect = getSelectWithNegInf(inputToContinue); if (succeeded(maybeChainedSelect)) { auto chainedSelect = maybeChainedSelect.value(); // Try to analyze the chained select for the missing pattern analyzeSelectForSeqLenMask(chainedSelect, currentResult, opsToSkip, seqLenSkip); - // Only update inputToContinue if we found the complementary pattern + // Only update inputToContinue if we found a complementary pattern bool foundComplementary = (!haveSeqLen && currentResult.seqLen) || - (!havePrefixOffset && currentResult.prefixOffset); + (!havePrefixOffset && currentResult.prefixOffset) || + (!haveSlidingWindow && currentResult.slidingWindowSize.has_value()); if (foundComplementary) { currentResult.inputToContinue = chainedSelect.getInput3(); } @@ -2302,7 +2381,8 @@ struct AttentionRewritePattern : public OpRewritePattern { } // We need at least one pattern to be detected - if (!currentResult.seqLen && !currentResult.prefixOffset) + if (!currentResult.seqLen && !currentResult.prefixOffset && + !currentResult.slidingWindowSize) return failure(); return currentResult; @@ -2890,16 +2970,18 @@ struct AttentionRewritePattern : public OpRewritePattern { return failure(); } - // Detect sequence length masking patterns (KV-cache or prefix causal) - // Note that non KV-Cache fusions might have tosa.select - // so, if the checks fail, we just keep going + // Detect sequence length masking patterns (KV-cache, prefix causal, + // or sliding window). Note that non KV-Cache fusions might have + // tosa.select so, if the checks fail, we just keep going Value kvCacheInput, currentSeqLen, prefixOffset; + std::optional slidingWindowSize; auto maybeSeqLenMask = getSeqLenMask(softmaxInput); if (succeeded(maybeSeqLenMask)) { auto result = maybeSeqLenMask.value(); kvCacheInput = result.inputToContinue; currentSeqLen = result.seqLen; prefixOffset = result.prefixOffset; + slidingWindowSize = result.slidingWindowSize; } else { kvCacheInput = softmaxInput; } @@ -2946,6 +3028,12 @@ struct AttentionRewritePattern : public OpRewritePattern { LLVM_DEBUG(llvm::dbgs() << "isCausal = " << isCausal << "\n"); LLVM_DEBUG(llvm::dbgs() << "isPrefixCausal = " << (bool)prefixOffset << "\n"); + LLVM_DEBUG(llvm::dbgs() + << "isSlidingWindow = " << slidingWindowSize.has_value() + << (slidingWindowSize ? + " (size=" + std::to_string(*slidingWindowSize) + ")" : + "") + << "\n"); if (isDotProduct && hasReduceOp) return failure(); if (!isDotProduct && !hasReduceOp) @@ -2972,6 +3060,7 @@ struct AttentionRewritePattern : public OpRewritePattern { AttentionMatcherValues attentionMatcherValues; attentionMatcherValues.isCausal = isCausal; attentionMatcherValues.prefixOffset = prefixOffset; + attentionMatcherValues.slidingWindowSize = slidingWindowSize; attentionMatcherValues.softmaxType = softmaxType; attentionMatcherValues.softmaxValues = softmaxMatcherValues; attentionMatcherValues.lse = lse; @@ -3072,6 +3161,11 @@ struct AttentionRewritePattern : public OpRewritePattern { std::tie(queries, keys, values, numHeadsQ, numHeadsKV) = getGQAValues( rewriter, firstMatMulOp.getA(), firstMatMulOp.getB(), op.getB()); + IntegerAttr slidingWindowSizeAttr; + if (attentionMatcherValues.slidingWindowSize.has_value()) + slidingWindowSizeAttr = rewriter.getI32IntegerAttr( + attentionMatcherValues.slidingWindowSize.value()); + rock::AttentionOp attnOp = rock::AttentionOp::create( rewriter, loc, outputType, lseType, queries, keys, values, elementwiseOtherArgs, currentSeqLen, prefixOffset, output, lseOut, @@ -3082,6 +3176,7 @@ struct AttentionRewritePattern : public OpRewritePattern { /*vTransposed=*/nullptr, /*oTransposed=*/nullptr, causalAttr, /*splitKV=*/rewriter.getI32IntegerAttr(1), + slidingWindowSizeAttr, /*features=*/nullptr, rewriter.getAttr(rock::StoreMethod::Set), softmaxTypeAttr, diff --git a/mlir/lib/Dialect/Rock/IR/RockDialect.cpp b/mlir/lib/Dialect/Rock/IR/RockDialect.cpp index 5b7c3131817e..a9437f34eca5 100644 --- a/mlir/lib/Dialect/Rock/IR/RockDialect.cpp +++ b/mlir/lib/Dialect/Rock/IR/RockDialect.cpp @@ -2840,6 +2840,16 @@ LogicalResult GridwiseAttentionAccelOp::verify() { "prefixOffset requires causal to be enabled. " "Prefix causal attention is causal masking with an offset."); + // Validate sliding window constraints + if (getSlidingWindowSize()) { + int32_t windowSize = static_cast(*getSlidingWindowSize()); + if (windowSize <= 0) + return emitError("slidingWindowSize must be positive"); + if (!getCurrentSeqLen()) + return emitError( + "slidingWindowSize requires currentSeqLen to be set"); + } + return success(); } @@ -3382,6 +3392,16 @@ LogicalResult AttentionOp::verify() { "prefixOffset requires causal to be enabled. " "Prefix causal attention is causal masking with an offset."); + // Validate sliding window constraints + if (getSlidingWindowSize()) { + int32_t windowSize = static_cast(*getSlidingWindowSize()); + if (windowSize <= 0) + return emitError("slidingWindowSize must be positive"); + if (!getCurrentSeqLen()) + return emitError( + "slidingWindowSize requires currentSeqLen to be set"); + } + return verifyGemmPlusGemmLikeOp(*this, getCurrentSeqLen(), getLse(), getNumHeadsQ(), getNumHeadsKV()); } diff --git a/mlir/lib/Dialect/Rock/Transforms/DetectFlashDecoding.cpp b/mlir/lib/Dialect/Rock/Transforms/DetectFlashDecoding.cpp index 1067f4abe6ef..7414b5ef3369 100644 --- a/mlir/lib/Dialect/Rock/Transforms/DetectFlashDecoding.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/DetectFlashDecoding.cpp @@ -462,7 +462,8 @@ struct DetectFlashDecodingPattern : public OpRewritePattern { op.getNumHeadsKVAttr(), op.getQTransposedAttr(), op.getKTransposedAttr(), op.getVTransposedAttr(), op.getOTransposedAttr(), op.getCausalAttr(), - rewriter.getI32IntegerAttr(splitKVFromQ), op.getFeaturesAttr(), + rewriter.getI32IntegerAttr(splitKVFromQ), + op.getSlidingWindowSizeAttr(), op.getFeaturesAttr(), op.getStoreMethodAttr(), op.getSoftmaxTypeAttr(), op.getParams0Attr(), op.getParams1Attr(), op.getFirstGemmIndicesAttr(), /*preSoftmaxHasSplitKVTransforms=*/rewriter.getBoolAttr(true)); diff --git a/mlir/lib/Dialect/Rock/Transforms/GemmToGridwise.cpp b/mlir/lib/Dialect/Rock/Transforms/GemmToGridwise.cpp index 52420e10f2bd..506a0be67118 100644 --- a/mlir/lib/Dialect/Rock/Transforms/GemmToGridwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/GemmToGridwise.cpp @@ -488,8 +488,9 @@ static LogicalResult commonAttentionGemmElmtGemm( ConversionPatternRewriter &rw, RockGemmGemmWrapperInterface op, Value a, Value b, Value c, Value out, Value lse, Value currentSeqLen, Value prefixOffset, UnitAttr causal, IntegerAttr splitKV, - ValueRange elementwiseInputs, Region &preSecondOpRegion, bool enableSoftmax, - TypeAttr softmaxType, int64_t numHeadsQ, int64_t numHeadsKV, + IntegerAttr slidingWindowSize, ValueRange elementwiseInputs, + Region &preSecondOpRegion, bool enableSoftmax, TypeAttr softmaxType, + int64_t numHeadsQ, int64_t numHeadsKV, std::optional> bufferDeps, BoolAttr preSoftmaxHasSplitKVTransforms) { @@ -611,7 +612,8 @@ static LogicalResult commonAttentionGemmElmtGemm( auto newOp = GridwiseAttentionAccelOp::create( rw, loc, a, b, c, elementwiseInputs, currentSeqLen, prefixOffset, out, - lse, causal, splitKV, op.getGemmFeaturesAttr(), op.getStoreMethodAttr(), + lse, causal, splitKV, slidingWindowSize, + op.getGemmFeaturesAttr(), op.getStoreMethodAttr(), blockSizeAttr, gridSizeAttr, /*disableQBypassLDS=*/nullptr, prePadG0MAttr, prePadG0NAttr, numRepeatsGQA, softmaxType, params0, params1, @@ -1088,8 +1090,8 @@ AttentionRewritePattern::matchAndRewrite(AttentionOp op, rw, op, adaptor.getQueries(), adaptor.getKeys(), adaptor.getValues(), adaptor.getOut(), adaptor.getLse(), adaptor.getCurrentSeqLen(), adaptor.getPrefixOffset(), adaptor.getCausalAttr(), - adaptor.getSplitKVAttr(), adaptor.getPreSoftmaxElemWiseInputs(), - op.getPreSoftmaxBody(), + adaptor.getSplitKVAttr(), adaptor.getSlidingWindowSizeAttr(), + adaptor.getPreSoftmaxElemWiseInputs(), op.getPreSoftmaxBody(), /*enableSoftmax=*/true, op.getSoftmaxTypeAttr(), adaptor.getNumHeadsQ(), adaptor.getNumHeadsKV(), /*bufferDeps=*/std::nullopt, @@ -1104,7 +1106,8 @@ LogicalResult GemmElementwiseGemmRewritePattern::matchAndRewrite( rw, op, adaptor.getA(), adaptor.getB(), adaptor.getC(), adaptor.getOut(), /*lse=*/nullptr, /*currentSeqLen=*/nullptr, /*prefixOffset=*/nullptr, /*causal=*/nullptr, - splitKV, adaptor.getElemwiseInputs(), op.getPreSecondGemmBody(), + splitKV, /*slidingWindowSize=*/nullptr, adaptor.getElemwiseInputs(), + op.getPreSecondGemmBody(), /*enableSoftmax=*/false, /*softmaxType=*/nullptr, /*numHeadsQ=*/1, /*numHeadsKV=*/1, std::cref(bufferDeps), /*preSoftmaxHasSplitKVTransforms=*/rw.getBoolAttr(false)); diff --git a/mlir/lib/Dialect/Rock/Transforms/SortDimensionsMemoryLayout.cpp b/mlir/lib/Dialect/Rock/Transforms/SortDimensionsMemoryLayout.cpp index a148f67fc52d..9d8ef68204cc 100644 --- a/mlir/lib/Dialect/Rock/Transforms/SortDimensionsMemoryLayout.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/SortDimensionsMemoryLayout.cpp @@ -622,7 +622,8 @@ struct AttentionRewritePattern : public OpRewritePattern { op.getPrefixOffset(), op.getOut(), op.getLse(), op.getNumHeadsQAttr(), op.getNumHeadsKVAttr(), transposedQ, transposedK, transposedV, op.getOTransposedAttr(), op.getCausalAttr(), op.getSplitKVAttr(), - op.getFeaturesAttr(), op.getStoreMethodAttr(), op.getSoftmaxTypeAttr(), + op.getSlidingWindowSizeAttr(), op.getFeaturesAttr(), + op.getStoreMethodAttr(), op.getSoftmaxTypeAttr(), op.getParams0Attr(), op.getParams1Attr(), op.getFirstGemmIndicesAttr(), op.getPreSoftmaxHasSplitKVTransformsAttr()); diff --git a/mlir/tools/rocmlir-gen/rocmlir-gen.cpp b/mlir/tools/rocmlir-gen/rocmlir-gen.cpp index 335d3f4c5ebb..66858eca4bcc 100644 --- a/mlir/tools/rocmlir-gen/rocmlir-gen.cpp +++ b/mlir/tools/rocmlir-gen/rocmlir-gen.cpp @@ -3379,6 +3379,7 @@ static func::FuncOp createGpuAttentionKernel(ModuleOp module, currentSeqLenTensor, prefixOffsetTensor, output, lse, numHeadsQ, numHeadsKV, transposeQ, transposeK, transposeV, transposeO, actualCausal, splitKV, + /*slidingWindowSize=*/nullptr, rock::GemmFeaturesAttr::get(builder.getContext(), params.features), storeMethod, softmaxType, /*params0=*/nullptr, /*params1=*/nullptr, From ff90aebc63f9ccdb4e4437f79ea687e00e933043 Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Thu, 12 Feb 2026 22:14:31 +0000 Subject: [PATCH 02/12] Add GridwiseGemmToBlockwise masking support --- .../Transforms/GridwiseGemmToBlockwise.cpp | 73 ++++++++++++++++--- 1 file changed, 61 insertions(+), 12 deletions(-) diff --git a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp index 8c91c4083bb2..2b8099256a92 100644 --- a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp @@ -1256,17 +1256,18 @@ struct GridwiseAttentionAccelRewritePattern } } - enum class OutOfScopeType { KVCache, Causal, PrefixCausal }; + enum class OutOfScopeType { KVCache, Causal, PrefixCausal, SlidingWindow }; void setGemm0OutputOutOfScope( PatternRewriter &rewriter, Location loc, OutOfScopeType outOfScopeType, layout::GridCoordinates gridCoords, Value gemm0OutBuffer, RegsAsMatrixSubTiles gemm0OutSubTileViews, bool enabled, Value mLoopIV, Value gemm0MBlocksLastIter, Value currentSeqLen, Value prefixOffset, - IntegerAttr numRepeatsGQA) const { + IntegerAttr numRepeatsGQA, + Value slidingWindowLowerBound) const { if (enabled) { // For KVCache, we only need to mask on the last iteration, but for causal - // masking we need to mask on every iteration. + // and sliding window masking we need to mask on every iteration. bool needsLastIterCheck = (outOfScopeType == OutOfScopeType::KVCache); // Use a lambda to generate the masking logic. @@ -1338,6 +1339,15 @@ struct GridwiseAttentionAccelRewritePattern mIndex, threshold); break; } + case OutOfScopeType::SlidingWindow: { + // Sliding window: mask when key_pos < max(0, currentSeqLen - + // windowSize). slidingWindowLowerBound is precomputed as + // max(0, currentSeqLen - windowSize). + assert(slidingWindowLowerBound != nullptr); + isInvalid = arith::CmpIOp::create(b, loc, arith::CmpIPredicate::ult, + mIndex, slidingWindowLowerBound); + break; + } } scf::IfOp ifOp = scf::IfOp::create(b, loc, isInvalid, @@ -1738,17 +1748,19 @@ struct GridwiseAttentionAccelRewritePattern return viewBuilder.get(); } - std::tuple + std::tuple getMLoopInfo(PatternRewriter &rewriter, Location loc, layout::AttnGridCoordinates gridCoordsGemm0, Value currentSeqLenTensor, Value prefixOffsetTensor, int64_t gemm0M, int64_t gemm0N, int64_t gemm0MPerBlock, int64_t gemm0NPerBlock, int64_t splitKV, bool isCausal, bool isKVCache, bool isPrefixCausal, + int64_t slidingWindowSize, IntegerAttr numRepeatsGQA = nullptr) const { Value gemm0MBlocksLastIter; Value currentSeqLen; Value prefixOffset; + Value slidingWindowLowerBound; Value effectiveSeqLen; Value start, end; @@ -1792,13 +1804,27 @@ struct GridwiseAttentionAccelRewritePattern loc, rewriter.getIndexType(), loadedValue); }; - // This is needed for KV Cache/Causal/Prefix Causal masking support - if (isCausal || isKVCache || isPrefixCausal) { + // This is needed for KV Cache/Causal/Prefix Causal/Sliding Window masking + if (isCausal || isKVCache || isPrefixCausal || slidingWindowSize > 0) { if (isKVCache) { currentSeqLen = loadTensorValue(currentSeqLenTensor); effectiveSeqLen = currentSeqLen; } + // Compute sliding window lower bound: max(0, currentSeqLen - windowSize) + if (slidingWindowSize > 0) { + assert(currentSeqLen != nullptr && + "sliding window requires currentSeqLen (KV-cache)"); + Value constWindowSize = + rewriter.createOrFold(loc, + slidingWindowSize); + Value zero = rewriter.createOrFold(loc, 0); + Value lowerBound = + arith::SubIOp::create(rewriter, loc, currentSeqLen, constWindowSize); + slidingWindowLowerBound = + arith::MaxSIOp::create(rewriter, loc, lowerBound, zero); + } + if (isCausal || isPrefixCausal) { // Compute the last Q position in the block. // (nIndex + 1) * NPerBlock - 1. @@ -1907,7 +1933,7 @@ struct GridwiseAttentionAccelRewritePattern end = rewriter.createOrFold(loc, gemm0MBlocks); } return std::make_tuple(start, end, gemm0MBlocksLastIter, currentSeqLen, - prefixOffset); + prefixOffset, slidingWindowLowerBound); } // Helper function to determine if early exit optimization is possible. @@ -2043,6 +2069,10 @@ struct GridwiseAttentionAccelRewritePattern bool isKVCache = currentSeqLenTensor != nullptr; bool isCausal = op.getCausal(); bool isPrefixCausal = isCausal && prefixOffsetTensor; + int64_t slidingWindowSize = + op.getSlidingWindowSize().has_value() + ? static_cast(*op.getSlidingWindowSize()) + : 0; int64_t splitKV = op.getSplitKV(); // Gemm0 out is casted to be softmaxType (if null, it's casted to elemTypeV) @@ -2451,13 +2481,16 @@ struct GridwiseAttentionAccelRewritePattern Value gemm0MBlocksLastIter; Value currentSeqLen; Value prefixOffset; + Value slidingWindowLowerBound; Value start, end; // get mLoop - std::tie(start, end, gemm0MBlocksLastIter, currentSeqLen, prefixOffset) = + std::tie(start, end, gemm0MBlocksLastIter, currentSeqLen, prefixOffset, + slidingWindowLowerBound) = getMLoopInfo(rewriter, loc, gridCoordsGemm0mIter0, currentSeqLenTensor, prefixOffsetTensor, gemm0M, gemm0N, gemm0MPerBlock, gemm0NPerBlock, splitKV, isCausal, isKVCache, - isPrefixCausal, op.getNumRepeatsGQAAttr()); + isPrefixCausal, slidingWindowSize, + op.getNumRepeatsGQAAttr()); // Early exit: Skip all computation when there's no work but always write // output. @@ -2778,7 +2811,8 @@ struct GridwiseAttentionAccelRewritePattern gemm0OutSubTileViewsTr, isKVCache, mLoopIV, gemm0MBlocksLastIter, currentSeqLen, /*prefixOffset=*/nullptr, - /*numRepeatsGQA=*/nullptr); + /*numRepeatsGQA=*/nullptr, + /*slidingWindowLowerBound=*/nullptr); } // Causal masking: either prefix-causal or standard causal @@ -2790,7 +2824,8 @@ struct GridwiseAttentionAccelRewritePattern gemm0OutSubTileViewsTr, isPrefixCausal, mLoopIV, gemm0MBlocksLastIter, /*currentSeqLen=*/nullptr, prefixOffset, - op.getNumRepeatsGQAAttr()); + op.getNumRepeatsGQAAttr(), + /*slidingWindowLowerBound=*/nullptr); } else if (isCausal) { // Standard causal masking: mask when key > query setGemm0OutputOutOfScope( @@ -2798,7 +2833,21 @@ struct GridwiseAttentionAccelRewritePattern softmaxInputBuffer, gemm0OutSubTileViewsTr, isCausal, mLoopIV, gemm0MBlocksLastIter, /*currentSeqLen=*/nullptr, - /*prefixOffset=*/nullptr, op.getNumRepeatsGQAAttr()); + /*prefixOffset=*/nullptr, op.getNumRepeatsGQAAttr(), + /*slidingWindowLowerBound=*/nullptr); + } + + // Sliding window masking: mask when key_pos < max(0, currentSeqLen - + // windowSize). This is independent of causal masking and applies + // alongside KV-cache masking. + if (slidingWindowSize > 0) { + setGemm0OutputOutOfScope( + rewriter, loc, OutOfScopeType::SlidingWindow, gridCoordsGemm0, + softmaxInputBuffer, gemm0OutSubTileViewsTr, slidingWindowSize > 0, + mLoopIV, gemm0MBlocksLastIter, + /*currentSeqLen=*/nullptr, + /*prefixOffset=*/nullptr, /*numRepeatsGQA=*/nullptr, + slidingWindowLowerBound); } APInt reductionAxis = APInt(64, 1); From 2eda9a71a3a1e83bf04e1483fb898216a2a9f7c8 Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Thu, 12 Feb 2026 22:49:52 +0000 Subject: [PATCH 03/12] Small bug fixes --- mlir/lib/Conversion/TosaToRock/TosaToRock.cpp | 159 ++++++++++++++++-- 1 file changed, 149 insertions(+), 10 deletions(-) diff --git a/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp b/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp index 83c312fef7e2..3eaf114c0de9 100644 --- a/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp +++ b/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp @@ -1626,6 +1626,8 @@ struct AttentionMatcherValues { bool isCausal; Value prefixOffset; std::optional slidingWindowSize; + std::optional seqLenClipMin; + std::optional seqLenClipMax; Type softmaxType; ElementwiseRegionFinder preSoftmaxElementwiseFinder; }; @@ -1975,6 +1977,10 @@ struct AttentionRewritePattern : public OpRewritePattern { Value seqLen; // The sequence length Value prefixOffset; // The prefix offset value std::optional slidingWindowSize; // The sliding window size + // Clip bounds detected on currentSeqLen during sliding window matching. + // When present, currentSeqLen should be clipped to [clipMin, clipMax]. + std::optional seqLenClipMin; + std::optional seqLenClipMax; }; // Helper to try detecting prefix causal pattern: add(row_indices, offset) @@ -2066,9 +2072,83 @@ struct AttentionRewritePattern : public OpRewritePattern { return currentSeqLen; } + // Result of sliding window pattern detection + struct SlidingWindowResult { + int64_t windowSize; + std::optional clipMin; + std::optional clipMax; + }; + + // Struct for clip detection result + struct ClipBounds { + int32_t clipMin; + int32_t clipMax; + }; + + // Helper to detect a clip pattern on a value: + // tosa.minimum(tosa.maximum(x, constLo), constHi) + // This is how migraphx.clip(x, lo, hi) is lowered to TOSA. + // Returns the clip bounds if the pattern is detected. + FailureOr tryDetectClipPattern(Value input) const { + DenseSet expandAndCollapse{ + tensor::CollapseShapeOp::getOperationName(), + tensor::ExpandShapeOp::getOperationName()}; + + // Helper to extract a splat i32 constant from a value + auto extractI32Constant = [&](Value val) -> std::optional { + auto maybeSkipped = getValueSkipping(val, expandAndCollapse); + Value v = succeeded(maybeSkipped) ? maybeSkipped.value() : val; + DenseElementsAttr attr; + if (!matchPattern(v, m_Constant(&attr))) + return std::nullopt; + if (!attr.getElementType().isInteger(32) || !attr.isSplat()) + return std::nullopt; + return attr.getSplatValue(); + }; + + // Look for tosa.minimum (the outer clip op) + auto maybeMin = + getDefiningOpSkipping(input, expandAndCollapse); + if (failed(maybeMin)) + return failure(); + auto minOp = maybeMin.value(); + + // One input of minimum is a constant (clipMax), the other is maximum + Value maxCandidate; + std::optional clipMax; + clipMax = extractI32Constant(minOp.getInput2()); + if (clipMax) { + maxCandidate = minOp.getInput1(); + } else { + clipMax = extractI32Constant(minOp.getInput1()); + if (clipMax) + maxCandidate = minOp.getInput2(); + else + return failure(); + } + + // Look for tosa.maximum (the inner clip op) + auto maybeMax = + getDefiningOpSkipping(maxCandidate, expandAndCollapse); + if (failed(maybeMax)) + return failure(); + auto maxOp = maybeMax.value(); + + // One input of maximum is a constant (clipMin) + std::optional clipMin; + clipMin = extractI32Constant(maxOp.getInput2()); + if (!clipMin) + clipMin = extractI32Constant(maxOp.getInput1()); + if (!clipMin) + return failure(); + + return ClipBounds{*clipMin, *clipMax}; + } + // Helper to try detecting sliding window pattern: // greater(add(seqLen, negative_const_offset) * broadcast, col_indices) - FailureOr + // Also detects an optional clip (min(max(x, lo), hi)) on the seqLen operand. + FailureOr trySlidingWindowPattern(Value input, const DenseSet &seqLenSkip) const { DenseSet expandAndCollapse{ @@ -2090,9 +2170,10 @@ struct AttentionRewritePattern : public OpRewritePattern { // One operand of the add is currentSeqLen (already tracked by KV-cache), // the other is a negative constant (-windowSize). Try both operands. - auto tryExtractNegativeConst = [&](Value candidate) -> FailureOr { - auto maybeSkipped = - getValueSkipping(candidate, expandAndCollapse); + Value seqLenOperand; + auto tryExtractNegativeConst = [&](Value candidate, + Value other) -> FailureOr { + auto maybeSkipped = getValueSkipping(candidate, expandAndCollapse); Value constVal = succeeded(maybeSkipped) ? maybeSkipped.value() : candidate; @@ -2105,16 +2186,33 @@ struct AttentionRewritePattern : public OpRewritePattern { int32_t offset = constAttr.getSplatValue(); if (offset >= 0) return failure(); + seqLenOperand = other; return -static_cast(offset); }; - auto maybeWindowSize = tryExtractNegativeConst(add.getInput2()); + auto maybeWindowSize = + tryExtractNegativeConst(add.getInput2(), add.getInput1()); if (failed(maybeWindowSize)) - maybeWindowSize = tryExtractNegativeConst(add.getInput1()); + maybeWindowSize = + tryExtractNegativeConst(add.getInput1(), add.getInput2()); if (failed(maybeWindowSize)) return failure(); - return maybeWindowSize.value(); + SlidingWindowResult result; + result.windowSize = maybeWindowSize.value(); + + // Try to detect a clip on the seqLen operand of the add. + // If detected, store the clip bounds so the rewrite phase can apply + // them after broadcasting currentSeqLen to the correct batch dims. + if (seqLenOperand) { + auto maybeClip = tryDetectClipPattern(seqLenOperand); + if (succeeded(maybeClip)) { + result.clipMin = maybeClip->clipMin; + result.clipMax = maybeClip->clipMax; + } + } + + return result; } /* @@ -2318,8 +2416,12 @@ struct AttentionRewritePattern : public OpRewritePattern { if (!result.slidingWindowSize) { auto maybeSlidingWindow = trySlidingWindowPattern(input1, seqLenSkip); - if (succeeded(maybeSlidingWindow)) - result.slidingWindowSize = maybeSlidingWindow.value(); + if (succeeded(maybeSlidingWindow)) { + auto swResult = maybeSlidingWindow.value(); + result.slidingWindowSize = swResult.windowSize; + result.seqLenClipMin = swResult.clipMin; + result.seqLenClipMax = swResult.clipMax; + } } return; } @@ -2347,7 +2449,7 @@ struct AttentionRewritePattern : public OpRewritePattern { Value inputToContinue = select.getInput3(); SeqLenMaskResult currentResult{inputToContinue, nullptr, nullptr, - std::nullopt}; + std::nullopt, std::nullopt, std::nullopt}; // Analyze the first (outer) select analyzeSelectForSeqLenMask(select, currentResult, opsToSkip, seqLenSkip); @@ -2975,6 +3077,7 @@ struct AttentionRewritePattern : public OpRewritePattern { // tosa.select so, if the checks fail, we just keep going Value kvCacheInput, currentSeqLen, prefixOffset; std::optional slidingWindowSize; + std::optional seqLenClipMin, seqLenClipMax; auto maybeSeqLenMask = getSeqLenMask(softmaxInput); if (succeeded(maybeSeqLenMask)) { auto result = maybeSeqLenMask.value(); @@ -2982,6 +3085,8 @@ struct AttentionRewritePattern : public OpRewritePattern { currentSeqLen = result.seqLen; prefixOffset = result.prefixOffset; slidingWindowSize = result.slidingWindowSize; + seqLenClipMin = result.seqLenClipMin; + seqLenClipMax = result.seqLenClipMax; } else { kvCacheInput = softmaxInput; } @@ -3061,6 +3166,8 @@ struct AttentionRewritePattern : public OpRewritePattern { attentionMatcherValues.isCausal = isCausal; attentionMatcherValues.prefixOffset = prefixOffset; attentionMatcherValues.slidingWindowSize = slidingWindowSize; + attentionMatcherValues.seqLenClipMin = seqLenClipMin; + attentionMatcherValues.seqLenClipMax = seqLenClipMax; attentionMatcherValues.softmaxType = softmaxType; attentionMatcherValues.softmaxValues = softmaxMatcherValues; attentionMatcherValues.lse = lse; @@ -3151,6 +3258,38 @@ struct AttentionRewritePattern : public OpRewritePattern { prepareBlockArgTensor(currentSeqLen); prepareBlockArgTensor(prefixOffset); + // Apply seqLen clip if detected during sliding window pattern matching. + // The original model may have clip(arg, lo, hi) on currentSeqLen which + // was traced through to reach the block argument (for correct broadcast). + // We re-apply the clip here so the GPU kernel uses the clipped value, + // matching the original model's behavior. + if (currentSeqLen && + (attentionMatcherValues.seqLenClipMin.has_value() || + attentionMatcherValues.seqLenClipMax.has_value())) { + auto seqLenType = cast(currentSeqLen.getType()); + auto elemTy = seqLenType.getElementType(); + if (attentionMatcherValues.seqLenClipMin.has_value()) { + auto minAttr = DenseElementsAttr::get( + seqLenType, + rewriter.getIntegerAttr(elemTy, + *attentionMatcherValues.seqLenClipMin)); + Value clipMinConst = + tosa::ConstOp::create(rewriter, loc, seqLenType, minAttr); + currentSeqLen = tosa::MaximumOp::create(rewriter, loc, seqLenType, + currentSeqLen, clipMinConst); + } + if (attentionMatcherValues.seqLenClipMax.has_value()) { + auto maxAttr = DenseElementsAttr::get( + seqLenType, + rewriter.getIntegerAttr(elemTy, + *attentionMatcherValues.seqLenClipMax)); + Value clipMaxConst = + tosa::ConstOp::create(rewriter, loc, seqLenType, maxAttr); + currentSeqLen = tosa::MinimumOp::create(rewriter, loc, seqLenType, + currentSeqLen, clipMaxConst); + } + } + UnitAttr causalAttr = isCausal ? rewriter.getUnitAttr() : nullptr; ElementwiseRegionFinder elemwiseRegion = attentionMatcherValues.preSoftmaxElementwiseFinder; From 2dd857fcb22ca2f467f692d957c7398ce5be7de1 Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Tue, 17 Feb 2026 14:09:57 +0000 Subject: [PATCH 04/12] Small refactor --- mlir/lib/Conversion/TosaToRock/TosaToRock.cpp | 62 +++++++++++-------- 1 file changed, 37 insertions(+), 25 deletions(-) diff --git a/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp b/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp index 3eaf114c0de9..a5050c2cc0f1 100644 --- a/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp +++ b/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp @@ -2037,11 +2037,23 @@ struct AttentionRewritePattern : public OpRewritePattern { return unwrappedOffset; } - // Helper to try detecting KV-cache pattern - // Returns the seqLen value if successful - FailureOr + // Result of KV-cache pattern detection + struct KVCacheResult { + Value seqLen; + std::optional clipMin; + std::optional clipMax; + }; + + // Helper to try detecting KV-cache pattern. + // Also detects an optional clip (min(max(x, lo), hi)) on currentSeqLen. + // The clip is a property of currentSeqLen itself (applied before any mask + // uses it), so it is detected here rather than in mask-specific matchers. + FailureOr tryKVCachePattern(Value input, const DenseSet &seqLenSkip) const { DenseSet expandAndCollapse{ + tensor::CollapseShapeOp::getOperationName(), + tensor::ExpandShapeOp::getOperationName()}; + DenseSet expandCollapseMinMax{ tensor::CollapseShapeOp::getOperationName(), tensor::ExpandShapeOp::getOperationName(), tosa::MaximumOp::getOperationName(), @@ -2060,8 +2072,19 @@ struct AttentionRewritePattern : public OpRewritePattern { !llvm::all_of(shape.slice(2), [](int32_t v) { return v == 1; })) return failure(); + // Try to detect a clip pattern on currentSeqLen before skipping through + // min/max. The clip (min(max(x, lo), hi)) may wrap the block argument + // and applies to all masks that use currentSeqLen. + KVCacheResult result; + auto maybeClip = tryDetectClipPattern(maybeNonOne.value()); + if (succeeded(maybeClip)) { + result.clipMin = maybeClip->clipMin; + result.clipMax = maybeClip->clipMax; + } + + // Skip through expand/collapse/min/max to reach the block argument auto maybeCurrentSeqLen = - getValueSkipping(maybeNonOne.value(), expandAndCollapse); + getValueSkipping(maybeNonOne.value(), expandCollapseMinMax); assert(succeeded(maybeCurrentSeqLen) && "Must have non-reshape op"); Value currentSeqLen = maybeCurrentSeqLen.value(); @@ -2069,14 +2092,13 @@ struct AttentionRewritePattern : public OpRewritePattern { if (!isI32BlockArgument(currentSeqLen, seqLenSkip)) return failure(); - return currentSeqLen; + result.seqLen = currentSeqLen; + return result; } // Result of sliding window pattern detection struct SlidingWindowResult { int64_t windowSize; - std::optional clipMin; - std::optional clipMax; }; // Struct for clip detection result @@ -2201,17 +2223,6 @@ struct AttentionRewritePattern : public OpRewritePattern { SlidingWindowResult result; result.windowSize = maybeWindowSize.value(); - // Try to detect a clip on the seqLen operand of the add. - // If detected, store the clip bounds so the rewrite phase can apply - // them after broadcasting currentSeqLen to the correct batch dims. - if (seqLenOperand) { - auto maybeClip = tryDetectClipPattern(seqLenOperand); - if (succeeded(maybeClip)) { - result.clipMin = maybeClip->clipMin; - result.clipMax = maybeClip->clipMax; - } - } - return result; } @@ -2393,7 +2404,10 @@ struct AttentionRewritePattern : public OpRewritePattern { if (!result.seqLen) { auto maybeKVCache = tryKVCachePattern(input2, seqLenSkip); if (succeeded(maybeKVCache)) { - result.seqLen = maybeKVCache.value(); + auto kvResult = maybeKVCache.value(); + result.seqLen = kvResult.seqLen; + result.seqLenClipMin = kvResult.clipMin; + result.seqLenClipMax = kvResult.clipMax; } } @@ -2417,10 +2431,7 @@ struct AttentionRewritePattern : public OpRewritePattern { auto maybeSlidingWindow = trySlidingWindowPattern(input1, seqLenSkip); if (succeeded(maybeSlidingWindow)) { - auto swResult = maybeSlidingWindow.value(); - result.slidingWindowSize = swResult.windowSize; - result.seqLenClipMin = swResult.clipMin; - result.seqLenClipMax = swResult.clipMax; + result.slidingWindowSize = maybeSlidingWindow->windowSize; } } return; @@ -3258,9 +3269,10 @@ struct AttentionRewritePattern : public OpRewritePattern { prepareBlockArgTensor(currentSeqLen); prepareBlockArgTensor(prefixOffset); - // Apply seqLen clip if detected during sliding window pattern matching. + // Apply seqLen clip if detected during KV-cache pattern matching. // The original model may have clip(arg, lo, hi) on currentSeqLen which - // was traced through to reach the block argument (for correct broadcast). + // was traced through to reach the block argument. The clip is a property + // of currentSeqLen itself, used by all masks (KV-cache, sliding window). // We re-apply the clip here so the GPU kernel uses the clipped value, // matching the original model's behavior. if (currentSeqLen && From 04c00484e106e03db78ea8d9f043fad2799f07e4 Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Tue, 17 Feb 2026 14:18:09 +0000 Subject: [PATCH 05/12] More refactoring --- mlir/lib/Conversion/TosaToRock/TosaToRock.cpp | 29 +++++-------------- 1 file changed, 7 insertions(+), 22 deletions(-) diff --git a/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp b/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp index a5050c2cc0f1..dd2b43004a3b 100644 --- a/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp +++ b/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp @@ -1977,8 +1977,7 @@ struct AttentionRewritePattern : public OpRewritePattern { Value seqLen; // The sequence length Value prefixOffset; // The prefix offset value std::optional slidingWindowSize; // The sliding window size - // Clip bounds detected on currentSeqLen during sliding window matching. - // When present, currentSeqLen should be clipped to [clipMin, clipMax]. + // Clip bounds detected on currentSeqLen during KV-cache pattern matching. std::optional seqLenClipMin; std::optional seqLenClipMax; }; @@ -2046,8 +2045,6 @@ struct AttentionRewritePattern : public OpRewritePattern { // Helper to try detecting KV-cache pattern. // Also detects an optional clip (min(max(x, lo), hi)) on currentSeqLen. - // The clip is a property of currentSeqLen itself (applied before any mask - // uses it), so it is detected here rather than in mask-specific matchers. FailureOr tryKVCachePattern(Value input, const DenseSet &seqLenSkip) const { DenseSet expandAndCollapse{ @@ -2076,7 +2073,7 @@ struct AttentionRewritePattern : public OpRewritePattern { // min/max. The clip (min(max(x, lo), hi)) may wrap the block argument // and applies to all masks that use currentSeqLen. KVCacheResult result; - auto maybeClip = tryDetectClipPattern(maybeNonOne.value()); + auto maybeClip = tryClipPattern(maybeNonOne.value()); if (succeeded(maybeClip)) { result.clipMin = maybeClip->clipMin; result.clipMax = maybeClip->clipMax; @@ -2096,11 +2093,6 @@ struct AttentionRewritePattern : public OpRewritePattern { return result; } - // Result of sliding window pattern detection - struct SlidingWindowResult { - int64_t windowSize; - }; - // Struct for clip detection result struct ClipBounds { int32_t clipMin; @@ -2109,9 +2101,7 @@ struct AttentionRewritePattern : public OpRewritePattern { // Helper to detect a clip pattern on a value: // tosa.minimum(tosa.maximum(x, constLo), constHi) - // This is how migraphx.clip(x, lo, hi) is lowered to TOSA. - // Returns the clip bounds if the pattern is detected. - FailureOr tryDetectClipPattern(Value input) const { + FailureOr tryClipPattern(Value input) const { DenseSet expandAndCollapse{ tensor::CollapseShapeOp::getOperationName(), tensor::ExpandShapeOp::getOperationName()}; @@ -2169,8 +2159,8 @@ struct AttentionRewritePattern : public OpRewritePattern { // Helper to try detecting sliding window pattern: // greater(add(seqLen, negative_const_offset) * broadcast, col_indices) - // Also detects an optional clip (min(max(x, lo), hi)) on the seqLen operand. - FailureOr + // Returns the window size if successful. + FailureOr trySlidingWindowPattern(Value input, const DenseSet &seqLenSkip) const { DenseSet expandAndCollapse{ @@ -2220,10 +2210,7 @@ struct AttentionRewritePattern : public OpRewritePattern { if (failed(maybeWindowSize)) return failure(); - SlidingWindowResult result; - result.windowSize = maybeWindowSize.value(); - - return result; + return maybeWindowSize.value(); } /* @@ -2431,7 +2418,7 @@ struct AttentionRewritePattern : public OpRewritePattern { auto maybeSlidingWindow = trySlidingWindowPattern(input1, seqLenSkip); if (succeeded(maybeSlidingWindow)) { - result.slidingWindowSize = maybeSlidingWindow->windowSize; + result.slidingWindowSize = maybeSlidingWindow.value(); } } return; @@ -3273,8 +3260,6 @@ struct AttentionRewritePattern : public OpRewritePattern { // The original model may have clip(arg, lo, hi) on currentSeqLen which // was traced through to reach the block argument. The clip is a property // of currentSeqLen itself, used by all masks (KV-cache, sliding window). - // We re-apply the clip here so the GPU kernel uses the clipped value, - // matching the original model's behavior. if (currentSeqLen && (attentionMatcherValues.seqLenClipMin.has_value() || attentionMatcherValues.seqLenClipMax.has_value())) { From c76bd001d0765fb51094cac405fd75f0ee3e4dec Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Tue, 17 Feb 2026 14:19:35 +0000 Subject: [PATCH 06/12] Clang-format --- mlir/lib/Conversion/TosaToRock/TosaToRock.cpp | 35 ++++++++----------- mlir/lib/Dialect/Rock/IR/RockDialect.cpp | 6 ++-- .../Rock/Transforms/DetectFlashDecoding.cpp | 7 ++-- .../Rock/Transforms/GemmToGridwise.cpp | 5 ++- .../Transforms/GridwiseGemmToBlockwise.cpp | 15 ++++---- .../Transforms/SortDimensionsMemoryLayout.cpp | 4 +-- 6 files changed, 30 insertions(+), 42 deletions(-) diff --git a/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp b/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp index dd2b43004a3b..27c05c5ae48f 100644 --- a/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp +++ b/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp @@ -1974,8 +1974,8 @@ struct AttentionRewritePattern : public OpRewritePattern { // Result struct for sequence length mask detection struct SeqLenMaskResult { Value inputToContinue; // The value to continue pattern matching with - Value seqLen; // The sequence length - Value prefixOffset; // The prefix offset value + Value seqLen; // The sequence length + Value prefixOffset; // The prefix offset value std::optional slidingWindowSize; // The sliding window size // Clip bounds detected on currentSeqLen during KV-cache pattern matching. std::optional seqLenClipMin; @@ -2415,8 +2415,7 @@ struct AttentionRewritePattern : public OpRewritePattern { // Try sliding window pattern if not already found if (!result.slidingWindowSize) { - auto maybeSlidingWindow = - trySlidingWindowPattern(input1, seqLenSkip); + auto maybeSlidingWindow = trySlidingWindowPattern(input1, seqLenSkip); if (succeeded(maybeSlidingWindow)) { result.slidingWindowSize = maybeSlidingWindow.value(); } @@ -2446,8 +2445,8 @@ struct AttentionRewritePattern : public OpRewritePattern { tosa::MinimumOp::getOperationName()}; Value inputToContinue = select.getInput3(); - SeqLenMaskResult currentResult{inputToContinue, nullptr, nullptr, - std::nullopt, std::nullopt, std::nullopt}; + SeqLenMaskResult currentResult{inputToContinue, nullptr, nullptr, + std::nullopt, std::nullopt, std::nullopt}; // Analyze the first (outer) select analyzeSelectForSeqLenMask(select, currentResult, opsToSkip, seqLenSkip); @@ -3133,9 +3132,9 @@ struct AttentionRewritePattern : public OpRewritePattern { << "isPrefixCausal = " << (bool)prefixOffset << "\n"); LLVM_DEBUG(llvm::dbgs() << "isSlidingWindow = " << slidingWindowSize.has_value() - << (slidingWindowSize ? - " (size=" + std::to_string(*slidingWindowSize) + ")" : - "") + << (slidingWindowSize + ? " (size=" + std::to_string(*slidingWindowSize) + ")" + : "") << "\n"); if (isDotProduct && hasReduceOp) return failure(); @@ -3260,16 +3259,14 @@ struct AttentionRewritePattern : public OpRewritePattern { // The original model may have clip(arg, lo, hi) on currentSeqLen which // was traced through to reach the block argument. The clip is a property // of currentSeqLen itself, used by all masks (KV-cache, sliding window). - if (currentSeqLen && - (attentionMatcherValues.seqLenClipMin.has_value() || - attentionMatcherValues.seqLenClipMax.has_value())) { + if (currentSeqLen && (attentionMatcherValues.seqLenClipMin.has_value() || + attentionMatcherValues.seqLenClipMax.has_value())) { auto seqLenType = cast(currentSeqLen.getType()); auto elemTy = seqLenType.getElementType(); if (attentionMatcherValues.seqLenClipMin.has_value()) { auto minAttr = DenseElementsAttr::get( - seqLenType, - rewriter.getIntegerAttr(elemTy, - *attentionMatcherValues.seqLenClipMin)); + seqLenType, rewriter.getIntegerAttr( + elemTy, *attentionMatcherValues.seqLenClipMin)); Value clipMinConst = tosa::ConstOp::create(rewriter, loc, seqLenType, minAttr); currentSeqLen = tosa::MaximumOp::create(rewriter, loc, seqLenType, @@ -3277,9 +3274,8 @@ struct AttentionRewritePattern : public OpRewritePattern { } if (attentionMatcherValues.seqLenClipMax.has_value()) { auto maxAttr = DenseElementsAttr::get( - seqLenType, - rewriter.getIntegerAttr(elemTy, - *attentionMatcherValues.seqLenClipMax)); + seqLenType, rewriter.getIntegerAttr( + elemTy, *attentionMatcherValues.seqLenClipMax)); Value clipMaxConst = tosa::ConstOp::create(rewriter, loc, seqLenType, maxAttr); currentSeqLen = tosa::MinimumOp::create(rewriter, loc, seqLenType, @@ -3311,8 +3307,7 @@ struct AttentionRewritePattern : public OpRewritePattern { /*kTransposed=*/nullptr, /*vTransposed=*/nullptr, /*oTransposed=*/nullptr, causalAttr, - /*splitKV=*/rewriter.getI32IntegerAttr(1), - slidingWindowSizeAttr, + /*splitKV=*/rewriter.getI32IntegerAttr(1), slidingWindowSizeAttr, /*features=*/nullptr, rewriter.getAttr(rock::StoreMethod::Set), softmaxTypeAttr, diff --git a/mlir/lib/Dialect/Rock/IR/RockDialect.cpp b/mlir/lib/Dialect/Rock/IR/RockDialect.cpp index a9437f34eca5..c5a9e55c5a9b 100644 --- a/mlir/lib/Dialect/Rock/IR/RockDialect.cpp +++ b/mlir/lib/Dialect/Rock/IR/RockDialect.cpp @@ -2846,8 +2846,7 @@ LogicalResult GridwiseAttentionAccelOp::verify() { if (windowSize <= 0) return emitError("slidingWindowSize must be positive"); if (!getCurrentSeqLen()) - return emitError( - "slidingWindowSize requires currentSeqLen to be set"); + return emitError("slidingWindowSize requires currentSeqLen to be set"); } return success(); @@ -3398,8 +3397,7 @@ LogicalResult AttentionOp::verify() { if (windowSize <= 0) return emitError("slidingWindowSize must be positive"); if (!getCurrentSeqLen()) - return emitError( - "slidingWindowSize requires currentSeqLen to be set"); + return emitError("slidingWindowSize requires currentSeqLen to be set"); } return verifyGemmPlusGemmLikeOp(*this, getCurrentSeqLen(), getLse(), diff --git a/mlir/lib/Dialect/Rock/Transforms/DetectFlashDecoding.cpp b/mlir/lib/Dialect/Rock/Transforms/DetectFlashDecoding.cpp index 7414b5ef3369..f4c868bb1d7f 100644 --- a/mlir/lib/Dialect/Rock/Transforms/DetectFlashDecoding.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/DetectFlashDecoding.cpp @@ -462,10 +462,9 @@ struct DetectFlashDecodingPattern : public OpRewritePattern { op.getNumHeadsKVAttr(), op.getQTransposedAttr(), op.getKTransposedAttr(), op.getVTransposedAttr(), op.getOTransposedAttr(), op.getCausalAttr(), - rewriter.getI32IntegerAttr(splitKVFromQ), - op.getSlidingWindowSizeAttr(), op.getFeaturesAttr(), - op.getStoreMethodAttr(), op.getSoftmaxTypeAttr(), op.getParams0Attr(), - op.getParams1Attr(), op.getFirstGemmIndicesAttr(), + rewriter.getI32IntegerAttr(splitKVFromQ), op.getSlidingWindowSizeAttr(), + op.getFeaturesAttr(), op.getStoreMethodAttr(), op.getSoftmaxTypeAttr(), + op.getParams0Attr(), op.getParams1Attr(), op.getFirstGemmIndicesAttr(), /*preSoftmaxHasSplitKVTransforms=*/rewriter.getBoolAttr(true)); // Copy the preSoftmax elementwise region if it exists diff --git a/mlir/lib/Dialect/Rock/Transforms/GemmToGridwise.cpp b/mlir/lib/Dialect/Rock/Transforms/GemmToGridwise.cpp index 506a0be67118..1ba83bdde8ff 100644 --- a/mlir/lib/Dialect/Rock/Transforms/GemmToGridwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/GemmToGridwise.cpp @@ -612,9 +612,8 @@ static LogicalResult commonAttentionGemmElmtGemm( auto newOp = GridwiseAttentionAccelOp::create( rw, loc, a, b, c, elementwiseInputs, currentSeqLen, prefixOffset, out, - lse, causal, splitKV, slidingWindowSize, - op.getGemmFeaturesAttr(), op.getStoreMethodAttr(), - blockSizeAttr, gridSizeAttr, + lse, causal, splitKV, slidingWindowSize, op.getGemmFeaturesAttr(), + op.getStoreMethodAttr(), blockSizeAttr, gridSizeAttr, /*disableQBypassLDS=*/nullptr, prePadG0MAttr, prePadG0NAttr, numRepeatsGQA, softmaxType, params0, params1, rw.getDenseI64ArrayAttr(op.getFirstGemmIndices()), diff --git a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp index 2b8099256a92..6a1c9fbacee7 100644 --- a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp @@ -1263,8 +1263,7 @@ struct GridwiseAttentionAccelRewritePattern layout::GridCoordinates gridCoords, Value gemm0OutBuffer, RegsAsMatrixSubTiles gemm0OutSubTileViews, bool enabled, Value mLoopIV, Value gemm0MBlocksLastIter, Value currentSeqLen, Value prefixOffset, - IntegerAttr numRepeatsGQA, - Value slidingWindowLowerBound) const { + IntegerAttr numRepeatsGQA, Value slidingWindowLowerBound) const { if (enabled) { // For KVCache, we only need to mask on the last iteration, but for causal // and sliding window masking we need to mask on every iteration. @@ -1754,8 +1753,7 @@ struct GridwiseAttentionAccelRewritePattern Value currentSeqLenTensor, Value prefixOffsetTensor, int64_t gemm0M, int64_t gemm0N, int64_t gemm0MPerBlock, int64_t gemm0NPerBlock, int64_t splitKV, bool isCausal, - bool isKVCache, bool isPrefixCausal, - int64_t slidingWindowSize, + bool isKVCache, bool isPrefixCausal, int64_t slidingWindowSize, IntegerAttr numRepeatsGQA = nullptr) const { Value gemm0MBlocksLastIter; Value currentSeqLen; @@ -1815,12 +1813,11 @@ struct GridwiseAttentionAccelRewritePattern if (slidingWindowSize > 0) { assert(currentSeqLen != nullptr && "sliding window requires currentSeqLen (KV-cache)"); - Value constWindowSize = - rewriter.createOrFold(loc, - slidingWindowSize); + Value constWindowSize = rewriter.createOrFold( + loc, slidingWindowSize); Value zero = rewriter.createOrFold(loc, 0); - Value lowerBound = - arith::SubIOp::create(rewriter, loc, currentSeqLen, constWindowSize); + Value lowerBound = arith::SubIOp::create(rewriter, loc, currentSeqLen, + constWindowSize); slidingWindowLowerBound = arith::MaxSIOp::create(rewriter, loc, lowerBound, zero); } diff --git a/mlir/lib/Dialect/Rock/Transforms/SortDimensionsMemoryLayout.cpp b/mlir/lib/Dialect/Rock/Transforms/SortDimensionsMemoryLayout.cpp index 9d8ef68204cc..4e7c8173052d 100644 --- a/mlir/lib/Dialect/Rock/Transforms/SortDimensionsMemoryLayout.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/SortDimensionsMemoryLayout.cpp @@ -623,8 +623,8 @@ struct AttentionRewritePattern : public OpRewritePattern { op.getNumHeadsKVAttr(), transposedQ, transposedK, transposedV, op.getOTransposedAttr(), op.getCausalAttr(), op.getSplitKVAttr(), op.getSlidingWindowSizeAttr(), op.getFeaturesAttr(), - op.getStoreMethodAttr(), op.getSoftmaxTypeAttr(), - op.getParams0Attr(), op.getParams1Attr(), op.getFirstGemmIndicesAttr(), + op.getStoreMethodAttr(), op.getSoftmaxTypeAttr(), op.getParams0Attr(), + op.getParams1Attr(), op.getFirstGemmIndicesAttr(), op.getPreSoftmaxHasSplitKVTransformsAttr()); // copy linalg::GenericOp if there's any From 3c5b79dd70e45ad0b806f6561abc9903aa2b7ca2 Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Tue, 17 Feb 2026 14:45:48 +0000 Subject: [PATCH 07/12] Add LIT tests --- .../tosa-to-rock-attention-kvcache.mlir | 79 +++++++++++ .../gridwise_attention_accel_lowering.mlir | 128 ++++++++++++++++++ ...mixr-attention-sliding-window-kvcache.mlir | 57 ++++++++ 3 files changed, 264 insertions(+) create mode 100644 mlir/test/fusion/pr-e2e/attention/mixr-attention-sliding-window-kvcache.mlir diff --git a/mlir/test/Conversion/TosaToRock/tosa-to-rock-attention-kvcache.mlir b/mlir/test/Conversion/TosaToRock/tosa-to-rock-attention-kvcache.mlir index f2db6da1a630..ba24083d1923 100644 --- a/mlir/test/Conversion/TosaToRock/tosa-to-rock-attention-kvcache.mlir +++ b/mlir/test/Conversion/TosaToRock/tosa-to-rock-attention-kvcache.mlir @@ -410,3 +410,82 @@ func.func @mlir_causal_attention_nokvcache_wrongrange(%arg0: tensor<24576xf16>, %collapsed_8 = tensor.collapse_shape %29 [[0, 1, 2, 3]] : tensor<1x2x32x128xf16> into tensor<8192xf16> return %collapsed_8 : tensor<8192xf16> } + +// CHECK-LABEL:func @mlir_attention_kvcache_sliding_window +// CHECK: %[[MAX:.*]] = tosa.maximum +// CHECK: %[[CLIP:.*]] = tosa.minimum %[[MAX]], {{.*}} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> +// CHECK: currentSeqLen = (%[[CLIP]] : tensor<2xi32>) +// CHECK: slidingWindowSize = 3 +func.func @mlir_attention_kvcache_sliding_window(%arg0: tensor<1xi32>, %arg1: tensor<12xf16>, %arg2: tensor<32xf16>, %arg3: tensor<32xf16>) -> tensor<4xf16> attributes {kernel = "mixr"} { + %0 = "tosa.const"() <{values = dense<4> : tensor<1x1x1x1xi32>}> : () -> tensor<1x1x1x1xi32> + %1 = tosa.const_shape {values = dense<4> : tensor<1xindex>} : () -> !tosa.shape<1> + %2 = tosa.const_shape {values = dense<[2, 8, 2]> : tensor<3xindex>} : () -> !tosa.shape<3> + %3 = tosa.const_shape {values = dense<[2, 1, 8]> : tensor<3xindex>} : () -> !tosa.shape<3> + %4 = "tosa.const"() <{values = dense<1.000000e+00> : tensor<1x2x1x8xf32>}> : () -> tensor<1x2x1x8xf32> + %5 = "tosa.const"() <{values = dense<1> : tensor<1x2x1x8xi8>}> : () -> tensor<1x2x1x8xi8> + %6 = tosa.const_shape {values = dense<8> : tensor<1xindex>} : () -> !tosa.shape<1> + %7 = "tosa.const"() <{values = dense<1> : tensor<8x1x1x1xi32>}> : () -> tensor<8x1x1x1xi32> + %8 = "tosa.const"() <{values = dense<5.000000e-01> : tensor<1x2x1x8xf16>}> : () -> tensor<1x2x1x8xf16> + %9 = "tosa.const"() <{values = dense<0xFC00> : tensor<1x2x1x8xf16>}> : () -> tensor<1x2x1x8xf16> + %10 = tosa.const_shape {values = dense<[1, 2, 1, 8]> : tensor<4xindex>} : () -> !tosa.shape<4> + %11 = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf16>}> : () -> tensor<1xf16> + %12 = tosa.const_shape {values = dense<[2, 2, 8]> : tensor<3xindex>} : () -> !tosa.shape<3> + %13 = tosa.const_shape {values = dense<[2, 1, 2]> : tensor<3xindex>} : () -> !tosa.shape<3> + %14 = tosa.const_shape {values = dense<[1, 2, 1, 2]> : tensor<4xindex>} : () -> !tosa.shape<4> + %15 = tosa.const_shape {values = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4> + %16 = "tosa.const"() <{values = dense<-3> : tensor<1x1x1x1xi32>}> : () -> tensor<1x1x1x1xi32> + %17 = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %18 = "tosa.const"() <{values = dense<1> : tensor<1x1x1x8xi32>}> : () -> tensor<1x1x1x8xi32> + %19 = tosa.const_shape {values = dense<[1, 1, 1, 8]> : tensor<4xindex>} : () -> !tosa.shape<4> + %20 = "tosa.const"() <{values = dense<[0, 1, 2, 3, 4, 5, 6, 7]> : tensor<8xi32>}> : () -> tensor<8xi32> + %21 = tosa.const_shape {values = dense<[1, 6, 1, 2]> : tensor<4xindex>} : () -> !tosa.shape<4> + %22 = tosa.const_shape {values = dense<[1, 2, 8, 2]> : tensor<4xindex>} : () -> !tosa.shape<4> + %expanded = tensor.expand_shape %arg2 [[0, 1, 2, 3]] output_shape [1, 2, 8, 2] : tensor<32xf16> into tensor<1x2x8x2xf16> + %expanded_0 = tensor.expand_shape %arg1 [[0, 1, 2, 3]] output_shape [1, 6, 1, 2] : tensor<12xf16> into tensor<1x6x1x2xf16> + %cst = arith.constant dense<[[[[0, 1, 2, 3, 4, 5, 6, 7]]]]> : tensor<1x1x1x8xi32> + %expanded_1 = tensor.expand_shape %arg0 [[0, 1, 2, 3]] output_shape [1, 1, 1, 1] : tensor<1xi32> into tensor<1x1x1x1xi32> + %23 = tosa.maximum %expanded_1, %0 : (tensor<1x1x1x1xi32>, tensor<1x1x1x1xi32>) -> tensor<1x1x1x1xi32> + %24 = tosa.minimum %23, %0 : (tensor<1x1x1x1xi32>, tensor<1x1x1x1xi32>) -> tensor<1x1x1x1xi32> + %extracted_slice = tensor.extract_slice %expanded_0[0, 0, 0, 0] [1, 2, 1, 2] [1, 1, 1, 1] : tensor<1x6x1x2xf16> to tensor<1x2x1x2xf16> + %25 = tosa.transpose %expanded {perms = array} : (tensor<1x2x8x2xf16>) -> tensor<1x2x2x8xf16> + %collapsed = tensor.collapse_shape %extracted_slice [[0, 1], [2], [3]] : tensor<1x2x1x2xf16> into tensor<2x1x2xf16> + %collapsed_2 = tensor.collapse_shape %25 [[0, 1], [2], [3]] : tensor<1x2x2x8xf16> into tensor<2x2x8xf16> + %26 = tosa.matmul %collapsed, %collapsed_2, %11, %11 {acc_type = f32} : (tensor<2x1x2xf16>, tensor<2x2x8xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<2x1x8xf16> + %expanded_3 = tensor.expand_shape %26 [[0, 1], [2], [3]] output_shape [1, 2, 1, 8] : tensor<2x1x8xf16> into tensor<1x2x1x8xf16> + %27 = tosa.mul %expanded_3, %8, %17 : (tensor<1x2x1x8xf16>, tensor<1x2x1x8xf16>, tensor<1xi8>) -> tensor<1x2x1x8xf16> + %28 = tosa.add %24, %16 : (tensor<1x1x1x1xi32>, tensor<1x1x1x1xi32>) -> tensor<1x1x1x1xi32> + %29 = tosa.mul %28, %7, %17 : (tensor<1x1x1x1xi32>, tensor<8x1x1x1xi32>, tensor<1xi8>) -> tensor<8x1x1x1xi32> + %collapsed_4 = tensor.collapse_shape %29 [[0, 1, 2, 3]] : tensor<8x1x1x1xi32> into tensor<8xi32> + %30 = tosa.greater %collapsed_4, %20 : (tensor<8xi32>, tensor<8xi32>) -> tensor<8xi1> + %31 = tosa.cast %30 : (tensor<8xi1>) -> tensor<8xi32> + %32 = tosa.cast %31 : (tensor<8xi32>) -> tensor<8xi8> + %expanded_5 = tensor.expand_shape %32 [[0, 1, 2, 3]] output_shape [1, 1, 1, 8] : tensor<8xi8> into tensor<1x1x1x8xi8> + %33 = tosa.mul %expanded_5, %5, %17 : (tensor<1x1x1x8xi8>, tensor<1x2x1x8xi8>, tensor<1xi8>) -> tensor<1x2x1x8xi8> + %34 = tosa.cast %33 : (tensor<1x2x1x8xi8>) -> tensor<1x2x1x8xi1> + %35 = tosa.select %34, %9, %27 : (tensor<1x2x1x8xi1>, tensor<1x2x1x8xf16>, tensor<1x2x1x8xf16>) -> tensor<1x2x1x8xf16> + %36 = tosa.mul %24, %18, %17 : (tensor<1x1x1x1xi32>, tensor<1x1x1x8xi32>, tensor<1xi8>) -> tensor<1x1x1x8xi32> + %37 = tosa.greater %cst, %36 : (tensor<1x1x1x8xi32>, tensor<1x1x1x8xi32>) -> tensor<1x1x1x8xi1> + %38 = tosa.cast %37 : (tensor<1x1x1x8xi1>) -> tensor<1x1x1x8xi32> + %39 = tosa.cast %38 : (tensor<1x1x1x8xi32>) -> tensor<1x1x1x8xi8> + %40 = tosa.mul %39, %5, %17 : (tensor<1x1x1x8xi8>, tensor<1x2x1x8xi8>, tensor<1xi8>) -> tensor<1x2x1x8xi8> + %41 = tosa.cast %40 : (tensor<1x2x1x8xi8>) -> tensor<1x2x1x8xi1> + %42 = tosa.select %41, %9, %35 : (tensor<1x2x1x8xi1>, tensor<1x2x1x8xf16>, tensor<1x2x1x8xf16>) -> tensor<1x2x1x8xf16> + %43 = tosa.cast %42 : (tensor<1x2x1x8xf16>) -> tensor<1x2x1x8xf32> + %44 = tosa.reduce_max %43 {axis = 3 : i32} : (tensor<1x2x1x8xf32>) -> tensor<1x2x1x1xf32> + %45 = tosa.mul %44, %4, %17 : (tensor<1x2x1x1xf32>, tensor<1x2x1x8xf32>, tensor<1xi8>) -> tensor<1x2x1x8xf32> + %46 = tosa.sub %43, %45 : (tensor<1x2x1x8xf32>, tensor<1x2x1x8xf32>) -> tensor<1x2x1x8xf32> + %47 = tosa.exp %46 : (tensor<1x2x1x8xf32>) -> tensor<1x2x1x8xf32> + %48 = tosa.reduce_sum %47 {axis = 3 : i32} : (tensor<1x2x1x8xf32>) -> tensor<1x2x1x1xf32> + %49 = tosa.mul %48, %4, %17 : (tensor<1x2x1x1xf32>, tensor<1x2x1x8xf32>, tensor<1xi8>) -> tensor<1x2x1x8xf32> + %50 = tosa.reciprocal %49 : (tensor<1x2x1x8xf32>) -> tensor<1x2x1x8xf32> + %51 = tosa.mul %47, %50, %17 : (tensor<1x2x1x8xf32>, tensor<1x2x1x8xf32>, tensor<1xi8>) -> tensor<1x2x1x8xf32> + %52 = tosa.cast %51 : (tensor<1x2x1x8xf32>) -> tensor<1x2x1x8xf16> + %collapsed_6 = tensor.collapse_shape %52 [[0, 1], [2], [3]] : tensor<1x2x1x8xf16> into tensor<2x1x8xf16> + %expanded_7 = tensor.expand_shape %arg3 [[0, 1, 2]] output_shape [2, 8, 2] : tensor<32xf16> into tensor<2x8x2xf16> + %53 = tosa.matmul %collapsed_6, %expanded_7, %11, %11 {acc_type = f32} : (tensor<2x1x8xf16>, tensor<2x8x2xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<2x1x2xf16> + %expanded_8 = tensor.expand_shape %53 [[0, 1], [2], [3]] output_shape [1, 2, 1, 2] : tensor<2x1x2xf16> into tensor<1x2x1x2xf16> + %54 = tosa.transpose %expanded_8 {perms = array} : (tensor<1x2x1x2xf16>) -> tensor<1x1x2x2xf16> + %collapsed_9 = tensor.collapse_shape %54 [[0, 1, 2, 3]] : tensor<1x1x2x2xf16> into tensor<4xf16> + return %collapsed_9 : tensor<4xf16> +} + diff --git a/mlir/test/Dialect/Rock/gridwise_attention_accel_lowering.mlir b/mlir/test/Dialect/Rock/gridwise_attention_accel_lowering.mlir index deaf0e606632..704e73d7af03 100644 --- a/mlir/test/Dialect/Rock/gridwise_attention_accel_lowering.mlir +++ b/mlir/test/Dialect/Rock/gridwise_attention_accel_lowering.mlir @@ -907,3 +907,131 @@ func.func @gridwise_attn_wavespereu_outputswizzle(%arg0: memref<1474560xf16>, %a } {blockSize = 128 : i32, enableSoftmax = false, firstGemmIndices = array, splitKV = 1 : i32, gridSize = 512 : i32, operandSegmentSizes = array, params0 = #rock.accel_gemm_params, params1 = #rock.accel_gemm_params, storeMethod = #rock} : memref<4x512x4096xf16>, memref<4x512x1024xf16>, memref<4x1024x384xf16>, memref<4x4096x384xf16> return } + +// ----- + +// CHECK-LABEL: @mlir_attention +// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index +// CHECK-DAG: %[[c32:.+]] = arith.constant 32 : index +// CHECK-DAG: %[[negInf:.+]] = arith.constant 0xFF800000 : f32 +// Load current sequence length +// CHECK: %[[registers:.+]] = rock.alloc() : memref<1xi32, #gpu.address_space> +// CHECK-NEXT: rock.threadwise_read_into {forceUnroll, useIndexDiffs} [](%{{.+}}) [%{{.+}}] -> %[[registers]] : memref<2x1xi32> -> memref<1xi32, #gpu.address_space>, vector<1xi1> +// CHECK-NEXT: %[[currSeqLen:.+]] = rock.in_bounds_load %[[registers]][%[[c0]]] : memref<1xi32, #gpu.address_space>, index -> i32 +// CHECK-NEXT: %[[currSeqLenIndex:.+]] = arith.index_cast %[[currSeqLen]] : i32 to index + +// Sliding window lower bound: max(seqLen - windowSize, 0) +// CHECK-NEXT: %[[seqLenMinusWindow:.+]] = arith.subi %[[currSeqLenIndex]], %[[c3]] : index +// CHECK-NEXT: %[[slidingWindowLB:.+]] = arith.maxsi %[[seqLenMinusWindow]], %[[c0]] : index + +// Dynamic loop bound: ceil(seqLen / tileSize) +// CHECK-NEXT: %[[num:.+]] = arith.addi %[[currSeqLenIndex]], %[[c32]] : index +// CHECK-NEXT: %[[numIter:.+]] = arith.divui %[[num]], %[[c32]] : index +// CHECK-NEXT: %[[lastIter:.+]] = arith.subi %[[numIter]], %[[c1]] : index + +// Outer N-tile loop +// CHECK: scf.for %[[iterIndex:.+]] = %{{.+}} to %[[numIter]] step %[[c1]] { +// KV-cache mask: on last iteration, mask positions > seqLen +// CHECK: %[[isLastIter:.+]] = arith.cmpi eq, %[[iterIndex]], %[[lastIter]] : index +// CHECK-NEXT: scf.if %[[isLastIter]] { +// CHECK: rock.transforming_for {forceUnroll, useIndexDiffs} (%[[dim0:.+]], %[[dim1:.+]], %[[dim2:.+]]) = [{{.*}}]({{.*}}), ({{.*}}) = [] + +// Sliding window mask: mask positions below the window +// CHECK: rock.transforming_for {forceUnroll, useIndexDiffs} (%[[swDim0:.+]], %[[swDim1:.+]], %[[swDim2:.+]]) = [{{.*}}]({{.*}}), ({{.*}}) = [] +// CHECK-NEXT: %[[slidingWindowCmp:.+]] = arith.cmpi ult, %[[swDim2]], %[[slidingWindowLB]] : index +// CHECK-NEXT: scf.if %[[slidingWindowCmp]] { +// CHECK-NEXT: rock.in_bounds_store %[[negInf]] + +#accel_gemm_params = #rock.accel_gemm_params +#map = affine_map<(d0, d1, d2, d3) -> ((d1 * 8 + d2) * 2 + d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d1 * 2 + d3)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)> +#map4 = affine_map<(d0, d1, d2) -> (0, d0, d1, d2)> +#map5 = affine_map<(d0, d1, d2) -> ((d0 * 8 + d1) * 2 + d2)> +#map6 = affine_map<(d0, d1, d2, d3) -> (((d0 + d1) * 2 + d2) * 2 + d3)> +#map7 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)> +#map8 = affine_map<(d0, d1) -> (d1)> +#map9 = affine_map<(d0, d1) -> (d0, 0)> +#map10 = affine_map<(d0) -> (0, d0)> +#map11 = affine_map<(d0) -> (d0)> +#map12 = affine_map<(d0, d1, d2) -> (d0, d2, d1)> +#map13 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map14 = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)> +#map15 = affine_map<(d0, d1) -> (0, d0, 0, d1)> +#map16 = affine_map<(d0, d1) -> (d0, d1)> +#transform_map = #rock.transform_map<#map by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 2, 8, 2] -> [32]> +#transform_map1 = #rock.transform_map<#map1 by [ ["dim0"] at [0]>, [] at []>, [] at []>] bounds = [1, 6, 1, 2] -> [12]> +#transform_map2 = #rock.transform_map<#map2 by [ ["dim0", "dim1", "dim2", "dim3"] at [0, 1, 2, 3]>] bounds = [1, 2, 1, 2] -> [1, 6, 1, 2]> +#transform_map3 = #rock.transform_map<#map3 by [ ["dim0", "dim1", "dim3", "dim2"] at [0, 1, 3, 2]>] bounds = [1, 2, 2, 8] -> [1, 2, 8, 2]> +#transform_map4 = #rock.transform_map<#map4 by [ ["col0", "col1"] at [0, 1]>, ["dim1"] at [2]>, ["dim2"] at [3]>] bounds = [2, 1, 2] -> [1, 2, 1, 2]> +#transform_map5 = #rock.transform_map<#map4 by [ ["col0", "col1"] at [0, 1]>, ["dim1"] at [2]>, ["dim2"] at [3]>] bounds = [2, 2, 8] -> [1, 2, 2, 8]> +#transform_map6 = #rock.transform_map<#map5 by [ ["dim0"] at [0]>] bounds = [2, 8, 2] -> [32]> +#transform_map7 = #rock.transform_map<#map6 by [ ["dim0"] at [0]>] bounds = [1, 1, 2, 2] -> [4]> +#transform_map8 = #rock.transform_map<#map7 by [ ["dim0", "dim2", "dim1", "dim3"] at [0, 1, 2, 3]>] bounds = [1, 2, 1, 2] -> [1, 1, 2, 2]> +#transform_map9 = #rock.transform_map<#map4 by [ ["exp1"] at [1]>, ["dim1"] at [2]>, ["dim2"] at [3]>, ["unit0"] at [0]>] bounds = [2, 1, 2] -> [1, 2, 1, 2]> +#transform_map10 = #rock.transform_map<#map8 by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 1] -> [1]> +#transform_map11 = #rock.transform_map<#map9 by [ ["dim0"] at [0]>, ["dim1"] at [1]>] bounds = [1, 2] -> [1, 1]> +#transform_map12 = #rock.transform_map<#map10 by [ ["col0", "col1"] at [0, 1]>] bounds = [2] -> [1, 2]> +#transform_map13 = #rock.transform_map<#map12 by [ ["dim0", "dim2", "dim1"] at [0, 2, 1]>] bounds = [2, 8, 2] -> [2, 2, 8]> +#transform_map14 = #rock.transform_map<#map12 by [ ["gemmG"] at [0]>, ["gemm0K", "gemm0M"] at [2, 1]>] bounds = [2, 2, 1] -> [2, 1, 2]> +#transform_map15 = #rock.transform_map<#map12 by [ ["gemmG"] at [0]>, ["gemm0K", "gemm0N"] at [2, 1]>] bounds = [2, 2, 8] -> [2, 8, 2]> +#transform_map16 = #rock.transform_map<#map13 by [ ["gemmG"] at [0]>, ["gemm0K"] at [1]>, ["gemm0N"] at [2]>] bounds = [2, 32, 32] -> [2, 2, 1]> +#transform_map17 = #rock.transform_map<#map13 by [ ["gemmG"] at [0]>, ["gemm0K"] at [1]>, ["gemm0M"] at [2]>] bounds = [2, 32, 32] -> [2, 2, 8]> +#transform_map18 = #rock.transform_map<#map13 by [ ["gemmG"] at [0]>, ["gemm1K"] at [1]>, ["gemm1M"] at [2]>] bounds = [2, 32, 32] -> [2, 8, 2]> +#transform_map19 = #rock.transform_map<#map13 by [ ["gemmG"] at [0]>, ["gemm1N"] at [1]>, ["gemm1M"] at [2]>] bounds = [2, 32, 32] -> [2, 1, 2]> +#transform_map20 = #rock.transform_map<#map14 by [ ["dim0"] at [0]>, ["dim1"] at [1]>, ["dim2"] at [2]>, [] at []>] bounds = [1, 2, 1, 8] -> [2, 1, 8]> +#transform_map21 = #rock.transform_map<#map15 by [ ["col0", "col1", "col2"] at [0, 1, 2]>, ["dim1"] at [3]>] bounds = [2, 8] -> [1, 2, 1, 8]> +#transform_map22 = #rock.transform_map<#map15 by [ ["exp1"] at [1]>, ["dim1"] at [3]>, ["unit0"] at [0]>, ["unit2"] at [2]>] bounds = [2, 8] -> [1, 2, 1, 8]> +module { + func.func @mlir_attention(%arg0: memref<1xi32>, %arg1: memref<12xf16>, %arg2: memref<32xf16>, %arg3: memref<32xf16>, %arg4: memref<4xf16>) attributes {arch = "gfx950", block_size = 64 : i32, features = #rock, grid_size = 2 : i32, kernel = "mixr"} { + %cst = arith.constant 5.000000e-01 : f16 + %c4_i32 = arith.constant 4 : i32 + %0 = rock.transform %arg2 by #transform_map : memref<32xf16> to memref<1x2x8x2xf16> + %1 = rock.transform %arg1 by #transform_map1 : memref<12xf16> to memref<1x6x1x2xf16> + %2 = rock.transform %1 by #transform_map2 : memref<1x6x1x2xf16> to memref<1x2x1x2xf16> + %3 = rock.transform %0 by #transform_map3 : memref<1x2x8x2xf16> to memref<1x2x2x8xf16> + %4 = rock.transform %2 by #transform_map4 : memref<1x2x1x2xf16> to memref<2x1x2xf16> + %5 = rock.transform %3 by #transform_map5 : memref<1x2x2x8xf16> to memref<2x2x8xf16> + %6 = rock.transform %arg3 by #transform_map6 : memref<32xf16> to memref<2x8x2xf16> + %alloc = memref.alloc() : memref<4xf16> + %7 = rock.transform %alloc by #transform_map7 : memref<4xf16> to memref<1x1x2x2xf16> + %8 = rock.transform %7 by #transform_map8 : memref<1x1x2x2xf16> to memref<1x2x1x2xf16> + %9 = rock.transform %8 by #transform_map9 : memref<1x2x1x2xf16> to memref<2x1x2xf16> + %10 = rock.transform %arg0 by #transform_map10 : memref<1xi32> to memref<1x1xi32> + %11 = rock.transform %10 by #transform_map11 : memref<1x1xi32> to memref<1x2xi32> + %12 = rock.transform %11 by #transform_map12 : memref<1x2xi32> to memref<2xi32> + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<2xi32> + linalg.generic {indexing_maps = [#map11, #map11], iterator_types = ["parallel"]} ins(%12 : memref<2xi32>) outs(%alloc_0 : memref<2xi32>) attrs = {rock.majorTensorNumber = 0 : index} { + ^bb0(%in: i32, %out: i32): + %20 = arith.maxsi %in, %c4_i32 : i32 + %21 = arith.minsi %20, %c4_i32 : i32 + linalg.yield %21 : i32 + } + %13 = rock.transform %5 by #transform_map13 : memref<2x2x8xf16> to memref<2x8x2xf16> + %14 = rock.transform %4 by #transform_map14 : memref<2x1x2xf16> to memref<2x2x1xf16> + %15 = rock.transform %13 by #transform_map15 : memref<2x8x2xf16> to memref<2x2x8xf16> + %16 = rock.transform %14 by #transform_map16 : memref<2x2x1xf16> to memref<2x32x32xf16> + %17 = rock.transform %15 by #transform_map17 : memref<2x2x8xf16> to memref<2x32x32xf16> + %18 = rock.transform %6 by #transform_map18 : memref<2x8x2xf16> to memref<2x32x32xf16> + %19 = rock.transform %9 by #transform_map19 : memref<2x1x2xf16> to memref<2x32x32xf16> + rock.gridwise_attention_accel(%16, %17, %18, %alloc_0, %19) preSoftmaxOps = { + ^bb0(%arg5: memref<2x1x8xf16>, %arg6: memref<1x2x1x8xf16>): + %20 = rock.transform %arg5 by #transform_map20 : memref<2x1x8xf16> to memref<1x2x1x8xf16> + %21 = rock.transform %20 by #transform_map21 : memref<1x2x1x8xf16> to memref<2x8xf16> + %alloc_1 = memref.alloc() : memref<1x2x1x8xf16> + %22 = rock.transform %alloc_1 by #transform_map22 : memref<1x2x1x8xf16> to memref<2x8xf16> + linalg.generic {indexing_maps = [#map16, #map16], iterator_types = ["parallel", "parallel"]} ins(%21 : memref<2x8xf16>) outs(%22 : memref<2x8xf16>) attrs = {rock.majorTensorNumber = 0 : index} { + ^bb0(%in: f16, %out: f16): + %23 = arith.mulf %in, %cst : f16 + linalg.yield %23 : f16 + } + memref.copy %alloc_1, %arg6 : memref<1x2x1x8xf16> to memref<1x2x1x8xf16> + rock.yield + } {blockSize = 64 : i32, firstGemmIndices = array, gridSize = 2 : i32, operandSegmentSizes = array, params0 = #accel_gemm_params, params1 = #accel_gemm_params, prePadG0M = 8 : index, prePadG0N = 1 : index, slidingWindowSize = 3 : i32, softmaxType = f32, splitKV = 1 : i32, storeMethod = #rock} : memref<2x32x32xf16>, memref<2x32x32xf16>, memref<2x32x32xf16>, memref<2xi32>, memref<2x32x32xf16> + memref.copy %alloc, %arg4 : memref<4xf16> to memref<4xf16> + return + } +} \ No newline at end of file diff --git a/mlir/test/fusion/pr-e2e/attention/mixr-attention-sliding-window-kvcache.mlir b/mlir/test/fusion/pr-e2e/attention/mixr-attention-sliding-window-kvcache.mlir new file mode 100644 index 000000000000..e7fe25eb8dbb --- /dev/null +++ b/mlir/test/fusion/pr-e2e/attention/mixr-attention-sliding-window-kvcache.mlir @@ -0,0 +1,57 @@ +// RUN: rocmlir-gen -fut mlir_attention --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx,highlevel -host-pipeline=migraphx,highlevel | rocmlir-gen -ph -rand 1 -rand_type float -fut mlir_attention_wrapper --verifier clone - | rocmlir-driver -host-pipeline mhal -kernel-pipeline full | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s +// CHECK: [1 1 1] + +module { + func.func @mlir_attention(%arg0: !migraphx.shaped<1x1xsi32, 1x1>, + %arg1: !migraphx.shaped<1x6x1x2xf16, 12x2x2x1>, + %arg2: !migraphx.shaped<1x2x8x2xf16, 32x16x2x1>, + %arg3: !migraphx.shaped<1x2x8x2xf16, 32x16x2x1>) -> !migraphx.shaped<1x1x4xf16, 4x4x1> attributes {kernel = "mixr"} { + %0 = migraphx.literal(dense<-3> : tensor<1xsi32>) : <1xsi32, 1> + %1 = migraphx.literal(dense<[0, 1, 2, 3, 4, 5, 6, 7]> : tensor<8xsi32>) : <8xsi32, 1> + %2 = migraphx.literal(dense<4> : tensor<1x1xsi32>) : <1x1xsi32, 1x1> + %3 = migraphx.literal(dense<[0, 1, 2, 3, 4, 5, 6, 7]> : tensor<8xsi32>) : <8xsi32, 1> + %4 = migraphx.literal(dense<0xFC00> : tensor<1xf16>) : <1xf16, 1> + %5 = migraphx.literal(dense<5.000000e-01> : tensor<1xf16>) : <1xf16, 1> + %6 = migraphx.multibroadcast %1 {out_dyn_dims = [], out_lens = [1, 1, 1, 8]} : <8xsi32, 1> -> <1x1x1x8xsi32, 0x0x0x1> + %7 = migraphx.multibroadcast %0 {out_dyn_dims = [], out_lens = [1, 1, 1, 1]} : <1xsi32, 1> -> <1x1x1x1xsi32, 0x0x0x1> + %8 = migraphx.reshape %arg0 {dims = [1, 1, 1, 1]} : <1x1xsi32, 1x1> -> <1x1x1x1xsi32, 1x1x1x1> + %9 = migraphx.reshape %2 {dims = [1, 1, 1, 1]} : <1x1xsi32, 1x1> -> <1x1x1x1xsi32, 1x1x1x1> + %10 = migraphx.reshape %2 {dims = [1, 1, 1, 1]} : <1x1xsi32, 1x1> -> <1x1x1x1xsi32, 1x1x1x1> + %11 = migraphx.clip %8, %9, %10 : <1x1x1x1xsi32, 1x1x1x1>, <1x1x1x1xsi32, 1x1x1x1>, <1x1x1x1xsi32, 1x1x1x1> -> <1x1x1x1xsi32, 1x1x1x1> + %12 = migraphx.slice %arg1 {axes = [1], ends = [2], starts = [0]} : <1x6x1x2xf16, 12x2x2x1> -> <1x2x1x2xf16, 12x2x2x1> + %13 = migraphx.transpose %arg2 {permutation = [0, 1, 3, 2]} : <1x2x8x2xf16, 32x16x2x1> -> <1x2x2x8xf16, 32x16x1x2> + %14 = migraphx.dot %12, %13 : <1x2x1x2xf16, 12x2x2x1>, <1x2x2x8xf16, 32x16x1x2> -> <1x2x1x8xf16, 16x8x8x1> + %15 = migraphx.multibroadcast %4 {out_dyn_dims = [], out_lens = [1, 2, 1, 8]} : <1xf16, 1> -> <1x2x1x8xf16, 0x0x0x0> + %16 = migraphx.multibroadcast %5 {out_dyn_dims = [], out_lens = [1, 2, 1, 8]} : <1xf16, 1> -> <1x2x1x8xf16, 0x0x0x0> + %17 = migraphx.mul %14, %16 : <1x2x1x8xf16, 16x8x8x1>, <1x2x1x8xf16, 0x0x0x0> -> <1x2x1x8xf16, 16x8x8x1> + %18 = migraphx.add %11, %7 : <1x1x1x1xsi32, 1x1x1x1>, <1x1x1x1xsi32, 0x0x0x1> -> <1x1x1x1xsi32, 1x1x1x1> + %19 = migraphx.multibroadcast %18 {out_dyn_dims = [], out_lens = [8, 1, 1, 1]} : <1x1x1x1xsi32, 1x1x1x1> -> <8x1x1x1xsi32, 0x1x1x1> + %20 = migraphx.reshape %19 {dims = [8]} : <8x1x1x1xsi32, 0x1x1x1> -> <8xsi32, 0> + %21 = migraphx.greater %20, %3 : <8xsi32, 0>, <8xsi32, 1> -> <8xsi32, 1> + %22 = migraphx.convert %21 {target_type = 0 : i64} : <8xsi32, 1> to <8xsi8, 1> + %23 = migraphx.broadcast %22 {axis = 3 : i64, out_lens = [1, 2, 1, 8]} : <8xsi8, 1> -> <1x2x1x8xsi8, 0x0x0x1> + %24 = migraphx.where %23, %15, %17 : <1x2x1x8xsi8, 0x0x0x1>, <1x2x1x8xf16, 0x0x0x0>, <1x2x1x8xf16, 16x8x8x1> -> <1x2x1x8xf16, 16x8x8x1> + %25 = migraphx.multibroadcast %11 {out_dyn_dims = [], out_lens = [1, 1, 1, 8]} : <1x1x1x1xsi32, 1x1x1x1> -> <1x1x1x8xsi32, 1x1x1x0> + %26 = migraphx.greater %6, %25 : <1x1x1x8xsi32, 0x0x0x1>, <1x1x1x8xsi32, 1x1x1x0> -> <1x1x1x8xsi32, 0x0x0x1> + %27 = migraphx.convert %26 {target_type = 0 : i64} : <1x1x1x8xsi32, 0x0x0x1> to <1x1x1x8xsi8, 0x0x0x1> + %28 = migraphx.multibroadcast %27 {out_dyn_dims = [], out_lens = [1, 2, 1, 8]} : <1x1x1x8xsi8, 0x0x0x1> -> <1x2x1x8xsi8, 0x0x0x1> + %29 = migraphx.where %28, %15, %24 : <1x2x1x8xsi8, 0x0x0x1>, <1x2x1x8xf16, 0x0x0x0>, <1x2x1x8xf16, 16x8x8x1> -> <1x2x1x8xf16, 16x8x8x1> + %30 = migraphx.convert %29 {target_type = 2 : i64} : <1x2x1x8xf16, 16x8x8x1> to <1x2x1x8xf32, 16x8x8x1> + %31 = migraphx.reshape %30 {dims = [1, 2, 1, 8]} : <1x2x1x8xf32, 16x8x8x1> -> <1x2x1x8xf32, 16x8x8x1> + %32 = migraphx.reduce_max %31 {axes = [3]} : <1x2x1x8xf32, 16x8x8x1> -> <1x2x1x1xf32, 2x1x1x1> + %33 = migraphx.reshape %32 {dims = [1, 2, 1, 1]} : <1x2x1x1xf32, 2x1x1x1> -> <1x2x1x1xf32, 2x1x1x1> + %34 = migraphx.multibroadcast %33 {out_dyn_dims = [], out_lens = [1, 2, 1, 8]} : <1x2x1x1xf32, 2x1x1x1> -> <1x2x1x8xf32, 2x1x1x0> + %35 = migraphx.sub %30, %34 : <1x2x1x8xf32, 16x8x8x1>, <1x2x1x8xf32, 2x1x1x0> -> <1x2x1x8xf32, 16x8x8x1> + %36 = migraphx.exp %35 : <1x2x1x8xf32, 16x8x8x1> -> <1x2x1x8xf32, 16x8x8x1> + %37 = migraphx.reshape %36 {dims = [1, 2, 1, 8]} : <1x2x1x8xf32, 16x8x8x1> -> <1x2x1x8xf32, 16x8x8x1> + %38 = migraphx.reduce_sum %37 {axes = [3]} : <1x2x1x8xf32, 16x8x8x1> -> <1x2x1x1xf32, 2x1x1x1> + %39 = migraphx.reshape %38 {dims = [1, 2, 1, 1]} : <1x2x1x1xf32, 2x1x1x1> -> <1x2x1x1xf32, 2x1x1x1> + %40 = migraphx.multibroadcast %39 {out_dyn_dims = [], out_lens = [1, 2, 1, 8]} : <1x2x1x1xf32, 2x1x1x1> -> <1x2x1x8xf32, 2x1x1x0> + %41 = migraphx.div %36, %40 : <1x2x1x8xf32, 16x8x8x1>, <1x2x1x8xf32, 2x1x1x0> -> <1x2x1x8xf32, 16x8x8x1> + %42 = migraphx.convert %41 {target_type = 1 : i64} : <1x2x1x8xf32, 16x8x8x1> to <1x2x1x8xf16, 16x8x8x1> + %43 = migraphx.dot %42, %arg3 : <1x2x1x8xf16, 16x8x8x1>, <1x2x8x2xf16, 32x16x2x1> -> <1x2x1x2xf16, 4x2x2x1> + %44 = migraphx.transpose %43 {permutation = [0, 2, 1, 3]} : <1x2x1x2xf16, 4x2x2x1> -> <1x1x2x2xf16, 4x2x2x1> + %45 = migraphx.reshape %44 {dims = [1, 1, 4]} : <1x1x2x2xf16, 4x2x2x1> -> <1x1x4xf16, 4x4x1> + return %45 : !migraphx.shaped<1x1x4xf16, 4x4x1> + } +} From dd4d52582e71f3a596fda6ff797c2de12f09c47c Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Wed, 18 Feb 2026 20:36:35 +0000 Subject: [PATCH 08/12] Attend to review comments --- mlir/lib/Dialect/Rock/IR/RockDialect.cpp | 61 ++++++++++++++----- .../Transforms/GridwiseGemmToBlockwise.cpp | 15 ++++- .../gridwise_attention_accel_lowering.mlir | 2 +- 3 files changed, 58 insertions(+), 20 deletions(-) diff --git a/mlir/lib/Dialect/Rock/IR/RockDialect.cpp b/mlir/lib/Dialect/Rock/IR/RockDialect.cpp index c5a9e55c5a9b..52914cb74829 100644 --- a/mlir/lib/Dialect/Rock/IR/RockDialect.cpp +++ b/mlir/lib/Dialect/Rock/IR/RockDialect.cpp @@ -2799,6 +2799,28 @@ void ThreadwiseGemmAccelOp::getEffects( getGemmMatrixEffects(*this, effects); } +// Validate sliding window constraints common to attention-like ops. +static LogicalResult +verifySlidingWindowConstraints(Operation *op, + std::optional slidingWindowSize, + Value currentSeqLen, int64_t maxSeqLen) { + if (!slidingWindowSize) + return success(); + int32_t windowSize = static_cast(*slidingWindowSize); + + if (windowSize <= 0) + return op->emitError("slidingWindowSize must be positive"); + + if (!currentSeqLen) + return op->emitError("slidingWindowSize requires currentSeqLen to be set"); + + if (windowSize > maxSeqLen) + return op->emitError( + "slidingWindowSize must not exceed max sequence length"); + + return success(); +} + //===----------------------------------------------------------------------===// // GridwiseAttentionAccelOp //===----------------------------------------------------------------------===// @@ -2832,6 +2854,9 @@ LogicalResult GridwiseAttentionAccelOp::verify() { if (!getEnableSoftmax() && getCausal()) return emitError("causal only works for attention."); + if (!getEnableSoftmax() && getSlidingWindowSize()) + return emitError("slidingWindowSize only works for attention."); + // Validate prefix offset constraints // prefixOffset requires causal to be enabled (prefix causal = causal + // prefixOffset) @@ -2840,14 +2865,17 @@ LogicalResult GridwiseAttentionAccelOp::verify() { "prefixOffset requires causal to be enabled. " "Prefix causal attention is causal masking with an offset."); - // Validate sliding window constraints - if (getSlidingWindowSize()) { - int32_t windowSize = static_cast(*getSlidingWindowSize()); - if (windowSize <= 0) - return emitError("slidingWindowSize must be positive"); - if (!getCurrentSeqLen()) - return emitError("slidingWindowSize requires currentSeqLen to be set"); - } + // Keys are normalized to [G, K, M] where M = key seq len (max seq len). + // Use pre-padding value if available, otherwise use the current shape. + ShapedType kType = cast(getKeys().getType()); + int64_t maxSeqLen = + getPrePadG0M().value_or(APInt(64, kType.getShape()[2])).getSExtValue(); + + // Validate sliding window constraints. + if (failed(verifySlidingWindowConstraints(getOperation(), + getSlidingWindowSize(), + getCurrentSeqLen(), maxSeqLen))) + return failure(); return success(); } @@ -3391,14 +3419,15 @@ LogicalResult AttentionOp::verify() { "prefixOffset requires causal to be enabled. " "Prefix causal attention is causal masking with an offset."); - // Validate sliding window constraints - if (getSlidingWindowSize()) { - int32_t windowSize = static_cast(*getSlidingWindowSize()); - if (windowSize <= 0) - return emitError("slidingWindowSize must be positive"); - if (!getCurrentSeqLen()) - return emitError("slidingWindowSize requires currentSeqLen to be set"); - } + // Validate sliding window constraints. + // Max seq len is the key N dimension. + ShapedType kType = cast(getKeys().getType()); + ArrayRef kLastDims = kType.getShape().slice(kType.getRank() - 2); + int64_t maxSeqLen = getKTransposed() ? kLastDims[0] : kLastDims[1]; + if (failed(verifySlidingWindowConstraints(getOperation(), + getSlidingWindowSize(), + getCurrentSeqLen(), maxSeqLen))) + return failure(); return verifyGemmPlusGemmLikeOp(*this, getCurrentSeqLen(), getLse(), getNumHeadsQ(), getNumHeadsKV()); diff --git a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp index 6a1c9fbacee7..8018ce9ddcf0 100644 --- a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp @@ -1906,6 +1906,17 @@ struct GridwiseAttentionAccelRewritePattern gemm0MIterations); end = arith::MinUIOp::create(rewriter, loc, end, endSplitKV); } + + // Adjust start for sliding window: skip M-blocks that are entirely + // below the window. All positions in those blocks would be masked to + // -inf anyway, so we can avoid the loads and GEMMs altogether. + if (slidingWindowSize > 0) { + Value slidingWindowStart = rewriter.createOrFold( + loc, slidingWindowLowerBound, constGemm0MPerBlock); + start = arith::MaxSIOp::create(rewriter, loc, start, + slidingWindowStart); + } + // compute last iteration of the block, this will be used later in // setGemm0OutputOutOfScope() gemm0MBlocksLastIter = @@ -2067,9 +2078,7 @@ struct GridwiseAttentionAccelRewritePattern bool isCausal = op.getCausal(); bool isPrefixCausal = isCausal && prefixOffsetTensor; int64_t slidingWindowSize = - op.getSlidingWindowSize().has_value() - ? static_cast(*op.getSlidingWindowSize()) - : 0; + static_cast(op.getSlidingWindowSize().value_or(0)); int64_t splitKV = op.getSplitKV(); // Gemm0 out is casted to be softmaxType (if null, it's casted to elemTypeV) diff --git a/mlir/test/Dialect/Rock/gridwise_attention_accel_lowering.mlir b/mlir/test/Dialect/Rock/gridwise_attention_accel_lowering.mlir index 704e73d7af03..243994126cdc 100644 --- a/mlir/test/Dialect/Rock/gridwise_attention_accel_lowering.mlir +++ b/mlir/test/Dialect/Rock/gridwise_attention_accel_lowering.mlir @@ -1034,4 +1034,4 @@ module { memref.copy %alloc, %arg4 : memref<4xf16> to memref<4xf16> return } -} \ No newline at end of file +} From 52a7058838883c93c2e7169f2ec6aef3f3b1ed10 Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Wed, 18 Feb 2026 23:03:51 +0000 Subject: [PATCH 09/12] Add rocmlir-gen and PR LIT tests --- mlir/test/e2e/PrAttentionBF16.toml | 4 + mlir/test/e2e/PrAttentionDirectToLDS.toml | 4 + mlir/test/e2e/PrAttentionF16.toml | 4 + mlir/test/e2e/PrAttentionF32.toml | 4 + mlir/test/e2e/PrAttentionI8.toml | 4 + mlir/test/e2e/PrAttentionSchedule.toml | 4 + .../rocmlir-gen/attention-sliding-window.mlir | 40 +++++++++ mlir/tools/rocmlir-gen/rocmlir-gen.cpp | 81 ++++++++++++++++++- 8 files changed, 144 insertions(+), 1 deletion(-) create mode 100644 mlir/test/rocmlir-gen/attention-sliding-window.mlir diff --git a/mlir/test/e2e/PrAttentionBF16.toml b/mlir/test/e2e/PrAttentionBF16.toml index 6de1b8f50afc..ffeb0be47b96 100644 --- a/mlir/test/e2e/PrAttentionBF16.toml +++ b/mlir/test/e2e/PrAttentionBF16.toml @@ -116,6 +116,10 @@ config = "-rand 1 -return_lse -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_ [[suite.test]] config = "-rand 1 -return_lse -split_kv 4 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias" +# GQA + causal + KV Cache batch=3 + return LSE + split-kv + sliding window +[[suite.test]] +config = "-rand 1 -return_lse -split_kv 4 -current_seq_len=17,1,32 -sliding_window_size=64 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias" + # GQA + causal + KV Cache batch=3 + return LSE + split-kv (padding) [[suite.test]] config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias" diff --git a/mlir/test/e2e/PrAttentionDirectToLDS.toml b/mlir/test/e2e/PrAttentionDirectToLDS.toml index 58b8861f4c76..2717c39024a4 100644 --- a/mlir/test/e2e/PrAttentionDirectToLDS.toml +++ b/mlir/test/e2e/PrAttentionDirectToLDS.toml @@ -26,3 +26,7 @@ config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 -g 3 -num_hea # GQA + prefix causal + KV Cache batch=3 + return LSE + split-kv (padding) [[suite.test]] config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 --prefix_offset=16,14,12 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias" + +# GQA + prefix causal + KV Cache batch=3 + return LSE + split-kv (padding) + sliding window +[[suite.test]] +config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 --prefix_offset=16,14,12 -sliding_window_size=64 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias" diff --git a/mlir/test/e2e/PrAttentionF16.toml b/mlir/test/e2e/PrAttentionF16.toml index 1ceae277283a..67569039708b 100644 --- a/mlir/test/e2e/PrAttentionF16.toml +++ b/mlir/test/e2e/PrAttentionF16.toml @@ -116,6 +116,10 @@ config = "-rand 1 -return_lse -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_ [[suite.test]] config = "-rand 1 -return_lse -split_kv 4 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias" +# GQA + causal + KV Cache batch=3 + return LSE + split-kv + sliding window +[[suite.test]] +config = "-rand 1 -return_lse -split_kv 4 -current_seq_len=17,1,32 -sliding_window_size=64 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias" + # GQA + causal + KV Cache batch=3 + return LSE + split-kv (padding) [[suite.test]] config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias" diff --git a/mlir/test/e2e/PrAttentionF32.toml b/mlir/test/e2e/PrAttentionF32.toml index b9e051119733..cc7de3c4bfa3 100644 --- a/mlir/test/e2e/PrAttentionF32.toml +++ b/mlir/test/e2e/PrAttentionF32.toml @@ -88,6 +88,10 @@ config = "-rand 1 -return_lse -split_kv 4 -current_seq_len=17,1,32 -g 3 -num_hea [[suite.test]] config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias" +# GQA + causal + KV Cache batch=3 + return LSE + split-kv (padding) + sliding window +#[[suite.test]] +#config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 -sliding_window_size=64 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias" + # GQA + prefix causal + KV Cache batch=3 + return LSE + split-kv (padding) [[suite.test]] config = "-rand 1 -return_lse -split_kv 8 -prefix_offset=18,5,16 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias" diff --git a/mlir/test/e2e/PrAttentionI8.toml b/mlir/test/e2e/PrAttentionI8.toml index eb1b7c251544..6bb6895f3361 100644 --- a/mlir/test/e2e/PrAttentionI8.toml +++ b/mlir/test/e2e/PrAttentionI8.toml @@ -93,6 +93,10 @@ config = "-rand 1 -return_lse -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_ [[suite.test]] config = "-rand 1 -return_lse -split_kv 4 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias" +# GQA + causal + KV Cache batch=3 + return LSE + split-kv + sliding window +[[suite.test]] +config = "-rand 1 -return_lse -split_kv 4 -current_seq_len=17,1,32 -sliding_window_size=64 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias" + # GQA + causal + KV Cache batch=3 + return LSE + split-kv (padding) [[suite.test]] config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias" diff --git a/mlir/test/e2e/PrAttentionSchedule.toml b/mlir/test/e2e/PrAttentionSchedule.toml index 2465abecaa4e..54caf88dbe4c 100644 --- a/mlir/test/e2e/PrAttentionSchedule.toml +++ b/mlir/test/e2e/PrAttentionSchedule.toml @@ -25,6 +25,10 @@ config = "-seq_len_q 384 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 --schedul [[suite.test]] config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias --schedule_version 2" +# GQA + causal + KV Cache batch=3 + return LSE + split-kv (padding) + sliding window +[[suite.test]] +config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 -sliding_window_size=64 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias --schedule_version 2" + # GQA + prefix causal + KV Cache batch=3 + return LSE + split-kv (padding) [[suite.test]] config = "-rand 1 -return_lse -split_kv 8 -prefix_offset=18,5,16 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias --schedule_version 2" diff --git a/mlir/test/rocmlir-gen/attention-sliding-window.mlir b/mlir/test/rocmlir-gen/attention-sliding-window.mlir new file mode 100644 index 000000000000..f9a5063e4b86 --- /dev/null +++ b/mlir/test/rocmlir-gen/attention-sliding-window.mlir @@ -0,0 +1,40 @@ +// RUN: rocmlir-gen --arch gfx90a:sramecc+:xnack- --operation attention -current_seq_len=33 -sliding_window_size=16 -seq_len_q 1 -seq_len_k 64 -head_dim_qk 32 -head_dim_v 32 -t f32 -pv --apply-bufferization-pipeline=false | rocmlir-opt | FileCheck %s --enable-var-scope + +// CHECK: module attributes {mhal.arch = "[[$ARCH:.*]]"} + +// CHECK-LABEL: func.func @rock_attention +// CHECK-SAME: (%[[queriesRaw:.*0]]: memref<32xf32>, +// CHECK-SAME: %[[keysRaw:.*1]]: memref<2048xf32>, +// CHECK-SAME: %[[valuesRaw:.*2]]: memref<2048xf32>, +// CHECK-SAME: %[[currentSeqLenRaw:.*3]]: memref<1xi32>, +// CHECK-SAME: %[[outputRaw:.*4]]: memref<32xf32>) +// CHECK-SAME: attributes {kernel, mhal.arch = "[[$ARCH]]"} + +// CHECK: rock.attention +// CHECK-NEXT: qk = %{{.*}} * %{{.*}} +// CHECK-NEXT: currentSeqLen = (%{{.*}} : memref<1xi32>) +// CHECK-NEXT: slidingWindowSize = 16 +// CHECK: softmax(qk) * %{{.*}} +// CHECK: return + +// CHECK-LABEL: func.func @host_naive_attention +// Verify KV-cache masking is applied +// CHECK: tosa.matmul +// CHECK: tosa.greater +// CHECK: tosa.select + +// Verify sliding window masking is applied in the CPU verifier: +// The sliding window masking computes lowerBound = max(0, currentSeqLen - windowSize), +// then masks positions where col < lowerBound with -inf. +// CHECK: tosa.sub %{{.*}}, %{{.*}} : (tensor<1x1x1x64xi32>, tensor<1x1x1x64xi32>) -> tensor<1x1x1x64xi32> +// CHECK: tosa.maximum %{{.*}}, %{{.*}} : (tensor<1x1x1x64xi32>, tensor<1x1x1x64xi32>) -> tensor<1x1x1x64xi32> +// CHECK: tosa.greater %{{.*}}, %{{.*}} : (tensor<1x1x1x64xi32>, tensor<1x1x1x64xi32>) -> tensor<1x1x1x64xi1> +// CHECK: tosa.select %{{.*}}, %{{.*}}, %{{.*}} : (tensor<1x1x1x64xi1>, tensor<1x1x1x64xf32>, tensor<1x1x1x64xf32>) -> tensor<1x1x1x64xf32> + +// Verify softmax follows +// CHECK-DAG: tosa.reduce_max +// CHECK-DAG: tosa.exp +// CHECK-DAG: tosa.reduce_sum +// CHECK-DAG: tosa.reciprocal +// CHECK: tosa.matmul +// CHECK: return diff --git a/mlir/tools/rocmlir-gen/rocmlir-gen.cpp b/mlir/tools/rocmlir-gen/rocmlir-gen.cpp index 66858eca4bcc..d35009a4b625 100644 --- a/mlir/tools/rocmlir-gen/rocmlir-gen.cpp +++ b/mlir/tools/rocmlir-gen/rocmlir-gen.cpp @@ -691,6 +691,13 @@ static llvm::cl::opt splitKV( "the number of blocks in the sequenceLengthK dimension."), llvm::cl::value_desc("positive integer"), llvm::cl::init(1)); +static llvm::cl::opt slidingWindowSize( + "sliding_window_size", + llvm::cl::desc("Sliding window attention size. Only the last " + "slidingWindowSize key positions (relative to " + "currentSeqLen) are attended to. Requires current_seq_len."), + llvm::cl::value_desc("positive integer"), llvm::cl::init(0)); + static llvm::cl::opt returnLSE( "return_lse", llvm::cl::desc("whether the attention kernel returns LSE (log-sum-exp)"), @@ -3079,6 +3086,68 @@ static Value maskKVCacheTosa(OpBuilder builder, Location loc, Value inputTensor, return resultReshaped; } +// Sliding window masking: mask positions where col < max(0, currentSeqLen - +// slidingWindowSize). The inputTensor shape is [b*num_heads_q, seq_len_q, +// seq_len_kv]. currentSeqLenVal has shape [1x1x1x1xi32] (already reshaped). +static Value slidingWindowMaskingTosa(OpBuilder builder, Location loc, + Value inputTensor, + Value currentSeqLenVal, + int64_t windowSize, float initValue) { + auto origType = cast(inputTensor.getType()); + ArrayRef origShape = origType.getShape(); + SmallVector newShape = {origShape[0] / numHeadsQ, numHeadsQ, + origShape[1], origShape[2]}; + ImplicitLocOpBuilder implicitBuilder(loc, builder); + auto newShapeValue = tosa::getTosaConstShape(implicitBuilder, newShape); + inputTensor = rock::tosa::createOpAndInfer( + builder, loc, origType.getElementType(), inputTensor, newShapeValue); + + auto inpType = cast(inputTensor.getType()); + ArrayRef inpShape = inpType.getShape(); + + // Create column range [0, 1, ..., seq_len_kv-1] + Value colRange = createRange(builder, loc, 3, inpShape); + + // Broadcast currentSeqLen to full shape + auto outType = RankedTensorType::get(inpShape, builder.getI32Type()); + auto currentSeqLenBroadcast = rock::tosa::getMulOp( + builder, loc, currentSeqLenVal, + rock::tosa::getOneTensor(builder, loc, outType), builder.getI32Type()); + + // Compute lowerBound = max(0, currentSeqLen - slidingWindowSize) + DenseElementsAttr windowSizeAttr = DenseIntElementsAttr::get( + RankedTensorType::get(inpShape, builder.getI32Type()), + static_cast(windowSize)); + Value windowSizeConst = tosa::ConstOp::create(builder, loc, + windowSizeAttr.getType(), + windowSizeAttr); + Value lowerBound = rock::tosa::createOpAndInfer( + builder, loc, builder.getI32Type(), currentSeqLenBroadcast, + windowSizeConst); + + // Clamp lower bound to >= 0 + DenseElementsAttr zeroAttr = DenseIntElementsAttr::get( + RankedTensorType::get(inpShape, builder.getI32Type()), + static_cast(0)); + Value zeroConst = + tosa::ConstOp::create(builder, loc, zeroAttr.getType(), zeroAttr); + lowerBound = rock::tosa::createOpAndInfer( + builder, loc, builder.getI32Type(), lowerBound, zeroConst); + + // Create mask: col < lowerBound (i.e., lowerBound > col) + auto mask = rock::tosa::createOpAndInfer( + builder, loc, builder.getIntegerType(1), lowerBound, colRange); + + Value result = applyMask(builder, loc, inputTensor, mask, initValue); + + // Reshape result back to [batch_size*num_heads_q, seq_len_q, seq_len_kv] + auto origShapeValue = tosa::getTosaConstShape(implicitBuilder, origShape); + auto resultReshaped = rock::tosa::createOpAndInfer( + builder, loc, inpType.getElementType(), result, origShapeValue); + + return resultReshaped; +} + static Value broadcastBatchTosa(OpBuilder builder, Location loc, Value inputTensor, int64_t numRepeat) { if (numRepeat == 1) @@ -3379,7 +3448,9 @@ static func::FuncOp createGpuAttentionKernel(ModuleOp module, currentSeqLenTensor, prefixOffsetTensor, output, lse, numHeadsQ, numHeadsKV, transposeQ, transposeK, transposeV, transposeO, actualCausal, splitKV, - /*slidingWindowSize=*/nullptr, + /*slidingWindowSize=*/ + slidingWindowSize > 0 ? builder.getI32IntegerAttr(slidingWindowSize) + : nullptr, rock::GemmFeaturesAttr::get(builder.getContext(), params.features), storeMethod, softmaxType, /*params0=*/nullptr, /*params1=*/nullptr, @@ -4322,6 +4393,14 @@ static func::FuncOp createCpuAttentionKernelWithMlir(ModuleOp module, qkTensor = causalMaskingTosa(builder, loc, qkTensor, -std::numeric_limits::infinity()); + // Apply sliding window masking if slidingWindowSize is set. + // Masks positions where col < max(0, currentSeqLen - slidingWindowSize). + if (slidingWindowSize > 0 && currentSeqLenTensor) { + qkTensor = slidingWindowMaskingTosa( + builder, loc, qkTensor, currentSeqLenTensor, slidingWindowSize, + -std::numeric_limits::infinity()); + } + constexpr int64_t reductionAxis = 2; auto qkMaxs = rock::tosa::createOpAndInfer( builder, loc, softmaxType, qkTensor, reductionAxis); From fa10ee5e22b677c54277aa214298dc80ff41ab0b Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Wed, 18 Feb 2026 23:06:01 +0000 Subject: [PATCH 10/12] Add E2E tests --- mlir/test/e2e/AttentionDirectToLDS.toml | 4 ++++ mlir/test/e2e/AttentionNonPowerOfTwoTileSize.toml | 4 ++++ mlir/test/e2e/AttentionSchedule.toml | 4 ++++ 3 files changed, 12 insertions(+) diff --git a/mlir/test/e2e/AttentionDirectToLDS.toml b/mlir/test/e2e/AttentionDirectToLDS.toml index 3bf93205011d..b9d2f0ec09b2 100644 --- a/mlir/test/e2e/AttentionDirectToLDS.toml +++ b/mlir/test/e2e/AttentionDirectToLDS.toml @@ -109,6 +109,10 @@ config = "-rand 1 -return_lse -split_kv 4 -current_seq_len=17,1,32 -g 3 -num_hea [[suite.test]] config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias" +# GQA + causal + KV Cache batch=3 + return LSE + split-kv (padding) + sliding window +[[suite.test]] +config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 -sliding_window_size=64 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias" + # GQA + prefix causal + KV Cache batch=3 + return LSE + split-kv (padding) [[suite.test]] config = "-rand 1 -return_lse -split_kv 8 -prefix_offset=18,5,16 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias" diff --git a/mlir/test/e2e/AttentionNonPowerOfTwoTileSize.toml b/mlir/test/e2e/AttentionNonPowerOfTwoTileSize.toml index 527666fff756..2b995ae7f171 100644 --- a/mlir/test/e2e/AttentionNonPowerOfTwoTileSize.toml +++ b/mlir/test/e2e/AttentionNonPowerOfTwoTileSize.toml @@ -78,6 +78,10 @@ config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 -g 3 -num_hea [[suite.test]] config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias" +# GQA + causal + KV Cache batch=3 + return LSE + split-kv (padding) + sliding window +[[suite.test]] +config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 -sliding_window_size=64 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias" + # GQA + prefix causal + KV Cache batch=3 + return LSE + split-kv (padding) [[suite.test]] config = "-rand 1 -return_lse -split_kv 8 -prefix_offset=18,5,16 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias" diff --git a/mlir/test/e2e/AttentionSchedule.toml b/mlir/test/e2e/AttentionSchedule.toml index 5b685218e659..1134812a58ae 100644 --- a/mlir/test/e2e/AttentionSchedule.toml +++ b/mlir/test/e2e/AttentionSchedule.toml @@ -113,6 +113,10 @@ config = "-rand 1 -return_lse -split_kv 4 -current_seq_len=17,1,32 -g 3 -num_hea [[suite.test]] config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias -schedule_version 2" +# GQA + causal + KV Cache batch=3 + return LSE + split-kv (padding) + sliding window +[[suite.test]] +config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 -sliding_window_size=64 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias -schedule_version 2" + # GQA + prefix causal + KV Cache batch=3 + return LSE + split-kv (padding) [[suite.test]] config = "-rand 1 -return_lse -split_kv 8 -prefix_offset=18,5,16 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias -schedule_version 2" From 873edc23af562e626eaba4d24ba1c6756f86a83a Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Wed, 18 Feb 2026 23:06:43 +0000 Subject: [PATCH 11/12] Clang-format --- .../Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp | 4 ++-- mlir/tools/rocmlir-gen/rocmlir-gen.cpp | 8 +++----- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp index 8018ce9ddcf0..13162bc540e0 100644 --- a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp @@ -1913,8 +1913,8 @@ struct GridwiseAttentionAccelRewritePattern if (slidingWindowSize > 0) { Value slidingWindowStart = rewriter.createOrFold( loc, slidingWindowLowerBound, constGemm0MPerBlock); - start = arith::MaxSIOp::create(rewriter, loc, start, - slidingWindowStart); + start = + arith::MaxSIOp::create(rewriter, loc, start, slidingWindowStart); } // compute last iteration of the block, this will be used later in diff --git a/mlir/tools/rocmlir-gen/rocmlir-gen.cpp b/mlir/tools/rocmlir-gen/rocmlir-gen.cpp index d35009a4b625..0688daf7d71e 100644 --- a/mlir/tools/rocmlir-gen/rocmlir-gen.cpp +++ b/mlir/tools/rocmlir-gen/rocmlir-gen.cpp @@ -3090,8 +3090,7 @@ static Value maskKVCacheTosa(OpBuilder builder, Location loc, Value inputTensor, // slidingWindowSize). The inputTensor shape is [b*num_heads_q, seq_len_q, // seq_len_kv]. currentSeqLenVal has shape [1x1x1x1xi32] (already reshaped). static Value slidingWindowMaskingTosa(OpBuilder builder, Location loc, - Value inputTensor, - Value currentSeqLenVal, + Value inputTensor, Value currentSeqLenVal, int64_t windowSize, float initValue) { auto origType = cast(inputTensor.getType()); ArrayRef origShape = origType.getShape(); @@ -3118,9 +3117,8 @@ static Value slidingWindowMaskingTosa(OpBuilder builder, Location loc, DenseElementsAttr windowSizeAttr = DenseIntElementsAttr::get( RankedTensorType::get(inpShape, builder.getI32Type()), static_cast(windowSize)); - Value windowSizeConst = tosa::ConstOp::create(builder, loc, - windowSizeAttr.getType(), - windowSizeAttr); + Value windowSizeConst = tosa::ConstOp::create( + builder, loc, windowSizeAttr.getType(), windowSizeAttr); Value lowerBound = rock::tosa::createOpAndInfer( builder, loc, builder.getI32Type(), currentSeqLenBroadcast, windowSizeConst); From 10cd27ed79df6ca94e6cc6329a05f9f00b0fcd3b Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Tue, 24 Feb 2026 14:36:54 +0000 Subject: [PATCH 12/12] Attend to more review comments --- .../Transforms/GridwiseGemmToBlockwise.cpp | 32 ++++++++----------- .../gridwise_attention_accel_lowering.mlir | 5 +++ 2 files changed, 19 insertions(+), 18 deletions(-) diff --git a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp index 13162bc540e0..92c28e0d4b44 100644 --- a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp @@ -2811,15 +2811,13 @@ struct GridwiseAttentionAccelRewritePattern // KV cache masking is independent of causal masking - it masks out // positions beyond currentSeqLen (padding). Apply it whenever KV // cache is enabled, regardless of causal/prefix-causal mode. - if (isKVCache) { - setGemm0OutputOutOfScope(rewriter, loc, OutOfScopeType::KVCache, - gridCoordsGemm0, softmaxInputBuffer, - gemm0OutSubTileViewsTr, isKVCache, mLoopIV, - gemm0MBlocksLastIter, currentSeqLen, - /*prefixOffset=*/nullptr, - /*numRepeatsGQA=*/nullptr, - /*slidingWindowLowerBound=*/nullptr); - } + setGemm0OutputOutOfScope(rewriter, loc, OutOfScopeType::KVCache, + gridCoordsGemm0, softmaxInputBuffer, + gemm0OutSubTileViewsTr, isKVCache, mLoopIV, + gemm0MBlocksLastIter, currentSeqLen, + /*prefixOffset=*/nullptr, + /*numRepeatsGQA=*/nullptr, + /*slidingWindowLowerBound=*/nullptr); // Causal masking: either prefix-causal or standard causal if (isPrefixCausal) { @@ -2846,15 +2844,13 @@ struct GridwiseAttentionAccelRewritePattern // Sliding window masking: mask when key_pos < max(0, currentSeqLen - // windowSize). This is independent of causal masking and applies // alongside KV-cache masking. - if (slidingWindowSize > 0) { - setGemm0OutputOutOfScope( - rewriter, loc, OutOfScopeType::SlidingWindow, gridCoordsGemm0, - softmaxInputBuffer, gemm0OutSubTileViewsTr, slidingWindowSize > 0, - mLoopIV, gemm0MBlocksLastIter, - /*currentSeqLen=*/nullptr, - /*prefixOffset=*/nullptr, /*numRepeatsGQA=*/nullptr, - slidingWindowLowerBound); - } + setGemm0OutputOutOfScope( + rewriter, loc, OutOfScopeType::SlidingWindow, gridCoordsGemm0, + softmaxInputBuffer, gemm0OutSubTileViewsTr, slidingWindowSize > 0, + mLoopIV, gemm0MBlocksLastIter, + /*currentSeqLen=*/nullptr, + /*prefixOffset=*/nullptr, /*numRepeatsGQA=*/nullptr, + slidingWindowLowerBound); APInt reductionAxis = APInt(64, 1); // Softmax max reduction diff --git a/mlir/test/Dialect/Rock/gridwise_attention_accel_lowering.mlir b/mlir/test/Dialect/Rock/gridwise_attention_accel_lowering.mlir index 243994126cdc..b5c7adf7bee4 100644 --- a/mlir/test/Dialect/Rock/gridwise_attention_accel_lowering.mlir +++ b/mlir/test/Dialect/Rock/gridwise_attention_accel_lowering.mlir @@ -929,6 +929,11 @@ func.func @gridwise_attn_wavespereu_outputswizzle(%arg0: memref<1474560xf16>, %a // Dynamic loop bound: ceil(seqLen / tileSize) // CHECK-NEXT: %[[num:.+]] = arith.addi %[[currSeqLenIndex]], %[[c32]] : index // CHECK-NEXT: %[[numIter:.+]] = arith.divui %[[num]], %[[c32]] : index + +// Sliding window start iteration: floor(slidingWindowLB / tileSize) +// CHECK-NEXT: %[[swStartIter:.+]] = arith.divui %[[slidingWindowLB]], %[[c32]] : index +// CHECK-NEXT: %[[swStartIterClamped:.+]] = arith.maxsi %[[swStartIter]], %[[c0]] : index + // CHECK-NEXT: %[[lastIter:.+]] = arith.subi %[[numIter]], %[[c1]] : index // Outer N-tile loop