diff --git a/mlir/include/mlir/Dialect/Rock/IR/RockOps.td b/mlir/include/mlir/Dialect/Rock/IR/RockOps.td index 8d70078b06f0..e4b0ee7496d6 100644 --- a/mlir/include/mlir/Dialect/Rock/IR/RockOps.td +++ b/mlir/include/mlir/Dialect/Rock/IR/RockOps.td @@ -220,7 +220,8 @@ def Rock_AttentionOp Optional>:$lse, I32Attr:$numHeadsQ, I32Attr:$numHeadsKV, UnitAttr:$qTransposed, UnitAttr:$kTransposed, UnitAttr:$vTransposed, UnitAttr:$oTransposed, UnitAttr:$causal, - I32Attr:$splitKV, OptionalAttr:$features, + I32Attr:$splitKV, OptionalAttr:$slidingWindowSize, + OptionalAttr:$features, StoreMethodAttr:$storeMethod, OptionalAttr:$softmaxType, OptionalAttr:$params0, OptionalAttr:$params1, @@ -253,6 +254,11 @@ def Rock_AttentionOp - A tensor of shape [G]: per-group/batch offsets, allowing different prefix lengths for each sequence in the batch + If slidingWindowSize is set, we implement sliding window attention where + only the last `slidingWindowSize` key positions (relative to currentSeqLen) + are attended to. Positions before `max(0, currentSeqLen - slidingWindowSize)` + are masked with -inf. This requires currentSeqLen to be set. + LSE (log-sum-exp) is an optional output typically used for flash decoding. For flash decoding, you can pass splitKV > 1, the default value is 1, which means flash decoding is disabled. Flash decoding multiplies the number of blocks by splitKV. Note that "lse" has to be non-null for splitKV > 1. @@ -278,6 +284,7 @@ def Rock_AttentionOp ` ` `qk` `=` (`tr` $qTransposed^)? $queries `*` (`tr` $kTransposed^)? $keys `:` type($queries) `,` type($keys) `\n` (`currentSeqLen` `=` `(` $currentSeqLen^ `:` type($currentSeqLen) `)` `\n`)? (`prefixOffset` `=` `(` $prefixOffset^ `:` type($prefixOffset) `)` `\n`)? + (`slidingWindowSize` `=` $slidingWindowSize^ `\n`)? (`causal` `\n` $causal^)? (`lse` `=` $lse^ `:` type($lse) `\n`)? (`qk` `=` `elementwise` (`otherIns` `(` $preSoftmaxElemWiseInputs^ `:` type($preSoftmaxElemWiseInputs) `)`)? $preSoftmaxBody^ `\n`)? @@ -583,7 +590,8 @@ def Rock_GridwiseAttentionAccelOp Optional>:$prefixOffset, MemRefRankOf<[F32, F16, BF16], [3]>:$out, Optional>:$lse, UnitAttr:$causal, - I32Attr:$splitKV, OptionalAttr:$features, + I32Attr:$splitKV, OptionalAttr:$slidingWindowSize, + OptionalAttr:$features, StoreMethodAttr:$storeMethod, I32Attr:$blockSize, I32Attr:$gridSize, UnitAttr:$disableQBypassLDS, OptionalAttr:$prePadG0M, OptionalAttr:$prePadG0N, diff --git a/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp b/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp index 2b074fe65d62..27c05c5ae48f 100644 --- a/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp +++ b/mlir/lib/Conversion/TosaToRock/TosaToRock.cpp @@ -1625,6 +1625,9 @@ struct AttentionMatcherValues { Value currentSeqLen; bool isCausal; Value prefixOffset; + std::optional slidingWindowSize; + std::optional seqLenClipMin; + std::optional seqLenClipMax; Type softmaxType; ElementwiseRegionFinder preSoftmaxElementwiseFinder; }; @@ -1973,6 +1976,10 @@ struct AttentionRewritePattern : public OpRewritePattern { Value inputToContinue; // The value to continue pattern matching with Value seqLen; // The sequence length Value prefixOffset; // The prefix offset value + std::optional slidingWindowSize; // The sliding window size + // Clip bounds detected on currentSeqLen during KV-cache pattern matching. + std::optional seqLenClipMin; + std::optional seqLenClipMax; }; // Helper to try detecting prefix causal pattern: add(row_indices, offset) @@ -2029,13 +2036,25 @@ struct AttentionRewritePattern : public OpRewritePattern { return unwrappedOffset; } - // Helper to try detecting KV-cache pattern - // Returns the seqLen value if successful - FailureOr + // Result of KV-cache pattern detection + struct KVCacheResult { + Value seqLen; + std::optional clipMin; + std::optional clipMax; + }; + + // Helper to try detecting KV-cache pattern. + // Also detects an optional clip (min(max(x, lo), hi)) on currentSeqLen. + FailureOr tryKVCachePattern(Value input, const DenseSet &seqLenSkip) const { DenseSet expandAndCollapse{ tensor::CollapseShapeOp::getOperationName(), tensor::ExpandShapeOp::getOperationName()}; + DenseSet expandCollapseMinMax{ + tensor::CollapseShapeOp::getOperationName(), + tensor::ExpandShapeOp::getOperationName(), + tosa::MaximumOp::getOperationName(), + tosa::MinimumOp::getOperationName()}; FailureOr maybeNonOne = mulBroadcast(input); if (failed(maybeNonOne)) return failure(); @@ -2050,8 +2069,19 @@ struct AttentionRewritePattern : public OpRewritePattern { !llvm::all_of(shape.slice(2), [](int32_t v) { return v == 1; })) return failure(); + // Try to detect a clip pattern on currentSeqLen before skipping through + // min/max. The clip (min(max(x, lo), hi)) may wrap the block argument + // and applies to all masks that use currentSeqLen. + KVCacheResult result; + auto maybeClip = tryClipPattern(maybeNonOne.value()); + if (succeeded(maybeClip)) { + result.clipMin = maybeClip->clipMin; + result.clipMax = maybeClip->clipMax; + } + + // Skip through expand/collapse/min/max to reach the block argument auto maybeCurrentSeqLen = - getValueSkipping(maybeNonOne.value(), expandAndCollapse); + getValueSkipping(maybeNonOne.value(), expandCollapseMinMax); assert(succeeded(maybeCurrentSeqLen) && "Must have non-reshape op"); Value currentSeqLen = maybeCurrentSeqLen.value(); @@ -2059,7 +2089,128 @@ struct AttentionRewritePattern : public OpRewritePattern { if (!isI32BlockArgument(currentSeqLen, seqLenSkip)) return failure(); - return currentSeqLen; + result.seqLen = currentSeqLen; + return result; + } + + // Struct for clip detection result + struct ClipBounds { + int32_t clipMin; + int32_t clipMax; + }; + + // Helper to detect a clip pattern on a value: + // tosa.minimum(tosa.maximum(x, constLo), constHi) + FailureOr tryClipPattern(Value input) const { + DenseSet expandAndCollapse{ + tensor::CollapseShapeOp::getOperationName(), + tensor::ExpandShapeOp::getOperationName()}; + + // Helper to extract a splat i32 constant from a value + auto extractI32Constant = [&](Value val) -> std::optional { + auto maybeSkipped = getValueSkipping(val, expandAndCollapse); + Value v = succeeded(maybeSkipped) ? maybeSkipped.value() : val; + DenseElementsAttr attr; + if (!matchPattern(v, m_Constant(&attr))) + return std::nullopt; + if (!attr.getElementType().isInteger(32) || !attr.isSplat()) + return std::nullopt; + return attr.getSplatValue(); + }; + + // Look for tosa.minimum (the outer clip op) + auto maybeMin = + getDefiningOpSkipping(input, expandAndCollapse); + if (failed(maybeMin)) + return failure(); + auto minOp = maybeMin.value(); + + // One input of minimum is a constant (clipMax), the other is maximum + Value maxCandidate; + std::optional clipMax; + clipMax = extractI32Constant(minOp.getInput2()); + if (clipMax) { + maxCandidate = minOp.getInput1(); + } else { + clipMax = extractI32Constant(minOp.getInput1()); + if (clipMax) + maxCandidate = minOp.getInput2(); + else + return failure(); + } + + // Look for tosa.maximum (the inner clip op) + auto maybeMax = + getDefiningOpSkipping(maxCandidate, expandAndCollapse); + if (failed(maybeMax)) + return failure(); + auto maxOp = maybeMax.value(); + + // One input of maximum is a constant (clipMin) + std::optional clipMin; + clipMin = extractI32Constant(maxOp.getInput2()); + if (!clipMin) + clipMin = extractI32Constant(maxOp.getInput1()); + if (!clipMin) + return failure(); + + return ClipBounds{*clipMin, *clipMax}; + } + + // Helper to try detecting sliding window pattern: + // greater(add(seqLen, negative_const_offset) * broadcast, col_indices) + // Returns the window size if successful. + FailureOr + trySlidingWindowPattern(Value input, + const DenseSet &seqLenSkip) const { + DenseSet expandAndCollapse{ + tensor::CollapseShapeOp::getOperationName(), + tensor::ExpandShapeOp::getOperationName()}; + + // Trace through broadcast multiplication (mul by 1) + FailureOr maybeNonOne = mulBroadcast(input); + if (failed(maybeNonOne)) + maybeNonOne = input; + + // Look for add(seqLen, constant_offset) + auto maybeAdd = getDefiningOpSkipping(maybeNonOne.value(), + expandAndCollapse); + if (failed(maybeAdd)) + return failure(); + + auto add = maybeAdd.value(); + + // One operand of the add is currentSeqLen (already tracked by KV-cache), + // the other is a negative constant (-windowSize). Try both operands. + Value seqLenOperand; + auto tryExtractNegativeConst = [&](Value candidate, + Value other) -> FailureOr { + auto maybeSkipped = getValueSkipping(candidate, expandAndCollapse); + Value constVal = + succeeded(maybeSkipped) ? maybeSkipped.value() : candidate; + + DenseElementsAttr constAttr; + if (!matchPattern(constVal, m_Constant(&constAttr))) + return failure(); + if (!constAttr.getElementType().isInteger(32) || !constAttr.isSplat()) + return failure(); + + int32_t offset = constAttr.getSplatValue(); + if (offset >= 0) + return failure(); + seqLenOperand = other; + return -static_cast(offset); + }; + + auto maybeWindowSize = + tryExtractNegativeConst(add.getInput2(), add.getInput1()); + if (failed(maybeWindowSize)) + maybeWindowSize = + tryExtractNegativeConst(add.getInput1(), add.getInput2()); + if (failed(maybeWindowSize)) + return failure(); + + return maybeWindowSize.value(); } /* @@ -2231,26 +2382,45 @@ struct AttentionRewritePattern : public OpRewritePattern { auto greater = maybeGreater.value(); - // input1 must be column indices (constant range from 0) - if (failed(isConstantRange(greater.getInput1(), 0))) - return; - - Value input2 = greater.getInput2(); + // Standard direction: greater(col_indices, value) + // Used for KV-cache and prefix-causal masks + if (succeeded(isConstantRange(greater.getInput1(), 0))) { + Value input2 = greater.getInput2(); + + // Try KV-cache pattern (scalar seqLen) if not already found + if (!result.seqLen) { + auto maybeKVCache = tryKVCachePattern(input2, seqLenSkip); + if (succeeded(maybeKVCache)) { + auto kvResult = maybeKVCache.value(); + result.seqLen = kvResult.seqLen; + result.seqLenClipMin = kvResult.clipMin; + result.seqLenClipMax = kvResult.clipMax; + } + } - // Try KV-cache pattern (scalar seqLen) if not already found - if (!result.seqLen) { - auto maybeKVCache = tryKVCachePattern(input2, seqLenSkip); - if (succeeded(maybeKVCache)) { - result.seqLen = maybeKVCache.value(); + // Try prefix causal pattern (row_indices + offset) if not already found + if (!result.prefixOffset) { + auto maybePrefixCausal = tryPrefixCausalPattern(input2, seqLenSkip); + if (succeeded(maybePrefixCausal)) { + result.prefixOffset = maybePrefixCausal.value(); + } } + return; } - // Try prefix causal pattern (row_indices + offset) if not already found - if (!result.prefixOffset) { - auto maybePrefixCausal = tryPrefixCausalPattern(input2, seqLenSkip); - if (succeeded(maybePrefixCausal)) { - result.prefixOffset = maybePrefixCausal.value(); + // Reversed direction: greater(value, col_indices) + // Used for sliding window mask where value = seqLen + negative_offset + if (succeeded(isConstantRange(greater.getInput2(), 0))) { + Value input1 = greater.getInput1(); + + // Try sliding window pattern if not already found + if (!result.slidingWindowSize) { + auto maybeSlidingWindow = trySlidingWindowPattern(input1, seqLenSkip); + if (succeeded(maybeSlidingWindow)) { + result.slidingWindowSize = maybeSlidingWindow.value(); + } } + return; } } @@ -2270,31 +2440,39 @@ struct AttentionRewritePattern : public OpRewritePattern { DenseSet seqLenSkip{tensor::CollapseShapeOp::getOperationName(), tensor::ExpandShapeOp::getOperationName(), tosa::TransposeOp::getOperationName(), - tosa::MulOp::getOperationName()}; + tosa::MulOp::getOperationName(), + tosa::MaximumOp::getOperationName(), + tosa::MinimumOp::getOperationName()}; Value inputToContinue = select.getInput3(); - SeqLenMaskResult currentResult{inputToContinue, nullptr, nullptr}; + SeqLenMaskResult currentResult{inputToContinue, nullptr, nullptr, + std::nullopt, std::nullopt, std::nullopt}; // Analyze the first (outer) select analyzeSelectForSeqLenMask(select, currentResult, opsToSkip, seqLenSkip); - // Check if the inputToContinue (input3) is another chained select with -inf - // This handles the case where KVCache and prefix causal use separate - // selects. + // Check if the inputToContinue (input3) is another chained select with + // -inf. This handles cases where multiple mask patterns (KVCache, prefix + // causal, sliding window) use separate selects. bool haveSeqLen = currentResult.seqLen != nullptr; bool havePrefixOffset = currentResult.prefixOffset != nullptr; + bool haveSlidingWindow = currentResult.slidingWindowSize.has_value(); - if (haveSeqLen != havePrefixOffset) { + // Try chaining if we found at least one pattern but not all + bool foundAny = haveSeqLen || havePrefixOffset || haveSlidingWindow; + bool foundAll = haveSeqLen && havePrefixOffset && haveSlidingWindow; + if (foundAny && !foundAll) { auto maybeChainedSelect = getSelectWithNegInf(inputToContinue); if (succeeded(maybeChainedSelect)) { auto chainedSelect = maybeChainedSelect.value(); // Try to analyze the chained select for the missing pattern analyzeSelectForSeqLenMask(chainedSelect, currentResult, opsToSkip, seqLenSkip); - // Only update inputToContinue if we found the complementary pattern + // Only update inputToContinue if we found a complementary pattern bool foundComplementary = (!haveSeqLen && currentResult.seqLen) || - (!havePrefixOffset && currentResult.prefixOffset); + (!havePrefixOffset && currentResult.prefixOffset) || + (!haveSlidingWindow && currentResult.slidingWindowSize.has_value()); if (foundComplementary) { currentResult.inputToContinue = chainedSelect.getInput3(); } @@ -2302,7 +2480,8 @@ struct AttentionRewritePattern : public OpRewritePattern { } // We need at least one pattern to be detected - if (!currentResult.seqLen && !currentResult.prefixOffset) + if (!currentResult.seqLen && !currentResult.prefixOffset && + !currentResult.slidingWindowSize) return failure(); return currentResult; @@ -2890,16 +3069,21 @@ struct AttentionRewritePattern : public OpRewritePattern { return failure(); } - // Detect sequence length masking patterns (KV-cache or prefix causal) - // Note that non KV-Cache fusions might have tosa.select - // so, if the checks fail, we just keep going + // Detect sequence length masking patterns (KV-cache, prefix causal, + // or sliding window). Note that non KV-Cache fusions might have + // tosa.select so, if the checks fail, we just keep going Value kvCacheInput, currentSeqLen, prefixOffset; + std::optional slidingWindowSize; + std::optional seqLenClipMin, seqLenClipMax; auto maybeSeqLenMask = getSeqLenMask(softmaxInput); if (succeeded(maybeSeqLenMask)) { auto result = maybeSeqLenMask.value(); kvCacheInput = result.inputToContinue; currentSeqLen = result.seqLen; prefixOffset = result.prefixOffset; + slidingWindowSize = result.slidingWindowSize; + seqLenClipMin = result.seqLenClipMin; + seqLenClipMax = result.seqLenClipMax; } else { kvCacheInput = softmaxInput; } @@ -2946,6 +3130,12 @@ struct AttentionRewritePattern : public OpRewritePattern { LLVM_DEBUG(llvm::dbgs() << "isCausal = " << isCausal << "\n"); LLVM_DEBUG(llvm::dbgs() << "isPrefixCausal = " << (bool)prefixOffset << "\n"); + LLVM_DEBUG(llvm::dbgs() + << "isSlidingWindow = " << slidingWindowSize.has_value() + << (slidingWindowSize + ? " (size=" + std::to_string(*slidingWindowSize) + ")" + : "") + << "\n"); if (isDotProduct && hasReduceOp) return failure(); if (!isDotProduct && !hasReduceOp) @@ -2972,6 +3162,9 @@ struct AttentionRewritePattern : public OpRewritePattern { AttentionMatcherValues attentionMatcherValues; attentionMatcherValues.isCausal = isCausal; attentionMatcherValues.prefixOffset = prefixOffset; + attentionMatcherValues.slidingWindowSize = slidingWindowSize; + attentionMatcherValues.seqLenClipMin = seqLenClipMin; + attentionMatcherValues.seqLenClipMax = seqLenClipMax; attentionMatcherValues.softmaxType = softmaxType; attentionMatcherValues.softmaxValues = softmaxMatcherValues; attentionMatcherValues.lse = lse; @@ -3062,6 +3255,34 @@ struct AttentionRewritePattern : public OpRewritePattern { prepareBlockArgTensor(currentSeqLen); prepareBlockArgTensor(prefixOffset); + // Apply seqLen clip if detected during KV-cache pattern matching. + // The original model may have clip(arg, lo, hi) on currentSeqLen which + // was traced through to reach the block argument. The clip is a property + // of currentSeqLen itself, used by all masks (KV-cache, sliding window). + if (currentSeqLen && (attentionMatcherValues.seqLenClipMin.has_value() || + attentionMatcherValues.seqLenClipMax.has_value())) { + auto seqLenType = cast(currentSeqLen.getType()); + auto elemTy = seqLenType.getElementType(); + if (attentionMatcherValues.seqLenClipMin.has_value()) { + auto minAttr = DenseElementsAttr::get( + seqLenType, rewriter.getIntegerAttr( + elemTy, *attentionMatcherValues.seqLenClipMin)); + Value clipMinConst = + tosa::ConstOp::create(rewriter, loc, seqLenType, minAttr); + currentSeqLen = tosa::MaximumOp::create(rewriter, loc, seqLenType, + currentSeqLen, clipMinConst); + } + if (attentionMatcherValues.seqLenClipMax.has_value()) { + auto maxAttr = DenseElementsAttr::get( + seqLenType, rewriter.getIntegerAttr( + elemTy, *attentionMatcherValues.seqLenClipMax)); + Value clipMaxConst = + tosa::ConstOp::create(rewriter, loc, seqLenType, maxAttr); + currentSeqLen = tosa::MinimumOp::create(rewriter, loc, seqLenType, + currentSeqLen, clipMaxConst); + } + } + UnitAttr causalAttr = isCausal ? rewriter.getUnitAttr() : nullptr; ElementwiseRegionFinder elemwiseRegion = attentionMatcherValues.preSoftmaxElementwiseFinder; @@ -3072,6 +3293,11 @@ struct AttentionRewritePattern : public OpRewritePattern { std::tie(queries, keys, values, numHeadsQ, numHeadsKV) = getGQAValues( rewriter, firstMatMulOp.getA(), firstMatMulOp.getB(), op.getB()); + IntegerAttr slidingWindowSizeAttr; + if (attentionMatcherValues.slidingWindowSize.has_value()) + slidingWindowSizeAttr = rewriter.getI32IntegerAttr( + attentionMatcherValues.slidingWindowSize.value()); + rock::AttentionOp attnOp = rock::AttentionOp::create( rewriter, loc, outputType, lseType, queries, keys, values, elementwiseOtherArgs, currentSeqLen, prefixOffset, output, lseOut, @@ -3081,7 +3307,7 @@ struct AttentionRewritePattern : public OpRewritePattern { /*kTransposed=*/nullptr, /*vTransposed=*/nullptr, /*oTransposed=*/nullptr, causalAttr, - /*splitKV=*/rewriter.getI32IntegerAttr(1), + /*splitKV=*/rewriter.getI32IntegerAttr(1), slidingWindowSizeAttr, /*features=*/nullptr, rewriter.getAttr(rock::StoreMethod::Set), softmaxTypeAttr, diff --git a/mlir/lib/Dialect/Rock/IR/RockDialect.cpp b/mlir/lib/Dialect/Rock/IR/RockDialect.cpp index 5b7c3131817e..52914cb74829 100644 --- a/mlir/lib/Dialect/Rock/IR/RockDialect.cpp +++ b/mlir/lib/Dialect/Rock/IR/RockDialect.cpp @@ -2799,6 +2799,28 @@ void ThreadwiseGemmAccelOp::getEffects( getGemmMatrixEffects(*this, effects); } +// Validate sliding window constraints common to attention-like ops. +static LogicalResult +verifySlidingWindowConstraints(Operation *op, + std::optional slidingWindowSize, + Value currentSeqLen, int64_t maxSeqLen) { + if (!slidingWindowSize) + return success(); + int32_t windowSize = static_cast(*slidingWindowSize); + + if (windowSize <= 0) + return op->emitError("slidingWindowSize must be positive"); + + if (!currentSeqLen) + return op->emitError("slidingWindowSize requires currentSeqLen to be set"); + + if (windowSize > maxSeqLen) + return op->emitError( + "slidingWindowSize must not exceed max sequence length"); + + return success(); +} + //===----------------------------------------------------------------------===// // GridwiseAttentionAccelOp //===----------------------------------------------------------------------===// @@ -2832,6 +2854,9 @@ LogicalResult GridwiseAttentionAccelOp::verify() { if (!getEnableSoftmax() && getCausal()) return emitError("causal only works for attention."); + if (!getEnableSoftmax() && getSlidingWindowSize()) + return emitError("slidingWindowSize only works for attention."); + // Validate prefix offset constraints // prefixOffset requires causal to be enabled (prefix causal = causal + // prefixOffset) @@ -2840,6 +2865,18 @@ LogicalResult GridwiseAttentionAccelOp::verify() { "prefixOffset requires causal to be enabled. " "Prefix causal attention is causal masking with an offset."); + // Keys are normalized to [G, K, M] where M = key seq len (max seq len). + // Use pre-padding value if available, otherwise use the current shape. + ShapedType kType = cast(getKeys().getType()); + int64_t maxSeqLen = + getPrePadG0M().value_or(APInt(64, kType.getShape()[2])).getSExtValue(); + + // Validate sliding window constraints. + if (failed(verifySlidingWindowConstraints(getOperation(), + getSlidingWindowSize(), + getCurrentSeqLen(), maxSeqLen))) + return failure(); + return success(); } @@ -3382,6 +3419,16 @@ LogicalResult AttentionOp::verify() { "prefixOffset requires causal to be enabled. " "Prefix causal attention is causal masking with an offset."); + // Validate sliding window constraints. + // Max seq len is the key N dimension. + ShapedType kType = cast(getKeys().getType()); + ArrayRef kLastDims = kType.getShape().slice(kType.getRank() - 2); + int64_t maxSeqLen = getKTransposed() ? kLastDims[0] : kLastDims[1]; + if (failed(verifySlidingWindowConstraints(getOperation(), + getSlidingWindowSize(), + getCurrentSeqLen(), maxSeqLen))) + return failure(); + return verifyGemmPlusGemmLikeOp(*this, getCurrentSeqLen(), getLse(), getNumHeadsQ(), getNumHeadsKV()); } diff --git a/mlir/lib/Dialect/Rock/Transforms/DetectFlashDecoding.cpp b/mlir/lib/Dialect/Rock/Transforms/DetectFlashDecoding.cpp index 1067f4abe6ef..f4c868bb1d7f 100644 --- a/mlir/lib/Dialect/Rock/Transforms/DetectFlashDecoding.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/DetectFlashDecoding.cpp @@ -462,9 +462,9 @@ struct DetectFlashDecodingPattern : public OpRewritePattern { op.getNumHeadsKVAttr(), op.getQTransposedAttr(), op.getKTransposedAttr(), op.getVTransposedAttr(), op.getOTransposedAttr(), op.getCausalAttr(), - rewriter.getI32IntegerAttr(splitKVFromQ), op.getFeaturesAttr(), - op.getStoreMethodAttr(), op.getSoftmaxTypeAttr(), op.getParams0Attr(), - op.getParams1Attr(), op.getFirstGemmIndicesAttr(), + rewriter.getI32IntegerAttr(splitKVFromQ), op.getSlidingWindowSizeAttr(), + op.getFeaturesAttr(), op.getStoreMethodAttr(), op.getSoftmaxTypeAttr(), + op.getParams0Attr(), op.getParams1Attr(), op.getFirstGemmIndicesAttr(), /*preSoftmaxHasSplitKVTransforms=*/rewriter.getBoolAttr(true)); // Copy the preSoftmax elementwise region if it exists diff --git a/mlir/lib/Dialect/Rock/Transforms/GemmToGridwise.cpp b/mlir/lib/Dialect/Rock/Transforms/GemmToGridwise.cpp index 52420e10f2bd..1ba83bdde8ff 100644 --- a/mlir/lib/Dialect/Rock/Transforms/GemmToGridwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/GemmToGridwise.cpp @@ -488,8 +488,9 @@ static LogicalResult commonAttentionGemmElmtGemm( ConversionPatternRewriter &rw, RockGemmGemmWrapperInterface op, Value a, Value b, Value c, Value out, Value lse, Value currentSeqLen, Value prefixOffset, UnitAttr causal, IntegerAttr splitKV, - ValueRange elementwiseInputs, Region &preSecondOpRegion, bool enableSoftmax, - TypeAttr softmaxType, int64_t numHeadsQ, int64_t numHeadsKV, + IntegerAttr slidingWindowSize, ValueRange elementwiseInputs, + Region &preSecondOpRegion, bool enableSoftmax, TypeAttr softmaxType, + int64_t numHeadsQ, int64_t numHeadsKV, std::optional> bufferDeps, BoolAttr preSoftmaxHasSplitKVTransforms) { @@ -611,8 +612,8 @@ static LogicalResult commonAttentionGemmElmtGemm( auto newOp = GridwiseAttentionAccelOp::create( rw, loc, a, b, c, elementwiseInputs, currentSeqLen, prefixOffset, out, - lse, causal, splitKV, op.getGemmFeaturesAttr(), op.getStoreMethodAttr(), - blockSizeAttr, gridSizeAttr, + lse, causal, splitKV, slidingWindowSize, op.getGemmFeaturesAttr(), + op.getStoreMethodAttr(), blockSizeAttr, gridSizeAttr, /*disableQBypassLDS=*/nullptr, prePadG0MAttr, prePadG0NAttr, numRepeatsGQA, softmaxType, params0, params1, rw.getDenseI64ArrayAttr(op.getFirstGemmIndices()), @@ -1088,8 +1089,8 @@ AttentionRewritePattern::matchAndRewrite(AttentionOp op, rw, op, adaptor.getQueries(), adaptor.getKeys(), adaptor.getValues(), adaptor.getOut(), adaptor.getLse(), adaptor.getCurrentSeqLen(), adaptor.getPrefixOffset(), adaptor.getCausalAttr(), - adaptor.getSplitKVAttr(), adaptor.getPreSoftmaxElemWiseInputs(), - op.getPreSoftmaxBody(), + adaptor.getSplitKVAttr(), adaptor.getSlidingWindowSizeAttr(), + adaptor.getPreSoftmaxElemWiseInputs(), op.getPreSoftmaxBody(), /*enableSoftmax=*/true, op.getSoftmaxTypeAttr(), adaptor.getNumHeadsQ(), adaptor.getNumHeadsKV(), /*bufferDeps=*/std::nullopt, @@ -1104,7 +1105,8 @@ LogicalResult GemmElementwiseGemmRewritePattern::matchAndRewrite( rw, op, adaptor.getA(), adaptor.getB(), adaptor.getC(), adaptor.getOut(), /*lse=*/nullptr, /*currentSeqLen=*/nullptr, /*prefixOffset=*/nullptr, /*causal=*/nullptr, - splitKV, adaptor.getElemwiseInputs(), op.getPreSecondGemmBody(), + splitKV, /*slidingWindowSize=*/nullptr, adaptor.getElemwiseInputs(), + op.getPreSecondGemmBody(), /*enableSoftmax=*/false, /*softmaxType=*/nullptr, /*numHeadsQ=*/1, /*numHeadsKV=*/1, std::cref(bufferDeps), /*preSoftmaxHasSplitKVTransforms=*/rw.getBoolAttr(false)); diff --git a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp index 8c91c4083bb2..92c28e0d4b44 100644 --- a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp @@ -1256,17 +1256,17 @@ struct GridwiseAttentionAccelRewritePattern } } - enum class OutOfScopeType { KVCache, Causal, PrefixCausal }; + enum class OutOfScopeType { KVCache, Causal, PrefixCausal, SlidingWindow }; void setGemm0OutputOutOfScope( PatternRewriter &rewriter, Location loc, OutOfScopeType outOfScopeType, layout::GridCoordinates gridCoords, Value gemm0OutBuffer, RegsAsMatrixSubTiles gemm0OutSubTileViews, bool enabled, Value mLoopIV, Value gemm0MBlocksLastIter, Value currentSeqLen, Value prefixOffset, - IntegerAttr numRepeatsGQA) const { + IntegerAttr numRepeatsGQA, Value slidingWindowLowerBound) const { if (enabled) { // For KVCache, we only need to mask on the last iteration, but for causal - // masking we need to mask on every iteration. + // and sliding window masking we need to mask on every iteration. bool needsLastIterCheck = (outOfScopeType == OutOfScopeType::KVCache); // Use a lambda to generate the masking logic. @@ -1338,6 +1338,15 @@ struct GridwiseAttentionAccelRewritePattern mIndex, threshold); break; } + case OutOfScopeType::SlidingWindow: { + // Sliding window: mask when key_pos < max(0, currentSeqLen - + // windowSize). slidingWindowLowerBound is precomputed as + // max(0, currentSeqLen - windowSize). + assert(slidingWindowLowerBound != nullptr); + isInvalid = arith::CmpIOp::create(b, loc, arith::CmpIPredicate::ult, + mIndex, slidingWindowLowerBound); + break; + } } scf::IfOp ifOp = scf::IfOp::create(b, loc, isInvalid, @@ -1738,17 +1747,18 @@ struct GridwiseAttentionAccelRewritePattern return viewBuilder.get(); } - std::tuple + std::tuple getMLoopInfo(PatternRewriter &rewriter, Location loc, layout::AttnGridCoordinates gridCoordsGemm0, Value currentSeqLenTensor, Value prefixOffsetTensor, int64_t gemm0M, int64_t gemm0N, int64_t gemm0MPerBlock, int64_t gemm0NPerBlock, int64_t splitKV, bool isCausal, - bool isKVCache, bool isPrefixCausal, + bool isKVCache, bool isPrefixCausal, int64_t slidingWindowSize, IntegerAttr numRepeatsGQA = nullptr) const { Value gemm0MBlocksLastIter; Value currentSeqLen; Value prefixOffset; + Value slidingWindowLowerBound; Value effectiveSeqLen; Value start, end; @@ -1792,13 +1802,26 @@ struct GridwiseAttentionAccelRewritePattern loc, rewriter.getIndexType(), loadedValue); }; - // This is needed for KV Cache/Causal/Prefix Causal masking support - if (isCausal || isKVCache || isPrefixCausal) { + // This is needed for KV Cache/Causal/Prefix Causal/Sliding Window masking + if (isCausal || isKVCache || isPrefixCausal || slidingWindowSize > 0) { if (isKVCache) { currentSeqLen = loadTensorValue(currentSeqLenTensor); effectiveSeqLen = currentSeqLen; } + // Compute sliding window lower bound: max(0, currentSeqLen - windowSize) + if (slidingWindowSize > 0) { + assert(currentSeqLen != nullptr && + "sliding window requires currentSeqLen (KV-cache)"); + Value constWindowSize = rewriter.createOrFold( + loc, slidingWindowSize); + Value zero = rewriter.createOrFold(loc, 0); + Value lowerBound = arith::SubIOp::create(rewriter, loc, currentSeqLen, + constWindowSize); + slidingWindowLowerBound = + arith::MaxSIOp::create(rewriter, loc, lowerBound, zero); + } + if (isCausal || isPrefixCausal) { // Compute the last Q position in the block. // (nIndex + 1) * NPerBlock - 1. @@ -1883,6 +1906,17 @@ struct GridwiseAttentionAccelRewritePattern gemm0MIterations); end = arith::MinUIOp::create(rewriter, loc, end, endSplitKV); } + + // Adjust start for sliding window: skip M-blocks that are entirely + // below the window. All positions in those blocks would be masked to + // -inf anyway, so we can avoid the loads and GEMMs altogether. + if (slidingWindowSize > 0) { + Value slidingWindowStart = rewriter.createOrFold( + loc, slidingWindowLowerBound, constGemm0MPerBlock); + start = + arith::MaxSIOp::create(rewriter, loc, start, slidingWindowStart); + } + // compute last iteration of the block, this will be used later in // setGemm0OutputOutOfScope() gemm0MBlocksLastIter = @@ -1907,7 +1941,7 @@ struct GridwiseAttentionAccelRewritePattern end = rewriter.createOrFold(loc, gemm0MBlocks); } return std::make_tuple(start, end, gemm0MBlocksLastIter, currentSeqLen, - prefixOffset); + prefixOffset, slidingWindowLowerBound); } // Helper function to determine if early exit optimization is possible. @@ -2043,6 +2077,8 @@ struct GridwiseAttentionAccelRewritePattern bool isKVCache = currentSeqLenTensor != nullptr; bool isCausal = op.getCausal(); bool isPrefixCausal = isCausal && prefixOffsetTensor; + int64_t slidingWindowSize = + static_cast(op.getSlidingWindowSize().value_or(0)); int64_t splitKV = op.getSplitKV(); // Gemm0 out is casted to be softmaxType (if null, it's casted to elemTypeV) @@ -2451,13 +2487,16 @@ struct GridwiseAttentionAccelRewritePattern Value gemm0MBlocksLastIter; Value currentSeqLen; Value prefixOffset; + Value slidingWindowLowerBound; Value start, end; // get mLoop - std::tie(start, end, gemm0MBlocksLastIter, currentSeqLen, prefixOffset) = + std::tie(start, end, gemm0MBlocksLastIter, currentSeqLen, prefixOffset, + slidingWindowLowerBound) = getMLoopInfo(rewriter, loc, gridCoordsGemm0mIter0, currentSeqLenTensor, prefixOffsetTensor, gemm0M, gemm0N, gemm0MPerBlock, gemm0NPerBlock, splitKV, isCausal, isKVCache, - isPrefixCausal, op.getNumRepeatsGQAAttr()); + isPrefixCausal, slidingWindowSize, + op.getNumRepeatsGQAAttr()); // Early exit: Skip all computation when there's no work but always write // output. @@ -2772,14 +2811,13 @@ struct GridwiseAttentionAccelRewritePattern // KV cache masking is independent of causal masking - it masks out // positions beyond currentSeqLen (padding). Apply it whenever KV // cache is enabled, regardless of causal/prefix-causal mode. - if (isKVCache) { - setGemm0OutputOutOfScope(rewriter, loc, OutOfScopeType::KVCache, - gridCoordsGemm0, softmaxInputBuffer, - gemm0OutSubTileViewsTr, isKVCache, mLoopIV, - gemm0MBlocksLastIter, currentSeqLen, - /*prefixOffset=*/nullptr, - /*numRepeatsGQA=*/nullptr); - } + setGemm0OutputOutOfScope(rewriter, loc, OutOfScopeType::KVCache, + gridCoordsGemm0, softmaxInputBuffer, + gemm0OutSubTileViewsTr, isKVCache, mLoopIV, + gemm0MBlocksLastIter, currentSeqLen, + /*prefixOffset=*/nullptr, + /*numRepeatsGQA=*/nullptr, + /*slidingWindowLowerBound=*/nullptr); // Causal masking: either prefix-causal or standard causal if (isPrefixCausal) { @@ -2790,7 +2828,8 @@ struct GridwiseAttentionAccelRewritePattern gemm0OutSubTileViewsTr, isPrefixCausal, mLoopIV, gemm0MBlocksLastIter, /*currentSeqLen=*/nullptr, prefixOffset, - op.getNumRepeatsGQAAttr()); + op.getNumRepeatsGQAAttr(), + /*slidingWindowLowerBound=*/nullptr); } else if (isCausal) { // Standard causal masking: mask when key > query setGemm0OutputOutOfScope( @@ -2798,9 +2837,21 @@ struct GridwiseAttentionAccelRewritePattern softmaxInputBuffer, gemm0OutSubTileViewsTr, isCausal, mLoopIV, gemm0MBlocksLastIter, /*currentSeqLen=*/nullptr, - /*prefixOffset=*/nullptr, op.getNumRepeatsGQAAttr()); + /*prefixOffset=*/nullptr, op.getNumRepeatsGQAAttr(), + /*slidingWindowLowerBound=*/nullptr); } + // Sliding window masking: mask when key_pos < max(0, currentSeqLen - + // windowSize). This is independent of causal masking and applies + // alongside KV-cache masking. + setGemm0OutputOutOfScope( + rewriter, loc, OutOfScopeType::SlidingWindow, gridCoordsGemm0, + softmaxInputBuffer, gemm0OutSubTileViewsTr, slidingWindowSize > 0, + mLoopIV, gemm0MBlocksLastIter, + /*currentSeqLen=*/nullptr, + /*prefixOffset=*/nullptr, /*numRepeatsGQA=*/nullptr, + slidingWindowLowerBound); + APInt reductionAxis = APInt(64, 1); // Softmax max reduction Value ldsReductionWorkspaceByteBuffer = createLDSByteBuffer( diff --git a/mlir/lib/Dialect/Rock/Transforms/SortDimensionsMemoryLayout.cpp b/mlir/lib/Dialect/Rock/Transforms/SortDimensionsMemoryLayout.cpp index a148f67fc52d..4e7c8173052d 100644 --- a/mlir/lib/Dialect/Rock/Transforms/SortDimensionsMemoryLayout.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/SortDimensionsMemoryLayout.cpp @@ -622,8 +622,9 @@ struct AttentionRewritePattern : public OpRewritePattern { op.getPrefixOffset(), op.getOut(), op.getLse(), op.getNumHeadsQAttr(), op.getNumHeadsKVAttr(), transposedQ, transposedK, transposedV, op.getOTransposedAttr(), op.getCausalAttr(), op.getSplitKVAttr(), - op.getFeaturesAttr(), op.getStoreMethodAttr(), op.getSoftmaxTypeAttr(), - op.getParams0Attr(), op.getParams1Attr(), op.getFirstGemmIndicesAttr(), + op.getSlidingWindowSizeAttr(), op.getFeaturesAttr(), + op.getStoreMethodAttr(), op.getSoftmaxTypeAttr(), op.getParams0Attr(), + op.getParams1Attr(), op.getFirstGemmIndicesAttr(), op.getPreSoftmaxHasSplitKVTransformsAttr()); // copy linalg::GenericOp if there's any diff --git a/mlir/test/Conversion/TosaToRock/tosa-to-rock-attention-kvcache.mlir b/mlir/test/Conversion/TosaToRock/tosa-to-rock-attention-kvcache.mlir index f2db6da1a630..ba24083d1923 100644 --- a/mlir/test/Conversion/TosaToRock/tosa-to-rock-attention-kvcache.mlir +++ b/mlir/test/Conversion/TosaToRock/tosa-to-rock-attention-kvcache.mlir @@ -410,3 +410,82 @@ func.func @mlir_causal_attention_nokvcache_wrongrange(%arg0: tensor<24576xf16>, %collapsed_8 = tensor.collapse_shape %29 [[0, 1, 2, 3]] : tensor<1x2x32x128xf16> into tensor<8192xf16> return %collapsed_8 : tensor<8192xf16> } + +// CHECK-LABEL:func @mlir_attention_kvcache_sliding_window +// CHECK: %[[MAX:.*]] = tosa.maximum +// CHECK: %[[CLIP:.*]] = tosa.minimum %[[MAX]], {{.*}} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> +// CHECK: currentSeqLen = (%[[CLIP]] : tensor<2xi32>) +// CHECK: slidingWindowSize = 3 +func.func @mlir_attention_kvcache_sliding_window(%arg0: tensor<1xi32>, %arg1: tensor<12xf16>, %arg2: tensor<32xf16>, %arg3: tensor<32xf16>) -> tensor<4xf16> attributes {kernel = "mixr"} { + %0 = "tosa.const"() <{values = dense<4> : tensor<1x1x1x1xi32>}> : () -> tensor<1x1x1x1xi32> + %1 = tosa.const_shape {values = dense<4> : tensor<1xindex>} : () -> !tosa.shape<1> + %2 = tosa.const_shape {values = dense<[2, 8, 2]> : tensor<3xindex>} : () -> !tosa.shape<3> + %3 = tosa.const_shape {values = dense<[2, 1, 8]> : tensor<3xindex>} : () -> !tosa.shape<3> + %4 = "tosa.const"() <{values = dense<1.000000e+00> : tensor<1x2x1x8xf32>}> : () -> tensor<1x2x1x8xf32> + %5 = "tosa.const"() <{values = dense<1> : tensor<1x2x1x8xi8>}> : () -> tensor<1x2x1x8xi8> + %6 = tosa.const_shape {values = dense<8> : tensor<1xindex>} : () -> !tosa.shape<1> + %7 = "tosa.const"() <{values = dense<1> : tensor<8x1x1x1xi32>}> : () -> tensor<8x1x1x1xi32> + %8 = "tosa.const"() <{values = dense<5.000000e-01> : tensor<1x2x1x8xf16>}> : () -> tensor<1x2x1x8xf16> + %9 = "tosa.const"() <{values = dense<0xFC00> : tensor<1x2x1x8xf16>}> : () -> tensor<1x2x1x8xf16> + %10 = tosa.const_shape {values = dense<[1, 2, 1, 8]> : tensor<4xindex>} : () -> !tosa.shape<4> + %11 = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf16>}> : () -> tensor<1xf16> + %12 = tosa.const_shape {values = dense<[2, 2, 8]> : tensor<3xindex>} : () -> !tosa.shape<3> + %13 = tosa.const_shape {values = dense<[2, 1, 2]> : tensor<3xindex>} : () -> !tosa.shape<3> + %14 = tosa.const_shape {values = dense<[1, 2, 1, 2]> : tensor<4xindex>} : () -> !tosa.shape<4> + %15 = tosa.const_shape {values = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4> + %16 = "tosa.const"() <{values = dense<-3> : tensor<1x1x1x1xi32>}> : () -> tensor<1x1x1x1xi32> + %17 = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %18 = "tosa.const"() <{values = dense<1> : tensor<1x1x1x8xi32>}> : () -> tensor<1x1x1x8xi32> + %19 = tosa.const_shape {values = dense<[1, 1, 1, 8]> : tensor<4xindex>} : () -> !tosa.shape<4> + %20 = "tosa.const"() <{values = dense<[0, 1, 2, 3, 4, 5, 6, 7]> : tensor<8xi32>}> : () -> tensor<8xi32> + %21 = tosa.const_shape {values = dense<[1, 6, 1, 2]> : tensor<4xindex>} : () -> !tosa.shape<4> + %22 = tosa.const_shape {values = dense<[1, 2, 8, 2]> : tensor<4xindex>} : () -> !tosa.shape<4> + %expanded = tensor.expand_shape %arg2 [[0, 1, 2, 3]] output_shape [1, 2, 8, 2] : tensor<32xf16> into tensor<1x2x8x2xf16> + %expanded_0 = tensor.expand_shape %arg1 [[0, 1, 2, 3]] output_shape [1, 6, 1, 2] : tensor<12xf16> into tensor<1x6x1x2xf16> + %cst = arith.constant dense<[[[[0, 1, 2, 3, 4, 5, 6, 7]]]]> : tensor<1x1x1x8xi32> + %expanded_1 = tensor.expand_shape %arg0 [[0, 1, 2, 3]] output_shape [1, 1, 1, 1] : tensor<1xi32> into tensor<1x1x1x1xi32> + %23 = tosa.maximum %expanded_1, %0 : (tensor<1x1x1x1xi32>, tensor<1x1x1x1xi32>) -> tensor<1x1x1x1xi32> + %24 = tosa.minimum %23, %0 : (tensor<1x1x1x1xi32>, tensor<1x1x1x1xi32>) -> tensor<1x1x1x1xi32> + %extracted_slice = tensor.extract_slice %expanded_0[0, 0, 0, 0] [1, 2, 1, 2] [1, 1, 1, 1] : tensor<1x6x1x2xf16> to tensor<1x2x1x2xf16> + %25 = tosa.transpose %expanded {perms = array} : (tensor<1x2x8x2xf16>) -> tensor<1x2x2x8xf16> + %collapsed = tensor.collapse_shape %extracted_slice [[0, 1], [2], [3]] : tensor<1x2x1x2xf16> into tensor<2x1x2xf16> + %collapsed_2 = tensor.collapse_shape %25 [[0, 1], [2], [3]] : tensor<1x2x2x8xf16> into tensor<2x2x8xf16> + %26 = tosa.matmul %collapsed, %collapsed_2, %11, %11 {acc_type = f32} : (tensor<2x1x2xf16>, tensor<2x2x8xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<2x1x8xf16> + %expanded_3 = tensor.expand_shape %26 [[0, 1], [2], [3]] output_shape [1, 2, 1, 8] : tensor<2x1x8xf16> into tensor<1x2x1x8xf16> + %27 = tosa.mul %expanded_3, %8, %17 : (tensor<1x2x1x8xf16>, tensor<1x2x1x8xf16>, tensor<1xi8>) -> tensor<1x2x1x8xf16> + %28 = tosa.add %24, %16 : (tensor<1x1x1x1xi32>, tensor<1x1x1x1xi32>) -> tensor<1x1x1x1xi32> + %29 = tosa.mul %28, %7, %17 : (tensor<1x1x1x1xi32>, tensor<8x1x1x1xi32>, tensor<1xi8>) -> tensor<8x1x1x1xi32> + %collapsed_4 = tensor.collapse_shape %29 [[0, 1, 2, 3]] : tensor<8x1x1x1xi32> into tensor<8xi32> + %30 = tosa.greater %collapsed_4, %20 : (tensor<8xi32>, tensor<8xi32>) -> tensor<8xi1> + %31 = tosa.cast %30 : (tensor<8xi1>) -> tensor<8xi32> + %32 = tosa.cast %31 : (tensor<8xi32>) -> tensor<8xi8> + %expanded_5 = tensor.expand_shape %32 [[0, 1, 2, 3]] output_shape [1, 1, 1, 8] : tensor<8xi8> into tensor<1x1x1x8xi8> + %33 = tosa.mul %expanded_5, %5, %17 : (tensor<1x1x1x8xi8>, tensor<1x2x1x8xi8>, tensor<1xi8>) -> tensor<1x2x1x8xi8> + %34 = tosa.cast %33 : (tensor<1x2x1x8xi8>) -> tensor<1x2x1x8xi1> + %35 = tosa.select %34, %9, %27 : (tensor<1x2x1x8xi1>, tensor<1x2x1x8xf16>, tensor<1x2x1x8xf16>) -> tensor<1x2x1x8xf16> + %36 = tosa.mul %24, %18, %17 : (tensor<1x1x1x1xi32>, tensor<1x1x1x8xi32>, tensor<1xi8>) -> tensor<1x1x1x8xi32> + %37 = tosa.greater %cst, %36 : (tensor<1x1x1x8xi32>, tensor<1x1x1x8xi32>) -> tensor<1x1x1x8xi1> + %38 = tosa.cast %37 : (tensor<1x1x1x8xi1>) -> tensor<1x1x1x8xi32> + %39 = tosa.cast %38 : (tensor<1x1x1x8xi32>) -> tensor<1x1x1x8xi8> + %40 = tosa.mul %39, %5, %17 : (tensor<1x1x1x8xi8>, tensor<1x2x1x8xi8>, tensor<1xi8>) -> tensor<1x2x1x8xi8> + %41 = tosa.cast %40 : (tensor<1x2x1x8xi8>) -> tensor<1x2x1x8xi1> + %42 = tosa.select %41, %9, %35 : (tensor<1x2x1x8xi1>, tensor<1x2x1x8xf16>, tensor<1x2x1x8xf16>) -> tensor<1x2x1x8xf16> + %43 = tosa.cast %42 : (tensor<1x2x1x8xf16>) -> tensor<1x2x1x8xf32> + %44 = tosa.reduce_max %43 {axis = 3 : i32} : (tensor<1x2x1x8xf32>) -> tensor<1x2x1x1xf32> + %45 = tosa.mul %44, %4, %17 : (tensor<1x2x1x1xf32>, tensor<1x2x1x8xf32>, tensor<1xi8>) -> tensor<1x2x1x8xf32> + %46 = tosa.sub %43, %45 : (tensor<1x2x1x8xf32>, tensor<1x2x1x8xf32>) -> tensor<1x2x1x8xf32> + %47 = tosa.exp %46 : (tensor<1x2x1x8xf32>) -> tensor<1x2x1x8xf32> + %48 = tosa.reduce_sum %47 {axis = 3 : i32} : (tensor<1x2x1x8xf32>) -> tensor<1x2x1x1xf32> + %49 = tosa.mul %48, %4, %17 : (tensor<1x2x1x1xf32>, tensor<1x2x1x8xf32>, tensor<1xi8>) -> tensor<1x2x1x8xf32> + %50 = tosa.reciprocal %49 : (tensor<1x2x1x8xf32>) -> tensor<1x2x1x8xf32> + %51 = tosa.mul %47, %50, %17 : (tensor<1x2x1x8xf32>, tensor<1x2x1x8xf32>, tensor<1xi8>) -> tensor<1x2x1x8xf32> + %52 = tosa.cast %51 : (tensor<1x2x1x8xf32>) -> tensor<1x2x1x8xf16> + %collapsed_6 = tensor.collapse_shape %52 [[0, 1], [2], [3]] : tensor<1x2x1x8xf16> into tensor<2x1x8xf16> + %expanded_7 = tensor.expand_shape %arg3 [[0, 1, 2]] output_shape [2, 8, 2] : tensor<32xf16> into tensor<2x8x2xf16> + %53 = tosa.matmul %collapsed_6, %expanded_7, %11, %11 {acc_type = f32} : (tensor<2x1x8xf16>, tensor<2x8x2xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<2x1x2xf16> + %expanded_8 = tensor.expand_shape %53 [[0, 1], [2], [3]] output_shape [1, 2, 1, 2] : tensor<2x1x2xf16> into tensor<1x2x1x2xf16> + %54 = tosa.transpose %expanded_8 {perms = array} : (tensor<1x2x1x2xf16>) -> tensor<1x1x2x2xf16> + %collapsed_9 = tensor.collapse_shape %54 [[0, 1, 2, 3]] : tensor<1x1x2x2xf16> into tensor<4xf16> + return %collapsed_9 : tensor<4xf16> +} + diff --git a/mlir/test/Dialect/Rock/gridwise_attention_accel_lowering.mlir b/mlir/test/Dialect/Rock/gridwise_attention_accel_lowering.mlir index deaf0e606632..b5c7adf7bee4 100644 --- a/mlir/test/Dialect/Rock/gridwise_attention_accel_lowering.mlir +++ b/mlir/test/Dialect/Rock/gridwise_attention_accel_lowering.mlir @@ -907,3 +907,136 @@ func.func @gridwise_attn_wavespereu_outputswizzle(%arg0: memref<1474560xf16>, %a } {blockSize = 128 : i32, enableSoftmax = false, firstGemmIndices = array, splitKV = 1 : i32, gridSize = 512 : i32, operandSegmentSizes = array, params0 = #rock.accel_gemm_params, params1 = #rock.accel_gemm_params, storeMethod = #rock} : memref<4x512x4096xf16>, memref<4x512x1024xf16>, memref<4x1024x384xf16>, memref<4x4096x384xf16> return } + +// ----- + +// CHECK-LABEL: @mlir_attention +// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index +// CHECK-DAG: %[[c32:.+]] = arith.constant 32 : index +// CHECK-DAG: %[[negInf:.+]] = arith.constant 0xFF800000 : f32 +// Load current sequence length +// CHECK: %[[registers:.+]] = rock.alloc() : memref<1xi32, #gpu.address_space> +// CHECK-NEXT: rock.threadwise_read_into {forceUnroll, useIndexDiffs} [](%{{.+}}) [%{{.+}}] -> %[[registers]] : memref<2x1xi32> -> memref<1xi32, #gpu.address_space>, vector<1xi1> +// CHECK-NEXT: %[[currSeqLen:.+]] = rock.in_bounds_load %[[registers]][%[[c0]]] : memref<1xi32, #gpu.address_space>, index -> i32 +// CHECK-NEXT: %[[currSeqLenIndex:.+]] = arith.index_cast %[[currSeqLen]] : i32 to index + +// Sliding window lower bound: max(seqLen - windowSize, 0) +// CHECK-NEXT: %[[seqLenMinusWindow:.+]] = arith.subi %[[currSeqLenIndex]], %[[c3]] : index +// CHECK-NEXT: %[[slidingWindowLB:.+]] = arith.maxsi %[[seqLenMinusWindow]], %[[c0]] : index + +// Dynamic loop bound: ceil(seqLen / tileSize) +// CHECK-NEXT: %[[num:.+]] = arith.addi %[[currSeqLenIndex]], %[[c32]] : index +// CHECK-NEXT: %[[numIter:.+]] = arith.divui %[[num]], %[[c32]] : index + +// Sliding window start iteration: floor(slidingWindowLB / tileSize) +// CHECK-NEXT: %[[swStartIter:.+]] = arith.divui %[[slidingWindowLB]], %[[c32]] : index +// CHECK-NEXT: %[[swStartIterClamped:.+]] = arith.maxsi %[[swStartIter]], %[[c0]] : index + +// CHECK-NEXT: %[[lastIter:.+]] = arith.subi %[[numIter]], %[[c1]] : index + +// Outer N-tile loop +// CHECK: scf.for %[[iterIndex:.+]] = %{{.+}} to %[[numIter]] step %[[c1]] { +// KV-cache mask: on last iteration, mask positions > seqLen +// CHECK: %[[isLastIter:.+]] = arith.cmpi eq, %[[iterIndex]], %[[lastIter]] : index +// CHECK-NEXT: scf.if %[[isLastIter]] { +// CHECK: rock.transforming_for {forceUnroll, useIndexDiffs} (%[[dim0:.+]], %[[dim1:.+]], %[[dim2:.+]]) = [{{.*}}]({{.*}}), ({{.*}}) = [] + +// Sliding window mask: mask positions below the window +// CHECK: rock.transforming_for {forceUnroll, useIndexDiffs} (%[[swDim0:.+]], %[[swDim1:.+]], %[[swDim2:.+]]) = [{{.*}}]({{.*}}), ({{.*}}) = [] +// CHECK-NEXT: %[[slidingWindowCmp:.+]] = arith.cmpi ult, %[[swDim2]], %[[slidingWindowLB]] : index +// CHECK-NEXT: scf.if %[[slidingWindowCmp]] { +// CHECK-NEXT: rock.in_bounds_store %[[negInf]] + +#accel_gemm_params = #rock.accel_gemm_params +#map = affine_map<(d0, d1, d2, d3) -> ((d1 * 8 + d2) * 2 + d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d1 * 2 + d3)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)> +#map4 = affine_map<(d0, d1, d2) -> (0, d0, d1, d2)> +#map5 = affine_map<(d0, d1, d2) -> ((d0 * 8 + d1) * 2 + d2)> +#map6 = affine_map<(d0, d1, d2, d3) -> (((d0 + d1) * 2 + d2) * 2 + d3)> +#map7 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)> +#map8 = affine_map<(d0, d1) -> (d1)> +#map9 = affine_map<(d0, d1) -> (d0, 0)> +#map10 = affine_map<(d0) -> (0, d0)> +#map11 = affine_map<(d0) -> (d0)> +#map12 = affine_map<(d0, d1, d2) -> (d0, d2, d1)> +#map13 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map14 = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)> +#map15 = affine_map<(d0, d1) -> (0, d0, 0, d1)> +#map16 = affine_map<(d0, d1) -> (d0, d1)> +#transform_map = #rock.transform_map<#map by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 2, 8, 2] -> [32]> +#transform_map1 = #rock.transform_map<#map1 by [ ["dim0"] at [0]>, [] at []>, [] at []>] bounds = [1, 6, 1, 2] -> [12]> +#transform_map2 = #rock.transform_map<#map2 by [ ["dim0", "dim1", "dim2", "dim3"] at [0, 1, 2, 3]>] bounds = [1, 2, 1, 2] -> [1, 6, 1, 2]> +#transform_map3 = #rock.transform_map<#map3 by [ ["dim0", "dim1", "dim3", "dim2"] at [0, 1, 3, 2]>] bounds = [1, 2, 2, 8] -> [1, 2, 8, 2]> +#transform_map4 = #rock.transform_map<#map4 by [ ["col0", "col1"] at [0, 1]>, ["dim1"] at [2]>, ["dim2"] at [3]>] bounds = [2, 1, 2] -> [1, 2, 1, 2]> +#transform_map5 = #rock.transform_map<#map4 by [ ["col0", "col1"] at [0, 1]>, ["dim1"] at [2]>, ["dim2"] at [3]>] bounds = [2, 2, 8] -> [1, 2, 2, 8]> +#transform_map6 = #rock.transform_map<#map5 by [ ["dim0"] at [0]>] bounds = [2, 8, 2] -> [32]> +#transform_map7 = #rock.transform_map<#map6 by [ ["dim0"] at [0]>] bounds = [1, 1, 2, 2] -> [4]> +#transform_map8 = #rock.transform_map<#map7 by [ ["dim0", "dim2", "dim1", "dim3"] at [0, 1, 2, 3]>] bounds = [1, 2, 1, 2] -> [1, 1, 2, 2]> +#transform_map9 = #rock.transform_map<#map4 by [ ["exp1"] at [1]>, ["dim1"] at [2]>, ["dim2"] at [3]>, ["unit0"] at [0]>] bounds = [2, 1, 2] -> [1, 2, 1, 2]> +#transform_map10 = #rock.transform_map<#map8 by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 1] -> [1]> +#transform_map11 = #rock.transform_map<#map9 by [ ["dim0"] at [0]>, ["dim1"] at [1]>] bounds = [1, 2] -> [1, 1]> +#transform_map12 = #rock.transform_map<#map10 by [ ["col0", "col1"] at [0, 1]>] bounds = [2] -> [1, 2]> +#transform_map13 = #rock.transform_map<#map12 by [ ["dim0", "dim2", "dim1"] at [0, 2, 1]>] bounds = [2, 8, 2] -> [2, 2, 8]> +#transform_map14 = #rock.transform_map<#map12 by [ ["gemmG"] at [0]>, ["gemm0K", "gemm0M"] at [2, 1]>] bounds = [2, 2, 1] -> [2, 1, 2]> +#transform_map15 = #rock.transform_map<#map12 by [ ["gemmG"] at [0]>, ["gemm0K", "gemm0N"] at [2, 1]>] bounds = [2, 2, 8] -> [2, 8, 2]> +#transform_map16 = #rock.transform_map<#map13 by [ ["gemmG"] at [0]>, ["gemm0K"] at [1]>, ["gemm0N"] at [2]>] bounds = [2, 32, 32] -> [2, 2, 1]> +#transform_map17 = #rock.transform_map<#map13 by [ ["gemmG"] at [0]>, ["gemm0K"] at [1]>, ["gemm0M"] at [2]>] bounds = [2, 32, 32] -> [2, 2, 8]> +#transform_map18 = #rock.transform_map<#map13 by [ ["gemmG"] at [0]>, ["gemm1K"] at [1]>, ["gemm1M"] at [2]>] bounds = [2, 32, 32] -> [2, 8, 2]> +#transform_map19 = #rock.transform_map<#map13 by [ ["gemmG"] at [0]>, ["gemm1N"] at [1]>, ["gemm1M"] at [2]>] bounds = [2, 32, 32] -> [2, 1, 2]> +#transform_map20 = #rock.transform_map<#map14 by [ ["dim0"] at [0]>, ["dim1"] at [1]>, ["dim2"] at [2]>, [] at []>] bounds = [1, 2, 1, 8] -> [2, 1, 8]> +#transform_map21 = #rock.transform_map<#map15 by [ ["col0", "col1", "col2"] at [0, 1, 2]>, ["dim1"] at [3]>] bounds = [2, 8] -> [1, 2, 1, 8]> +#transform_map22 = #rock.transform_map<#map15 by [ ["exp1"] at [1]>, ["dim1"] at [3]>, ["unit0"] at [0]>, ["unit2"] at [2]>] bounds = [2, 8] -> [1, 2, 1, 8]> +module { + func.func @mlir_attention(%arg0: memref<1xi32>, %arg1: memref<12xf16>, %arg2: memref<32xf16>, %arg3: memref<32xf16>, %arg4: memref<4xf16>) attributes {arch = "gfx950", block_size = 64 : i32, features = #rock, grid_size = 2 : i32, kernel = "mixr"} { + %cst = arith.constant 5.000000e-01 : f16 + %c4_i32 = arith.constant 4 : i32 + %0 = rock.transform %arg2 by #transform_map : memref<32xf16> to memref<1x2x8x2xf16> + %1 = rock.transform %arg1 by #transform_map1 : memref<12xf16> to memref<1x6x1x2xf16> + %2 = rock.transform %1 by #transform_map2 : memref<1x6x1x2xf16> to memref<1x2x1x2xf16> + %3 = rock.transform %0 by #transform_map3 : memref<1x2x8x2xf16> to memref<1x2x2x8xf16> + %4 = rock.transform %2 by #transform_map4 : memref<1x2x1x2xf16> to memref<2x1x2xf16> + %5 = rock.transform %3 by #transform_map5 : memref<1x2x2x8xf16> to memref<2x2x8xf16> + %6 = rock.transform %arg3 by #transform_map6 : memref<32xf16> to memref<2x8x2xf16> + %alloc = memref.alloc() : memref<4xf16> + %7 = rock.transform %alloc by #transform_map7 : memref<4xf16> to memref<1x1x2x2xf16> + %8 = rock.transform %7 by #transform_map8 : memref<1x1x2x2xf16> to memref<1x2x1x2xf16> + %9 = rock.transform %8 by #transform_map9 : memref<1x2x1x2xf16> to memref<2x1x2xf16> + %10 = rock.transform %arg0 by #transform_map10 : memref<1xi32> to memref<1x1xi32> + %11 = rock.transform %10 by #transform_map11 : memref<1x1xi32> to memref<1x2xi32> + %12 = rock.transform %11 by #transform_map12 : memref<1x2xi32> to memref<2xi32> + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<2xi32> + linalg.generic {indexing_maps = [#map11, #map11], iterator_types = ["parallel"]} ins(%12 : memref<2xi32>) outs(%alloc_0 : memref<2xi32>) attrs = {rock.majorTensorNumber = 0 : index} { + ^bb0(%in: i32, %out: i32): + %20 = arith.maxsi %in, %c4_i32 : i32 + %21 = arith.minsi %20, %c4_i32 : i32 + linalg.yield %21 : i32 + } + %13 = rock.transform %5 by #transform_map13 : memref<2x2x8xf16> to memref<2x8x2xf16> + %14 = rock.transform %4 by #transform_map14 : memref<2x1x2xf16> to memref<2x2x1xf16> + %15 = rock.transform %13 by #transform_map15 : memref<2x8x2xf16> to memref<2x2x8xf16> + %16 = rock.transform %14 by #transform_map16 : memref<2x2x1xf16> to memref<2x32x32xf16> + %17 = rock.transform %15 by #transform_map17 : memref<2x2x8xf16> to memref<2x32x32xf16> + %18 = rock.transform %6 by #transform_map18 : memref<2x8x2xf16> to memref<2x32x32xf16> + %19 = rock.transform %9 by #transform_map19 : memref<2x1x2xf16> to memref<2x32x32xf16> + rock.gridwise_attention_accel(%16, %17, %18, %alloc_0, %19) preSoftmaxOps = { + ^bb0(%arg5: memref<2x1x8xf16>, %arg6: memref<1x2x1x8xf16>): + %20 = rock.transform %arg5 by #transform_map20 : memref<2x1x8xf16> to memref<1x2x1x8xf16> + %21 = rock.transform %20 by #transform_map21 : memref<1x2x1x8xf16> to memref<2x8xf16> + %alloc_1 = memref.alloc() : memref<1x2x1x8xf16> + %22 = rock.transform %alloc_1 by #transform_map22 : memref<1x2x1x8xf16> to memref<2x8xf16> + linalg.generic {indexing_maps = [#map16, #map16], iterator_types = ["parallel", "parallel"]} ins(%21 : memref<2x8xf16>) outs(%22 : memref<2x8xf16>) attrs = {rock.majorTensorNumber = 0 : index} { + ^bb0(%in: f16, %out: f16): + %23 = arith.mulf %in, %cst : f16 + linalg.yield %23 : f16 + } + memref.copy %alloc_1, %arg6 : memref<1x2x1x8xf16> to memref<1x2x1x8xf16> + rock.yield + } {blockSize = 64 : i32, firstGemmIndices = array, gridSize = 2 : i32, operandSegmentSizes = array, params0 = #accel_gemm_params, params1 = #accel_gemm_params, prePadG0M = 8 : index, prePadG0N = 1 : index, slidingWindowSize = 3 : i32, softmaxType = f32, splitKV = 1 : i32, storeMethod = #rock} : memref<2x32x32xf16>, memref<2x32x32xf16>, memref<2x32x32xf16>, memref<2xi32>, memref<2x32x32xf16> + memref.copy %alloc, %arg4 : memref<4xf16> to memref<4xf16> + return + } +} diff --git a/mlir/test/e2e/AttentionDirectToLDS.toml b/mlir/test/e2e/AttentionDirectToLDS.toml index 3bf93205011d..b9d2f0ec09b2 100644 --- a/mlir/test/e2e/AttentionDirectToLDS.toml +++ b/mlir/test/e2e/AttentionDirectToLDS.toml @@ -109,6 +109,10 @@ config = "-rand 1 -return_lse -split_kv 4 -current_seq_len=17,1,32 -g 3 -num_hea [[suite.test]] config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias" +# GQA + causal + KV Cache batch=3 + return LSE + split-kv (padding) + sliding window +[[suite.test]] +config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 -sliding_window_size=64 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias" + # GQA + prefix causal + KV Cache batch=3 + return LSE + split-kv (padding) [[suite.test]] config = "-rand 1 -return_lse -split_kv 8 -prefix_offset=18,5,16 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias" diff --git a/mlir/test/e2e/AttentionNonPowerOfTwoTileSize.toml b/mlir/test/e2e/AttentionNonPowerOfTwoTileSize.toml index 527666fff756..2b995ae7f171 100644 --- a/mlir/test/e2e/AttentionNonPowerOfTwoTileSize.toml +++ b/mlir/test/e2e/AttentionNonPowerOfTwoTileSize.toml @@ -78,6 +78,10 @@ config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 -g 3 -num_hea [[suite.test]] config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias" +# GQA + causal + KV Cache batch=3 + return LSE + split-kv (padding) + sliding window +[[suite.test]] +config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 -sliding_window_size=64 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias" + # GQA + prefix causal + KV Cache batch=3 + return LSE + split-kv (padding) [[suite.test]] config = "-rand 1 -return_lse -split_kv 8 -prefix_offset=18,5,16 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias" diff --git a/mlir/test/e2e/AttentionSchedule.toml b/mlir/test/e2e/AttentionSchedule.toml index 5b685218e659..1134812a58ae 100644 --- a/mlir/test/e2e/AttentionSchedule.toml +++ b/mlir/test/e2e/AttentionSchedule.toml @@ -113,6 +113,10 @@ config = "-rand 1 -return_lse -split_kv 4 -current_seq_len=17,1,32 -g 3 -num_hea [[suite.test]] config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias -schedule_version 2" +# GQA + causal + KV Cache batch=3 + return LSE + split-kv (padding) + sliding window +[[suite.test]] +config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 -sliding_window_size=64 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias -schedule_version 2" + # GQA + prefix causal + KV Cache batch=3 + return LSE + split-kv (padding) [[suite.test]] config = "-rand 1 -return_lse -split_kv 8 -prefix_offset=18,5,16 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias -schedule_version 2" diff --git a/mlir/test/e2e/PrAttentionBF16.toml b/mlir/test/e2e/PrAttentionBF16.toml index 6de1b8f50afc..ffeb0be47b96 100644 --- a/mlir/test/e2e/PrAttentionBF16.toml +++ b/mlir/test/e2e/PrAttentionBF16.toml @@ -116,6 +116,10 @@ config = "-rand 1 -return_lse -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_ [[suite.test]] config = "-rand 1 -return_lse -split_kv 4 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias" +# GQA + causal + KV Cache batch=3 + return LSE + split-kv + sliding window +[[suite.test]] +config = "-rand 1 -return_lse -split_kv 4 -current_seq_len=17,1,32 -sliding_window_size=64 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias" + # GQA + causal + KV Cache batch=3 + return LSE + split-kv (padding) [[suite.test]] config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias" diff --git a/mlir/test/e2e/PrAttentionDirectToLDS.toml b/mlir/test/e2e/PrAttentionDirectToLDS.toml index 58b8861f4c76..2717c39024a4 100644 --- a/mlir/test/e2e/PrAttentionDirectToLDS.toml +++ b/mlir/test/e2e/PrAttentionDirectToLDS.toml @@ -26,3 +26,7 @@ config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 -g 3 -num_hea # GQA + prefix causal + KV Cache batch=3 + return LSE + split-kv (padding) [[suite.test]] config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 --prefix_offset=16,14,12 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias" + +# GQA + prefix causal + KV Cache batch=3 + return LSE + split-kv (padding) + sliding window +[[suite.test]] +config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 --prefix_offset=16,14,12 -sliding_window_size=64 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias" diff --git a/mlir/test/e2e/PrAttentionF16.toml b/mlir/test/e2e/PrAttentionF16.toml index 1ceae277283a..67569039708b 100644 --- a/mlir/test/e2e/PrAttentionF16.toml +++ b/mlir/test/e2e/PrAttentionF16.toml @@ -116,6 +116,10 @@ config = "-rand 1 -return_lse -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_ [[suite.test]] config = "-rand 1 -return_lse -split_kv 4 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias" +# GQA + causal + KV Cache batch=3 + return LSE + split-kv + sliding window +[[suite.test]] +config = "-rand 1 -return_lse -split_kv 4 -current_seq_len=17,1,32 -sliding_window_size=64 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias" + # GQA + causal + KV Cache batch=3 + return LSE + split-kv (padding) [[suite.test]] config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias" diff --git a/mlir/test/e2e/PrAttentionF32.toml b/mlir/test/e2e/PrAttentionF32.toml index b9e051119733..cc7de3c4bfa3 100644 --- a/mlir/test/e2e/PrAttentionF32.toml +++ b/mlir/test/e2e/PrAttentionF32.toml @@ -88,6 +88,10 @@ config = "-rand 1 -return_lse -split_kv 4 -current_seq_len=17,1,32 -g 3 -num_hea [[suite.test]] config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias" +# GQA + causal + KV Cache batch=3 + return LSE + split-kv (padding) + sliding window +#[[suite.test]] +#config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 -sliding_window_size=64 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias" + # GQA + prefix causal + KV Cache batch=3 + return LSE + split-kv (padding) [[suite.test]] config = "-rand 1 -return_lse -split_kv 8 -prefix_offset=18,5,16 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias" diff --git a/mlir/test/e2e/PrAttentionI8.toml b/mlir/test/e2e/PrAttentionI8.toml index eb1b7c251544..6bb6895f3361 100644 --- a/mlir/test/e2e/PrAttentionI8.toml +++ b/mlir/test/e2e/PrAttentionI8.toml @@ -93,6 +93,10 @@ config = "-rand 1 -return_lse -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_ [[suite.test]] config = "-rand 1 -return_lse -split_kv 4 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias" +# GQA + causal + KV Cache batch=3 + return LSE + split-kv + sliding window +[[suite.test]] +config = "-rand 1 -return_lse -split_kv 4 -current_seq_len=17,1,32 -sliding_window_size=64 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias" + # GQA + causal + KV Cache batch=3 + return LSE + split-kv (padding) [[suite.test]] config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias" diff --git a/mlir/test/e2e/PrAttentionSchedule.toml b/mlir/test/e2e/PrAttentionSchedule.toml index 2465abecaa4e..54caf88dbe4c 100644 --- a/mlir/test/e2e/PrAttentionSchedule.toml +++ b/mlir/test/e2e/PrAttentionSchedule.toml @@ -25,6 +25,10 @@ config = "-seq_len_q 384 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 --schedul [[suite.test]] config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias --schedule_version 2" +# GQA + causal + KV Cache batch=3 + return LSE + split-kv (padding) + sliding window +[[suite.test]] +config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 -sliding_window_size=64 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias --schedule_version 2" + # GQA + prefix causal + KV Cache batch=3 + return LSE + split-kv (padding) [[suite.test]] config = "-rand 1 -return_lse -split_kv 8 -prefix_offset=18,5,16 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias --schedule_version 2" diff --git a/mlir/test/fusion/pr-e2e/attention/mixr-attention-sliding-window-kvcache.mlir b/mlir/test/fusion/pr-e2e/attention/mixr-attention-sliding-window-kvcache.mlir new file mode 100644 index 000000000000..e7fe25eb8dbb --- /dev/null +++ b/mlir/test/fusion/pr-e2e/attention/mixr-attention-sliding-window-kvcache.mlir @@ -0,0 +1,57 @@ +// RUN: rocmlir-gen -fut mlir_attention --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx,highlevel -host-pipeline=migraphx,highlevel | rocmlir-gen -ph -rand 1 -rand_type float -fut mlir_attention_wrapper --verifier clone - | rocmlir-driver -host-pipeline mhal -kernel-pipeline full | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s +// CHECK: [1 1 1] + +module { + func.func @mlir_attention(%arg0: !migraphx.shaped<1x1xsi32, 1x1>, + %arg1: !migraphx.shaped<1x6x1x2xf16, 12x2x2x1>, + %arg2: !migraphx.shaped<1x2x8x2xf16, 32x16x2x1>, + %arg3: !migraphx.shaped<1x2x8x2xf16, 32x16x2x1>) -> !migraphx.shaped<1x1x4xf16, 4x4x1> attributes {kernel = "mixr"} { + %0 = migraphx.literal(dense<-3> : tensor<1xsi32>) : <1xsi32, 1> + %1 = migraphx.literal(dense<[0, 1, 2, 3, 4, 5, 6, 7]> : tensor<8xsi32>) : <8xsi32, 1> + %2 = migraphx.literal(dense<4> : tensor<1x1xsi32>) : <1x1xsi32, 1x1> + %3 = migraphx.literal(dense<[0, 1, 2, 3, 4, 5, 6, 7]> : tensor<8xsi32>) : <8xsi32, 1> + %4 = migraphx.literal(dense<0xFC00> : tensor<1xf16>) : <1xf16, 1> + %5 = migraphx.literal(dense<5.000000e-01> : tensor<1xf16>) : <1xf16, 1> + %6 = migraphx.multibroadcast %1 {out_dyn_dims = [], out_lens = [1, 1, 1, 8]} : <8xsi32, 1> -> <1x1x1x8xsi32, 0x0x0x1> + %7 = migraphx.multibroadcast %0 {out_dyn_dims = [], out_lens = [1, 1, 1, 1]} : <1xsi32, 1> -> <1x1x1x1xsi32, 0x0x0x1> + %8 = migraphx.reshape %arg0 {dims = [1, 1, 1, 1]} : <1x1xsi32, 1x1> -> <1x1x1x1xsi32, 1x1x1x1> + %9 = migraphx.reshape %2 {dims = [1, 1, 1, 1]} : <1x1xsi32, 1x1> -> <1x1x1x1xsi32, 1x1x1x1> + %10 = migraphx.reshape %2 {dims = [1, 1, 1, 1]} : <1x1xsi32, 1x1> -> <1x1x1x1xsi32, 1x1x1x1> + %11 = migraphx.clip %8, %9, %10 : <1x1x1x1xsi32, 1x1x1x1>, <1x1x1x1xsi32, 1x1x1x1>, <1x1x1x1xsi32, 1x1x1x1> -> <1x1x1x1xsi32, 1x1x1x1> + %12 = migraphx.slice %arg1 {axes = [1], ends = [2], starts = [0]} : <1x6x1x2xf16, 12x2x2x1> -> <1x2x1x2xf16, 12x2x2x1> + %13 = migraphx.transpose %arg2 {permutation = [0, 1, 3, 2]} : <1x2x8x2xf16, 32x16x2x1> -> <1x2x2x8xf16, 32x16x1x2> + %14 = migraphx.dot %12, %13 : <1x2x1x2xf16, 12x2x2x1>, <1x2x2x8xf16, 32x16x1x2> -> <1x2x1x8xf16, 16x8x8x1> + %15 = migraphx.multibroadcast %4 {out_dyn_dims = [], out_lens = [1, 2, 1, 8]} : <1xf16, 1> -> <1x2x1x8xf16, 0x0x0x0> + %16 = migraphx.multibroadcast %5 {out_dyn_dims = [], out_lens = [1, 2, 1, 8]} : <1xf16, 1> -> <1x2x1x8xf16, 0x0x0x0> + %17 = migraphx.mul %14, %16 : <1x2x1x8xf16, 16x8x8x1>, <1x2x1x8xf16, 0x0x0x0> -> <1x2x1x8xf16, 16x8x8x1> + %18 = migraphx.add %11, %7 : <1x1x1x1xsi32, 1x1x1x1>, <1x1x1x1xsi32, 0x0x0x1> -> <1x1x1x1xsi32, 1x1x1x1> + %19 = migraphx.multibroadcast %18 {out_dyn_dims = [], out_lens = [8, 1, 1, 1]} : <1x1x1x1xsi32, 1x1x1x1> -> <8x1x1x1xsi32, 0x1x1x1> + %20 = migraphx.reshape %19 {dims = [8]} : <8x1x1x1xsi32, 0x1x1x1> -> <8xsi32, 0> + %21 = migraphx.greater %20, %3 : <8xsi32, 0>, <8xsi32, 1> -> <8xsi32, 1> + %22 = migraphx.convert %21 {target_type = 0 : i64} : <8xsi32, 1> to <8xsi8, 1> + %23 = migraphx.broadcast %22 {axis = 3 : i64, out_lens = [1, 2, 1, 8]} : <8xsi8, 1> -> <1x2x1x8xsi8, 0x0x0x1> + %24 = migraphx.where %23, %15, %17 : <1x2x1x8xsi8, 0x0x0x1>, <1x2x1x8xf16, 0x0x0x0>, <1x2x1x8xf16, 16x8x8x1> -> <1x2x1x8xf16, 16x8x8x1> + %25 = migraphx.multibroadcast %11 {out_dyn_dims = [], out_lens = [1, 1, 1, 8]} : <1x1x1x1xsi32, 1x1x1x1> -> <1x1x1x8xsi32, 1x1x1x0> + %26 = migraphx.greater %6, %25 : <1x1x1x8xsi32, 0x0x0x1>, <1x1x1x8xsi32, 1x1x1x0> -> <1x1x1x8xsi32, 0x0x0x1> + %27 = migraphx.convert %26 {target_type = 0 : i64} : <1x1x1x8xsi32, 0x0x0x1> to <1x1x1x8xsi8, 0x0x0x1> + %28 = migraphx.multibroadcast %27 {out_dyn_dims = [], out_lens = [1, 2, 1, 8]} : <1x1x1x8xsi8, 0x0x0x1> -> <1x2x1x8xsi8, 0x0x0x1> + %29 = migraphx.where %28, %15, %24 : <1x2x1x8xsi8, 0x0x0x1>, <1x2x1x8xf16, 0x0x0x0>, <1x2x1x8xf16, 16x8x8x1> -> <1x2x1x8xf16, 16x8x8x1> + %30 = migraphx.convert %29 {target_type = 2 : i64} : <1x2x1x8xf16, 16x8x8x1> to <1x2x1x8xf32, 16x8x8x1> + %31 = migraphx.reshape %30 {dims = [1, 2, 1, 8]} : <1x2x1x8xf32, 16x8x8x1> -> <1x2x1x8xf32, 16x8x8x1> + %32 = migraphx.reduce_max %31 {axes = [3]} : <1x2x1x8xf32, 16x8x8x1> -> <1x2x1x1xf32, 2x1x1x1> + %33 = migraphx.reshape %32 {dims = [1, 2, 1, 1]} : <1x2x1x1xf32, 2x1x1x1> -> <1x2x1x1xf32, 2x1x1x1> + %34 = migraphx.multibroadcast %33 {out_dyn_dims = [], out_lens = [1, 2, 1, 8]} : <1x2x1x1xf32, 2x1x1x1> -> <1x2x1x8xf32, 2x1x1x0> + %35 = migraphx.sub %30, %34 : <1x2x1x8xf32, 16x8x8x1>, <1x2x1x8xf32, 2x1x1x0> -> <1x2x1x8xf32, 16x8x8x1> + %36 = migraphx.exp %35 : <1x2x1x8xf32, 16x8x8x1> -> <1x2x1x8xf32, 16x8x8x1> + %37 = migraphx.reshape %36 {dims = [1, 2, 1, 8]} : <1x2x1x8xf32, 16x8x8x1> -> <1x2x1x8xf32, 16x8x8x1> + %38 = migraphx.reduce_sum %37 {axes = [3]} : <1x2x1x8xf32, 16x8x8x1> -> <1x2x1x1xf32, 2x1x1x1> + %39 = migraphx.reshape %38 {dims = [1, 2, 1, 1]} : <1x2x1x1xf32, 2x1x1x1> -> <1x2x1x1xf32, 2x1x1x1> + %40 = migraphx.multibroadcast %39 {out_dyn_dims = [], out_lens = [1, 2, 1, 8]} : <1x2x1x1xf32, 2x1x1x1> -> <1x2x1x8xf32, 2x1x1x0> + %41 = migraphx.div %36, %40 : <1x2x1x8xf32, 16x8x8x1>, <1x2x1x8xf32, 2x1x1x0> -> <1x2x1x8xf32, 16x8x8x1> + %42 = migraphx.convert %41 {target_type = 1 : i64} : <1x2x1x8xf32, 16x8x8x1> to <1x2x1x8xf16, 16x8x8x1> + %43 = migraphx.dot %42, %arg3 : <1x2x1x8xf16, 16x8x8x1>, <1x2x8x2xf16, 32x16x2x1> -> <1x2x1x2xf16, 4x2x2x1> + %44 = migraphx.transpose %43 {permutation = [0, 2, 1, 3]} : <1x2x1x2xf16, 4x2x2x1> -> <1x1x2x2xf16, 4x2x2x1> + %45 = migraphx.reshape %44 {dims = [1, 1, 4]} : <1x1x2x2xf16, 4x2x2x1> -> <1x1x4xf16, 4x4x1> + return %45 : !migraphx.shaped<1x1x4xf16, 4x4x1> + } +} diff --git a/mlir/test/rocmlir-gen/attention-sliding-window.mlir b/mlir/test/rocmlir-gen/attention-sliding-window.mlir new file mode 100644 index 000000000000..f9a5063e4b86 --- /dev/null +++ b/mlir/test/rocmlir-gen/attention-sliding-window.mlir @@ -0,0 +1,40 @@ +// RUN: rocmlir-gen --arch gfx90a:sramecc+:xnack- --operation attention -current_seq_len=33 -sliding_window_size=16 -seq_len_q 1 -seq_len_k 64 -head_dim_qk 32 -head_dim_v 32 -t f32 -pv --apply-bufferization-pipeline=false | rocmlir-opt | FileCheck %s --enable-var-scope + +// CHECK: module attributes {mhal.arch = "[[$ARCH:.*]]"} + +// CHECK-LABEL: func.func @rock_attention +// CHECK-SAME: (%[[queriesRaw:.*0]]: memref<32xf32>, +// CHECK-SAME: %[[keysRaw:.*1]]: memref<2048xf32>, +// CHECK-SAME: %[[valuesRaw:.*2]]: memref<2048xf32>, +// CHECK-SAME: %[[currentSeqLenRaw:.*3]]: memref<1xi32>, +// CHECK-SAME: %[[outputRaw:.*4]]: memref<32xf32>) +// CHECK-SAME: attributes {kernel, mhal.arch = "[[$ARCH]]"} + +// CHECK: rock.attention +// CHECK-NEXT: qk = %{{.*}} * %{{.*}} +// CHECK-NEXT: currentSeqLen = (%{{.*}} : memref<1xi32>) +// CHECK-NEXT: slidingWindowSize = 16 +// CHECK: softmax(qk) * %{{.*}} +// CHECK: return + +// CHECK-LABEL: func.func @host_naive_attention +// Verify KV-cache masking is applied +// CHECK: tosa.matmul +// CHECK: tosa.greater +// CHECK: tosa.select + +// Verify sliding window masking is applied in the CPU verifier: +// The sliding window masking computes lowerBound = max(0, currentSeqLen - windowSize), +// then masks positions where col < lowerBound with -inf. +// CHECK: tosa.sub %{{.*}}, %{{.*}} : (tensor<1x1x1x64xi32>, tensor<1x1x1x64xi32>) -> tensor<1x1x1x64xi32> +// CHECK: tosa.maximum %{{.*}}, %{{.*}} : (tensor<1x1x1x64xi32>, tensor<1x1x1x64xi32>) -> tensor<1x1x1x64xi32> +// CHECK: tosa.greater %{{.*}}, %{{.*}} : (tensor<1x1x1x64xi32>, tensor<1x1x1x64xi32>) -> tensor<1x1x1x64xi1> +// CHECK: tosa.select %{{.*}}, %{{.*}}, %{{.*}} : (tensor<1x1x1x64xi1>, tensor<1x1x1x64xf32>, tensor<1x1x1x64xf32>) -> tensor<1x1x1x64xf32> + +// Verify softmax follows +// CHECK-DAG: tosa.reduce_max +// CHECK-DAG: tosa.exp +// CHECK-DAG: tosa.reduce_sum +// CHECK-DAG: tosa.reciprocal +// CHECK: tosa.matmul +// CHECK: return diff --git a/mlir/tools/rocmlir-gen/rocmlir-gen.cpp b/mlir/tools/rocmlir-gen/rocmlir-gen.cpp index 335d3f4c5ebb..0688daf7d71e 100644 --- a/mlir/tools/rocmlir-gen/rocmlir-gen.cpp +++ b/mlir/tools/rocmlir-gen/rocmlir-gen.cpp @@ -691,6 +691,13 @@ static llvm::cl::opt splitKV( "the number of blocks in the sequenceLengthK dimension."), llvm::cl::value_desc("positive integer"), llvm::cl::init(1)); +static llvm::cl::opt slidingWindowSize( + "sliding_window_size", + llvm::cl::desc("Sliding window attention size. Only the last " + "slidingWindowSize key positions (relative to " + "currentSeqLen) are attended to. Requires current_seq_len."), + llvm::cl::value_desc("positive integer"), llvm::cl::init(0)); + static llvm::cl::opt returnLSE( "return_lse", llvm::cl::desc("whether the attention kernel returns LSE (log-sum-exp)"), @@ -3079,6 +3086,66 @@ static Value maskKVCacheTosa(OpBuilder builder, Location loc, Value inputTensor, return resultReshaped; } +// Sliding window masking: mask positions where col < max(0, currentSeqLen - +// slidingWindowSize). The inputTensor shape is [b*num_heads_q, seq_len_q, +// seq_len_kv]. currentSeqLenVal has shape [1x1x1x1xi32] (already reshaped). +static Value slidingWindowMaskingTosa(OpBuilder builder, Location loc, + Value inputTensor, Value currentSeqLenVal, + int64_t windowSize, float initValue) { + auto origType = cast(inputTensor.getType()); + ArrayRef origShape = origType.getShape(); + SmallVector newShape = {origShape[0] / numHeadsQ, numHeadsQ, + origShape[1], origShape[2]}; + ImplicitLocOpBuilder implicitBuilder(loc, builder); + auto newShapeValue = tosa::getTosaConstShape(implicitBuilder, newShape); + inputTensor = rock::tosa::createOpAndInfer( + builder, loc, origType.getElementType(), inputTensor, newShapeValue); + + auto inpType = cast(inputTensor.getType()); + ArrayRef inpShape = inpType.getShape(); + + // Create column range [0, 1, ..., seq_len_kv-1] + Value colRange = createRange(builder, loc, 3, inpShape); + + // Broadcast currentSeqLen to full shape + auto outType = RankedTensorType::get(inpShape, builder.getI32Type()); + auto currentSeqLenBroadcast = rock::tosa::getMulOp( + builder, loc, currentSeqLenVal, + rock::tosa::getOneTensor(builder, loc, outType), builder.getI32Type()); + + // Compute lowerBound = max(0, currentSeqLen - slidingWindowSize) + DenseElementsAttr windowSizeAttr = DenseIntElementsAttr::get( + RankedTensorType::get(inpShape, builder.getI32Type()), + static_cast(windowSize)); + Value windowSizeConst = tosa::ConstOp::create( + builder, loc, windowSizeAttr.getType(), windowSizeAttr); + Value lowerBound = rock::tosa::createOpAndInfer( + builder, loc, builder.getI32Type(), currentSeqLenBroadcast, + windowSizeConst); + + // Clamp lower bound to >= 0 + DenseElementsAttr zeroAttr = DenseIntElementsAttr::get( + RankedTensorType::get(inpShape, builder.getI32Type()), + static_cast(0)); + Value zeroConst = + tosa::ConstOp::create(builder, loc, zeroAttr.getType(), zeroAttr); + lowerBound = rock::tosa::createOpAndInfer( + builder, loc, builder.getI32Type(), lowerBound, zeroConst); + + // Create mask: col < lowerBound (i.e., lowerBound > col) + auto mask = rock::tosa::createOpAndInfer( + builder, loc, builder.getIntegerType(1), lowerBound, colRange); + + Value result = applyMask(builder, loc, inputTensor, mask, initValue); + + // Reshape result back to [batch_size*num_heads_q, seq_len_q, seq_len_kv] + auto origShapeValue = tosa::getTosaConstShape(implicitBuilder, origShape); + auto resultReshaped = rock::tosa::createOpAndInfer( + builder, loc, inpType.getElementType(), result, origShapeValue); + + return resultReshaped; +} + static Value broadcastBatchTosa(OpBuilder builder, Location loc, Value inputTensor, int64_t numRepeat) { if (numRepeat == 1) @@ -3379,6 +3446,9 @@ static func::FuncOp createGpuAttentionKernel(ModuleOp module, currentSeqLenTensor, prefixOffsetTensor, output, lse, numHeadsQ, numHeadsKV, transposeQ, transposeK, transposeV, transposeO, actualCausal, splitKV, + /*slidingWindowSize=*/ + slidingWindowSize > 0 ? builder.getI32IntegerAttr(slidingWindowSize) + : nullptr, rock::GemmFeaturesAttr::get(builder.getContext(), params.features), storeMethod, softmaxType, /*params0=*/nullptr, /*params1=*/nullptr, @@ -4321,6 +4391,14 @@ static func::FuncOp createCpuAttentionKernelWithMlir(ModuleOp module, qkTensor = causalMaskingTosa(builder, loc, qkTensor, -std::numeric_limits::infinity()); + // Apply sliding window masking if slidingWindowSize is set. + // Masks positions where col < max(0, currentSeqLen - slidingWindowSize). + if (slidingWindowSize > 0 && currentSeqLenTensor) { + qkTensor = slidingWindowMaskingTosa( + builder, loc, qkTensor, currentSeqLenTensor, slidingWindowSize, + -std::numeric_limits::infinity()); + } + constexpr int64_t reductionAxis = 2; auto qkMaxs = rock::tosa::createOpAndInfer( builder, loc, softmaxType, qkTensor, reductionAxis);