From bd7ce285029fec4fb32eb7691df5de6aa507f867 Mon Sep 17 00:00:00 2001 From: stefankoncarevic Date: Wed, 18 Feb 2026 06:11:58 -0600 Subject: [PATCH 1/4] Relax kpack validation for MFMA with k_base >= 8 Allow kpack < k_base when k_base >= 8 and k_base % kpack == 0. This enables better utilization of double-rate MFMA instructions (e.g., gfx950 f16/bf16/int8, gfx942 int8/fp8) with kpack=4. Disable LDS transpose for prefetch when kpack < kBase as a necessary fix for the relaxed validation. --- mlir/lib/Dialect/Rock/IR/MfmaInsnGroup.cpp | 24 ++++- .../Transforms/GridwiseGemmToBlockwise.cpp | 7 -- .../lib/Dialect/Rock/utility/AccelEmitter.cpp | 94 +++++++++++++------ .../Dialect/Rock/utility/LdsTransposeLoad.cpp | 12 +++ mlir/test/e2e/LdsTransposeLoadAttention.toml | 25 ++--- .../test/e2e/PrLdsTransposeLoadAttention.toml | 8 +- 6 files changed, 108 insertions(+), 62 deletions(-) diff --git a/mlir/lib/Dialect/Rock/IR/MfmaInsnGroup.cpp b/mlir/lib/Dialect/Rock/IR/MfmaInsnGroup.cpp index e8d2d770d605..062e1a716a9c 100644 --- a/mlir/lib/Dialect/Rock/IR/MfmaInsnGroup.cpp +++ b/mlir/lib/Dialect/Rock/IR/MfmaInsnGroup.cpp @@ -484,13 +484,29 @@ VectorType MfmaInsn::getRetType(Type elementType) { return VectorType::get({attr.nOutputsOfMfma}, vectorElem); } +// Check if the MFMA instruction is coherent with the K dimension configuration. +// When k_base >= 8, allows kpack < k_base if k_base % kpack == 0. bool MfmaInsn::isCoherentWithK(int64_t kpack, int64_t kPerBlock) { + int64_t totalKPerBlock = kpack * kPerBlock; + if (kpack > 1) { if (kpack < attr.k_base) { - LLVM_DEBUG(llvm::dbgs() - << "Should pack at least k_base elements and avoid waste " - "xdlopsgemm cycles\n"); - return false; + // Relaxed kpack check only for MFMA with k_base >= 8 + if (attr.k_base < 8) { + LLVM_DEBUG(llvm::dbgs() << "kpack (" << kpack << ") must be >= k_base (" + << attr.k_base << ")\n"); + return false; + } + if (attr.k_base % kpack != 0) { + LLVM_DEBUG(llvm::dbgs() + << "kpack must divide k_base when kpack < k_base\n"); + return false; + } + if (totalKPerBlock < attr.k) { + LLVM_DEBUG(llvm::dbgs() << "totalKPerBlock (" << totalKPerBlock + << ") must be >= MFMA K (" << attr.k << ")\n"); + return false; + } } if (attr.isKReduction && kPerBlock < attr.inputSpansPerMfmaIn) { LLVM_DEBUG( diff --git a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp index 8c91c4083bb2..201748cc206b 100644 --- a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp @@ -2476,13 +2476,6 @@ struct GridwiseAttentionAccelRewritePattern gemm0TuningParams.getNPerWave(), gemm0kpack, /*doubleBuffering=*/false, /*bLoadsFromLDS=*/qLoadsFromLDS); - // Disable LDS transpose for large head dimensions (HeadDimQK >= 512) - // Note: gemm0N = qShape[2] = head_dim_qk - if (gemm0N >= 512) { - ldsDecisionGemm0.enableA = false; - ldsDecisionGemm0.enableB = false; - } - // create matrix params BlockwiseMatrixParamsAttr matrixParamsK = BlockwiseMatrixParamsAttr::get( rewriter.getContext(), elemTypeK, elemTypeKLoad, diff --git a/mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp b/mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp index 711d4a7b4edf..62790611eae1 100644 --- a/mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp +++ b/mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp @@ -546,9 +546,8 @@ Value MfmaEmitter::wrapLDSBufferForLoad( TopDownTMBuilder toLDSRowCol(b, {}, {}, loc); // Use LDS transpose compatible K formula when this operand uses LDS - // transpose load (and kVec >= kBase to ensure proper K distribution) - if (useLdsTransposeLoad && kVec >= kBase) { - + // transpose load. Handles both kVec >= kBase and kVec < kBase cases. + if (useLdsTransposeLoad) { // K access pattern must match the transpose load's pattern. // For double-rate MFMA, properly distribute K across threads int64_t instrK = mfmaAttr.k; @@ -568,32 +567,63 @@ Value MfmaEmitter::wrapLDSBufferForLoad( TransformMapAttr splitBlkIdAttr = splitBlkId.get(); transformAttrs.push_back(splitBlkIdAttr); - // Split k_vec into k_mfma and k_base for kpack > kBase - int64_t numMfmaPerKVec = kVec / kBase; - - TopDownTMBuilder splitKVec = - TopDownTMBuilder::below(splitBlkId, splitBlkIdAttr); - splitKVec.passThrough({"wave_m", "wave_n"}, {0, 1}, {"wave_m", "wave_n"}); - splitKVec.passThrough({"blk_d", "blk_k", "blk_td", "d_iter", "k_iter"}, - {2, 3, 4, 5, 6}, - {"blk_d", "blk_k", "blk_td", "d_iter", "k_iter"}); - splitKVec.merge({"k_mfma", "k_base"}, {7, 8}, "k_vec", - {numMfmaPerKVec, kBase}); - TransformMapAttr splitKVecAttr = splitKVec.get(); - transformAttrs.push_back(splitKVecAttr); - - toLDSRowCol = TopDownTMBuilder::below(splitKVec, splitKVecAttr); - - // d = d_iter * dWaves * numBlksInD * inputSpanLen + wave_d * numBlksInD * - // inputSpanLen + blk_d * inputSpanLen + blk_td - toLDSRowCol.unmerge("d", 0, {"d_iter", thisWaveDim, "blk_d", "blk_td"}, - {dRepeats, dWaves, numBlksInD, inputSpanLen}); - - // k = k_iter * (numMfmaPerKVec * instrK) + k_mfma * instrK + blk_k * - // kBase + k_base - toLDSRowCol.unmerge("k", 1, {"k_iter", "k_mfma", "blk_k", "k_base"}, - {kIter, numMfmaPerKVec, numBlksInK, kBase}); - + if (kVec >= kBase) { + // Case 1: kVec >= kBase - split k_vec into k_mfma and k_base + int64_t numMfmaPerKVec = kVec / kBase; + + TopDownTMBuilder splitKVec = + TopDownTMBuilder::below(splitBlkId, splitBlkIdAttr); + splitKVec.passThrough({"wave_m", "wave_n"}, {0, 1}, + {"wave_m", "wave_n"}); + splitKVec.passThrough({"blk_d", "blk_k", "blk_td", "d_iter", "k_iter"}, + {2, 3, 4, 5, 6}, + {"blk_d", "blk_k", "blk_td", "d_iter", "k_iter"}); + splitKVec.merge({"k_mfma", "k_base"}, {7, 8}, "k_vec", + {numMfmaPerKVec, kBase}); + TransformMapAttr splitKVecAttr = splitKVec.get(); + transformAttrs.push_back(splitKVecAttr); + + toLDSRowCol = TopDownTMBuilder::below(splitKVec, splitKVecAttr); + + // d = d_iter * dWaves * numBlksInD * inputSpanLen + wave_d * numBlksInD + // * inputSpanLen + blk_d * inputSpanLen + blk_td + toLDSRowCol.unmerge("d", 0, {"d_iter", thisWaveDim, "blk_d", "blk_td"}, + {dRepeats, dWaves, numBlksInD, inputSpanLen}); + + // k = k_iter * (numMfmaPerKVec * instrK) + k_mfma * instrK + blk_k * + // kBase + k_base + toLDSRowCol.unmerge("k", 1, {"k_iter", "k_mfma", "blk_k", "k_base"}, + {kIter, numMfmaPerKVec, numBlksInK, kBase}); + } else { + // Case 2: kVec < kBase - split k_iter to accumulate multiple kVec + // loads into one kBase worth of data (e.g., kVec=4, kBase=8) + int64_t numKVecPerMfma = kBase / kVec; + int64_t kOuter = kIter / numKVecPerMfma; + + TopDownTMBuilder splitKIter = + TopDownTMBuilder::below(splitBlkId, splitBlkIdAttr); + splitKIter.passThrough({"wave_m", "wave_n"}, {0, 1}, + {"wave_m", "wave_n"}); + splitKIter.passThrough({"blk_d", "blk_k", "blk_td", "d_iter"}, + {2, 3, 4, 5}, + {"blk_d", "blk_k", "blk_td", "d_iter"}); + splitKIter.merge({"k_outer", "k_inner"}, {6, 7}, "k_iter", + {kOuter, numKVecPerMfma}); + splitKIter.passThrough({"k_vec"}, {8}, {"k_vec"}); + TransformMapAttr splitKIterAttr = splitKIter.get(); + transformAttrs.push_back(splitKIterAttr); + + toLDSRowCol = TopDownTMBuilder::below(splitKIter, splitKIterAttr); + + // d formula same as kVec >= kBase case + toLDSRowCol.unmerge("d", 0, {"d_iter", thisWaveDim, "blk_d", "blk_td"}, + {dRepeats, dWaves, numBlksInD, inputSpanLen}); + + // k = k_outer * instrK + blk_k * kBase + k_inner * kVec + k_vec + // This accumulates numKVecPerMfma loads of kVec elements into kBase + toLDSRowCol.unmerge("k", 1, {"k_outer", "blk_k", "k_inner", "k_vec"}, + {kOuter, numBlksInK, numKVecPerMfma, kVec}); + } } else { // Standard formula for regular load scenarios toLDSRowCol = TopDownTMBuilder::below(splitWaveId, splitWaveIdAttr); @@ -760,9 +790,11 @@ MfmaEmitter::createAccelGemmOperandTransforms( TransformMapAttr splitWaveIdAttr = splitWaveId.get(); transformAttrs.push_back(splitWaveIdAttr); // Fourth coordinate transform - // Check if we need LDS transpose compatible K formula + // Check if we need LDS transpose compatible K formula. + // When prefetch is used: kPack >= kBase allows LDS transpose load, + // kPack < kBase disables it (falls back to regular load). bool useLdsTransposeCompatibleK = - otherOperandUsesLdsTranspose && isKReduction && (kPack >= kBase); + otherOperandUsesLdsTranspose && isKReduction; int64_t numBlksInK = instrK / kBase; int64_t numBlksInD = (waveSize / inputSpanLen) / numBlksInK; diff --git a/mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp b/mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp index d4f021dc2715..f65e1487c847 100644 --- a/mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp +++ b/mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp @@ -250,6 +250,18 @@ LDSTransposeDecision decideLDSTransposeForOperands( << (decB.usable ? "USABLE" : "NOT USABLE") << "\n"); } + // If B is prefetched (doesn't load from LDS) and kpack < kBase, + // disable A's LDS transpose. The prefetched B operand needs a compatible + // K formula that requires kpack >= kBase (floordiv/mod can't be expressed + // with linear transforms). + int64_t kBase = mfmaEmitter->getParams().kBase; + if (!bLoadsFromLDS && kpack < kBase) { + decA.usable = false; + LLVM_DEBUG(llvm::dbgs() + << "[lds_transpose] Disabling A: prefetch with " + << "kPack(" << kpack << ") < kBase(" << kBase << ")\n"); + } + // Enable LDS transpose load for each operand that supports it. // The K access pattern formula in AccelEmitter.cpp (useLdsTransposeLoad) // ensures compatibility when mixing regular load with transpose load. diff --git a/mlir/test/e2e/LdsTransposeLoadAttention.toml b/mlir/test/e2e/LdsTransposeLoadAttention.toml index b2bd4f102575..e53c9d7439bb 100644 --- a/mlir/test/e2e/LdsTransposeLoadAttention.toml +++ b/mlir/test/e2e/LdsTransposeLoadAttention.toml @@ -48,10 +48,10 @@ config = "-seq_len_q 128 -seq_len_k 128 -head_dim_qk 64 -head_dim_v 128 --transK config = "-seq_len_q 256 -seq_len_k 256 -head_dim_qk 128 -head_dim_v 128 --transK=false --transQ=true -perf_config attn:v2:64,64,16,16,16,16,8,1,4,2,1" [[suite.test]] -config = "-seq_len_q 64 -seq_len_k 64 -head_dim_qk 64 -head_dim_v 64 --transK=false --transQ=true -perf_config attn:v2:16,16,16,16,16,16,4,1,3,2,1" +config = "-seq_len_q 64 -seq_len_k 64 -head_dim_qk 64 -head_dim_v 64 --transK=false --transQ=true -perf_config attn:v2:32,32,16,16,16,16,8,1,3,2,1" [[suite.test]] -config = "-seq_len_q 128 -seq_len_k 128 -head_dim_qk 128 -head_dim_v 128 --transK=false --transQ=true -perf_config attn:v2:16,16,32,16,16,16,4,1,4,2,1" +config = "-seq_len_q 128 -seq_len_k 128 -head_dim_qk 128 -head_dim_v 128 --transK=false --transQ=true -perf_config attn:v2:32,32,16,32,32,16,8,1,3,2,1" # ============================================================================ # Suite 3: transK=false, transQ=false - Only K uses LDS transpose @@ -83,13 +83,13 @@ config = "-seq_len_q 128 -seq_len_k 128 -head_dim_qk 128 -head_dim_v 128 --trans config = "-seq_len_q 256 -seq_len_k 256 -head_dim_qk 256 -head_dim_v 128 --transK=false --transQ=false -perf_config attn:v2:64,64,32,32,32,32,8,1,4,2,1" [[suite.test]] -config = "-seq_len_q 64 -seq_len_k 64 -head_dim_qk 64 -head_dim_v 64 --transK=false --transQ=false -perf_config attn:v2:32,32,16,16,16,16,4,1,3,2,1" +config = "-seq_len_q 64 -seq_len_k 64 -head_dim_qk 64 -head_dim_v 64 --transK=false --transQ=false -perf_config attn:v2:32,32,16,16,16,16,8,1,3,2,1" [[suite.test]] -config = "-seq_len_q 128 -seq_len_k 128 -head_dim_qk 128 -head_dim_v 128 --transK=false --transQ=false -perf_config attn:v2:16,16,32,16,16,16,4,1,4,2,1" +config = "-seq_len_q 128 -seq_len_k 128 -head_dim_qk 128 -head_dim_v 128 --transK=false --transQ=false -perf_config attn:v2:32,32,16,32,32,16,8,1,3,2,1" [[suite.test]] -config = "-seq_len_q 128 -seq_len_k 128 -head_dim_qk 128 -head_dim_v 128 --transK=false --transQ=false -perf_config attn:v2:32,32,32,32,32,32,4,1,3,2,1" +config = "-seq_len_q 128 -seq_len_k 128 -head_dim_qk 128 -head_dim_v 128 --transK=false --transQ=false -perf_config attn:v2:32,32,32,32,32,32,8,1,3,2,1" # ============================================================================ # Suite 5: transK=true, transQ=true - Only Q can use LDS transpose @@ -108,19 +108,12 @@ config = "-seq_len_q 256 -seq_len_k 256 -head_dim_qk 256 -head_dim_v 256 --trans config = "-seq_len_q 128 -seq_len_k 128 -head_dim_qk 256 -head_dim_v 64 --transK=true --transQ=true -perf_config attn:v2:32,32,16,16,16,16,8,1,3,2,1" # ============================================================================ -# Suite 6: GEMM1 LDS Transpose - V uses LDS transpose, P is prefetched +# Suite 6: GEMM1 fallback - kpack=4 < kBase=8, LDS transpose disabled for V +# Tests that the fallback to regular path works correctly when P is prefetched +# but kpack < kBase prevents LDS transpose for V operand. # ============================================================================ [[suite]] -name = "lds_transpose_gemm1_v_with_p_prefetch" +name = "lds_transpose_gemm1_fallback" -# 32x8 MFMA, medium dimensions [[suite.test]] config = "-seq_len_q 128 -seq_len_k 128 -head_dim_qk 64 -head_dim_v 64 --transK=false --transQ=false -perf_config attn:v3:32,32,32,32,32,32,32,4,1,3,2,0,1" - -# 32x8 MFMA, larger dimensions -[[suite.test]] -config = "-seq_len_q 256 -seq_len_k 256 -head_dim_qk 128 -head_dim_v 128 --transK=false --transQ=false -perf_config attn:v3:32,32,32,32,32,32,32,4,1,3,2,0,1" - -# 16x16 MFMA, smaller dimensions -[[suite.test]] -config = "-seq_len_q 64 -seq_len_k 64 -head_dim_qk 64 -head_dim_v 64 --transK=false --transQ=false -perf_config attn:v3:16,16,16,16,16,16,16,4,1,3,2,0,1" diff --git a/mlir/test/e2e/PrLdsTransposeLoadAttention.toml b/mlir/test/e2e/PrLdsTransposeLoadAttention.toml index eb9bf30e4e17..e2212715a5a2 100644 --- a/mlir/test/e2e/PrLdsTransposeLoadAttention.toml +++ b/mlir/test/e2e/PrLdsTransposeLoadAttention.toml @@ -18,7 +18,7 @@ prefix = "-t " name = "lds_transpose_both" [[suite.test]] -config = "-seq_len_q 64 -seq_len_k 64 -head_dim_qk 128 -head_dim_v 64 --transK=false --transQ=true -perf_config attn:v2:32,32,32,32,32,32,4,1,3,2,1" +config = "-seq_len_q 64 -seq_len_k 64 -head_dim_qk 256 -head_dim_v 64 --transK=false --transQ=true -perf_config attn:v2:32,32,32,32,32,32,4,1,3,2,1" [[suite.test]] config = "-seq_len_q 64 -seq_len_k 64 -head_dim_qk 128 -head_dim_v 64 --transK=false --transQ=true -perf_config attn:v2:16,16,16,16,16,16,4,1,3,2,1" @@ -48,7 +48,7 @@ config = "-seq_len_q 64 -seq_len_k 64 -head_dim_qk 128 -head_dim_v 64 --transK=f config = "-seq_len_q 64 -seq_len_k 64 -head_dim_qk 128 -head_dim_v 64 --transK=false --transQ=false -perf_config attn:v2:16,16,16,16,16,16,4,1,3,2,1" [[suite.test]] -config = "-seq_len_q 64 -seq_len_k 64 -head_dim_qk 64 -head_dim_v 64 --transK=false --transQ=false -perf_config attn:v2:32,32,32,32,32,32,4,1,3,2,1" +config = "-seq_len_q 64 -seq_len_k 64 -head_dim_qk 256 -head_dim_v 64 --transK=false --transQ=false -perf_config attn:v2:32,32,32,32,32,32,4,1,3,2,1" # ============================================================================ # Suite 4: Only Q uses LDS transpose (transK=true, transQ=true) @@ -69,13 +69,13 @@ config = "-seq_len_q 64 -seq_len_k 64 -head_dim_qk 128 -head_dim_v 64 --transK=t name = "lds_transpose_gemm1_v" [[suite.test]] -config = "-seq_len_q 64 -seq_len_k 64 -head_dim_qk 64 -head_dim_v 64 --transK=false --transQ=false -perf_config attn:v3:32,32,32,32,32,32,32,4,1,3,2,0,1" +config = "-seq_len_q 64 -seq_len_k 64 -head_dim_qk 64 -head_dim_v 64 --transK=false --transQ=false -perf_config attn:v3:32,32,32,32,32,32,32,8,1,3,2,0,1" [[suite.test]] config = "-seq_len_q 64 -seq_len_k 64 -head_dim_qk 64 -head_dim_v 64 --transK=false --transQ=false -perf_config attn:v3:16,16,16,16,16,16,16,4,1,3,2,0,1" [[suite.test]] -config = "-seq_len_q 128 -seq_len_k 128 -head_dim_qk 64 -head_dim_v 64 --transK=false --transQ=false -perf_config attn:v3:32,32,32,32,32,32,32,4,1,3,2,0,1" +config = "-seq_len_q 128 -seq_len_k 128 -head_dim_qk 256 -head_dim_v 64 --transK=false --transQ=false -perf_config attn:v3:32,32,32,32,32,32,32,4,1,3,2,0,1" # ============================================================================ # Suite 6: Mixed scenarios with different MFMA sizes From f440b8f3123ee004ee1155b1ac892d2dc34b087a Mon Sep 17 00:00:00 2001 From: stefankoncarevic Date: Fri, 20 Feb 2026 03:13:57 -0600 Subject: [PATCH 2/4] Pass scheduleVersion to MfmaInsnGroup::select and isCoherentWithK Relaxed kpack validation (kpack < k_base) now only applies to double-buffer pipelines (scheduleVersion 2 or 4). --- .../mlir/Dialect/Rock/IR/MfmaInsnGroup.h | 9 ++++-- mlir/lib/Dialect/Rock/IR/MfmaInsnGroup.cpp | 29 ++++++++++++------- .../Rock/Tuning/GridwiseGemmParams.cpp | 13 +++++---- .../lib/Dialect/Rock/utility/AccelEmitter.cpp | 3 +- 4 files changed, 34 insertions(+), 20 deletions(-) diff --git a/mlir/include/mlir/Dialect/Rock/IR/MfmaInsnGroup.h b/mlir/include/mlir/Dialect/Rock/IR/MfmaInsnGroup.h index 52c7237dfd6c..b37907d05c8d 100644 --- a/mlir/include/mlir/Dialect/Rock/IR/MfmaInsnGroup.h +++ b/mlir/include/mlir/Dialect/Rock/IR/MfmaInsnGroup.h @@ -71,7 +71,8 @@ class MfmaInsn { MfmaInsnAttr getAttr() const; Type getArgTypeFor(Type elementTypeA); VectorType getRetType(Type elementType); - bool isCoherentWithK(int64_t kPack, int64_t kPerBlock); + bool isCoherentWithK(int64_t kPack, int64_t kPerBlock, + int64_t scheduleVersion); }; template @@ -138,7 +139,8 @@ class MfmaInsnGroup { public: static FailureOr select(Type elementTypeA, Type elementTypeB, StringRef arch, int64_t mnPerXdl, - int64_t kPack, int64_t kPackPerBlock); + int64_t kPack, int64_t kPackPerBlock, + int64_t scheduleVersion); MfmaInsnGroup(Type elementTypeA, Type elementTypeB, const MfmaInsn &insn, const MfmaInsnGroupAttr &groupAttr); int64_t getMRepeats(int64_t mPerWave); @@ -150,7 +152,8 @@ class MfmaInsnGroup { Type getArgTypeA(); Type getArgTypeB(); VectorType getRetType(); - bool isCoherentWithK(int64_t kPack, int64_t kPerBlock); + bool isCoherentWithK(int64_t kPack, int64_t kPerBlock, + int64_t scheduleVersion); SmallString<16> getROCDLIntrinsicName() { return groupAttr.insn; } }; diff --git a/mlir/lib/Dialect/Rock/IR/MfmaInsnGroup.cpp b/mlir/lib/Dialect/Rock/IR/MfmaInsnGroup.cpp index 062e1a716a9c..f02ba3e9012f 100644 --- a/mlir/lib/Dialect/Rock/IR/MfmaInsnGroup.cpp +++ b/mlir/lib/Dialect/Rock/IR/MfmaInsnGroup.cpp @@ -485,16 +485,20 @@ VectorType MfmaInsn::getRetType(Type elementType) { } // Check if the MFMA instruction is coherent with the K dimension configuration. -// When k_base >= 8, allows kpack < k_base if k_base % kpack == 0. -bool MfmaInsn::isCoherentWithK(int64_t kpack, int64_t kPerBlock) { +// Double-buffer pipelines allow kpack < k_base if k_base % kpack == 0. +// Single-buffer pipelines require kpack >= k_base to avoid wasting MFMA cycles. +bool MfmaInsn::isCoherentWithK(int64_t kpack, int64_t kPerBlock, + int64_t scheduleVersion) { int64_t totalKPerBlock = kpack * kPerBlock; + // Double-buffer pipelines: scheduleVersion 2 or 4 + bool isDoubleBuffer = (scheduleVersion == 2 || scheduleVersion == 4); if (kpack > 1) { if (kpack < attr.k_base) { - // Relaxed kpack check only for MFMA with k_base >= 8 - if (attr.k_base < 8) { - LLVM_DEBUG(llvm::dbgs() << "kpack (" << kpack << ") must be >= k_base (" - << attr.k_base << ")\n"); + if (!isDoubleBuffer) { + LLVM_DEBUG(llvm::dbgs() + << "Should pack at least k_base elements and avoid waste " + "xdlopsgemm cycles\n"); return false; } if (attr.k_base % kpack != 0) { @@ -581,14 +585,16 @@ static MfmaTypeId convertTypesToId(Type dataTypeA, Type dataTypeB) { FailureOr MfmaInsnGroup::select(Type elementTypeA, Type elementTypeB, StringRef arch, - int64_t mnPerXdl, int64_t kPack, int64_t kPackPerBlock) { + int64_t mnPerXdl, int64_t kPack, int64_t kPackPerBlock, + int64_t scheduleVersion) { LLVM_DEBUG(llvm::dbgs() << "Invoke Mfma group selection:\n" << "elementType A: " << elementTypeA << "\n" << "elementType B: " << elementTypeB << "\n" << "arch: " << arch << "\n" << "mnPerXdl: " << mnPerXdl << "\n" << "kPack: " << kPack << "\n" - << "KPackPerBlock: " << kPackPerBlock << "\n"); + << "KPackPerBlock: " << kPackPerBlock << "\n" + << "scheduleVersion: " << scheduleVersion << "\n"); // Use 64x64 as base unit in large waves int64_t mPerMfmaGroup = getLenPerMfmaGroup(mnPerXdl); @@ -621,7 +627,7 @@ MfmaInsnGroup::select(Type elementTypeA, Type elementTypeB, StringRef arch, // gfx950 has double rate instructions. Select from those first. selectFrom(getMfmaInsnGroupAttrMapGfx950()); if (succeeded(result)) { - if (result->isCoherentWithK(kPack, kPackPerBlock)) { + if (result->isCoherentWithK(kPack, kPackPerBlock, scheduleVersion)) { LLVM_DEBUG(llvm::dbgs() << "Selected gfx950 double rate instruction\n"); return; } @@ -704,6 +710,7 @@ SmallVector MfmaInsnGroup::getImms() { return groupAttr.imms; } -bool MfmaInsnGroup::isCoherentWithK(int64_t kpack, int64_t kPerBlock) { - return insn.isCoherentWithK(kpack, kPerBlock); +bool MfmaInsnGroup::isCoherentWithK(int64_t kpack, int64_t kPerBlock, + int64_t scheduleVersion) { + return insn.isCoherentWithK(kpack, kPerBlock, scheduleVersion); } diff --git a/mlir/lib/Dialect/Rock/Tuning/GridwiseGemmParams.cpp b/mlir/lib/Dialect/Rock/Tuning/GridwiseGemmParams.cpp index 8323e8b79217..650daa3bbfa2 100644 --- a/mlir/lib/Dialect/Rock/Tuning/GridwiseGemmParams.cpp +++ b/mlir/lib/Dialect/Rock/Tuning/GridwiseGemmParams.cpp @@ -505,13 +505,15 @@ PopulateParamsXDL::isValidBlockwiseGemm(RockAccelTuningParamAttrInterface param, } auto maybeMfmaInsnGroup = MfmaInsnGroup::select(dataTypeA, dataTypeB, arch, mnPerXdl, - param.getKpack(), param.getKpackPerBlock()); + param.getKpack(), param.getKpackPerBlock(), + param.getScheduleVersion()); if (failed(maybeMfmaInsnGroup)) { LLVM_DEBUG(llvm::dbgs() << "Failed to select xdlops instruction group.\n"); return failure(); } MfmaInsnGroup mfmaGroup = *maybeMfmaInsnGroup; - if (!mfmaGroup.isCoherentWithK(param.getKpack(), param.getKpackPerBlock())) { + if (!mfmaGroup.isCoherentWithK(param.getKpack(), param.getKpackPerBlock(), + param.getScheduleVersion())) { LLVM_DEBUG( llvm::dbgs() << "Mfma instruction group selection is not compatible with k.\n"); @@ -539,13 +541,14 @@ PopulateParamsXDL::getTuningParameters(OpBuilder &b, KernelType opType, int64_t mnPerXdl = params.getMnPerXdl(); auto maybeMfmaInsnGroup = MfmaInsnGroup::select(dataTypeA, dataTypeB, arch, mnPerXdl, - params.getKpack(), params.getKpackPerBlock()); + params.getKpack(), params.getKpackPerBlock(), + params.getScheduleVersion()); if (failed(maybeMfmaInsnGroup)) { continue; } MfmaInsnGroup mfmaGroup = *maybeMfmaInsnGroup; - if (mfmaGroup.isCoherentWithK(params.getKpack(), - params.getKpackPerBlock())) { + if (mfmaGroup.isCoherentWithK(params.getKpack(), params.getKpackPerBlock(), + params.getScheduleVersion())) { res.push_back(params); } } diff --git a/mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp b/mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp index 62790611eae1..258b5369d7bc 100644 --- a/mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp +++ b/mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp @@ -1461,7 +1461,8 @@ AccelEmitter::select(GemmFeatures features, Type dataTypeA, Type dataTypeB, if (isMfma) { auto maybeMfmaInsnGroup = MfmaInsnGroup::select( dataTypeA, dataTypeB, arch, tuningParams.getMnPerXdl(), - tuningParams.getKpack(), tuningParams.getKpackPerBlock()); + tuningParams.getKpack(), tuningParams.getKpackPerBlock(), + tuningParams.getScheduleVersion()); if (failed(maybeMfmaInsnGroup)) { return nullptr; } From 3e49dfb87e41921a24b6737ecccf3374476d9a82 Mon Sep 17 00:00:00 2001 From: stefankoncarevic Date: Fri, 20 Feb 2026 08:10:09 -0600 Subject: [PATCH 3/4] Fix clang-format --- .../lib/Dialect/Rock/Tuning/GridwiseGemmParams.cpp | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/mlir/lib/Dialect/Rock/Tuning/GridwiseGemmParams.cpp b/mlir/lib/Dialect/Rock/Tuning/GridwiseGemmParams.cpp index 650daa3bbfa2..e3aa8873a586 100644 --- a/mlir/lib/Dialect/Rock/Tuning/GridwiseGemmParams.cpp +++ b/mlir/lib/Dialect/Rock/Tuning/GridwiseGemmParams.cpp @@ -503,10 +503,9 @@ PopulateParamsXDL::isValidBlockwiseGemm(RockAccelTuningParamAttrInterface param, if (auto derivedParam = dyn_cast(param)) { mnPerXdl = derivedParam.getMnPerXdl(); } - auto maybeMfmaInsnGroup = - MfmaInsnGroup::select(dataTypeA, dataTypeB, arch, mnPerXdl, - param.getKpack(), param.getKpackPerBlock(), - param.getScheduleVersion()); + auto maybeMfmaInsnGroup = MfmaInsnGroup::select( + dataTypeA, dataTypeB, arch, mnPerXdl, param.getKpack(), + param.getKpackPerBlock(), param.getScheduleVersion()); if (failed(maybeMfmaInsnGroup)) { LLVM_DEBUG(llvm::dbgs() << "Failed to select xdlops instruction group.\n"); return failure(); @@ -539,10 +538,9 @@ PopulateParamsXDL::getTuningParameters(OpBuilder &b, KernelType opType, continue; int64_t mnPerXdl = params.getMnPerXdl(); - auto maybeMfmaInsnGroup = - MfmaInsnGroup::select(dataTypeA, dataTypeB, arch, mnPerXdl, - params.getKpack(), params.getKpackPerBlock(), - params.getScheduleVersion()); + auto maybeMfmaInsnGroup = MfmaInsnGroup::select( + dataTypeA, dataTypeB, arch, mnPerXdl, params.getKpack(), + params.getKpackPerBlock(), params.getScheduleVersion()); if (failed(maybeMfmaInsnGroup)) { continue; } From 4759559aceb06ca6957f50f7fd6a9b4967515799 Mon Sep 17 00:00:00 2001 From: stefankoncarevic Date: Mon, 23 Feb 2026 04:24:54 -0600 Subject: [PATCH 4/4] Minor change --- mlir/test/e2e/PrLdsTransposeLoad.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/test/e2e/PrLdsTransposeLoad.toml b/mlir/test/e2e/PrLdsTransposeLoad.toml index d1606c50cf89..c2fcb0064a05 100644 --- a/mlir/test/e2e/PrLdsTransposeLoad.toml +++ b/mlir/test/e2e/PrLdsTransposeLoad.toml @@ -88,7 +88,7 @@ config = "-g 1 -m 16 -k 64 -n 16 --transA=true --transB=true --perf_config v3:16 config = "-g 1 -m 32 -k 128 -n 32 --transA=true --transB=true --perf_config v3:64,32,16,32,32,1,1,4,2,1,1" [[suite.test]] -config = "-g 1 -m 32 -k 256 -n 32 --transA=true --transB=true --perf_config v3:32,64,16,32,32,1,8,3,2,1,1" +config = "-g 1 -m 32 -k 256 -n 32 --transA=true --transB=true --perf_config v3:32,64,16,32,32,1,1,3,2,1,1" [[suite.test]] config = "-g 256 -m 32 -k 32 -n 32 --transA=true --transB=true --perf_config v3:32,32,32,32,16,1,1,4,2,1,1"