From 035e8806645facc26d338fb6fd1d4da15fb2e03a Mon Sep 17 00:00:00 2001 From: stefan koncarevic Date: Tue, 23 Dec 2025 08:57:50 -0600 Subject: [PATCH 01/14] Add LDS transpose load support for attention kernel MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit adds LDS transpose load optimization support for the attention kernel's GEMM operations. GEMM0 (K × Q): - Added decideLDSTransposeForOperands() call for GEMM0 - K matrix (operand A): Can use LDS transpose when directToLDS is enabled - Q matrix (operand B): Can use LDS transpose only when NOT prefetched to registers. Added bLoadsFromLDS parameter to decideLDSTransposeForOperands() to handle the prefetch case correctly. GEMM1 (V × P): - Added decideLDSTransposeForOperands() call for GEMM1 - V matrix (operand A): Can use LDS transpose when directToLDS is enabled - P matrix (operand B): Never uses LDS transpose since it comes from registers (softmax output), not from global memory. directToLDS is always false for P. API changes: - Extended decideLDSTransposeForOperands() with optional bLoadsFromLDS parameter (default=true). When false, operand B is immediately marked as NOT USABLE for LDS transpose, regardless of other constraints. BlockwiseMatrixParamsAttr changes: - Replaced mnPerXdl with accelDDim and accelKDim parameters. The MFMA instruction geometry (e.g., 16x16, 16x32, 32x16) requires both the D dimension (M or N) and K dimension to correctly configure the LDS transpose load. mnPerXdl only captured one dimension, which was insufficient for determining the correct hardware transpose behavior and offset calculations. --- .../mlir/Dialect/Rock/IR/RockAttrDefs.td | 2 + .../Dialect/Rock/utility/LdsTransposeLoad.h | 6 ++- .../Transforms/BlockwiseGemmToThreadwise.cpp | 4 +- .../BlockwiseLoadTileToThreadwise.cpp | 4 +- .../Transforms/GridwiseGemmToBlockwise.cpp | 45 +++++++++++++++---- .../Dialect/Rock/utility/LdsTransposeLoad.cpp | 26 +++++++---- 6 files changed, 65 insertions(+), 22 deletions(-) diff --git a/mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td b/mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td index e00fc4e2f696..c7afc9d3b163 100644 --- a/mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td +++ b/mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td @@ -518,6 +518,7 @@ def Rock_BlockwiseMatrixParamsAttr : Rock_Attr<"BlockwiseMatrixParams", []> { - g: gemm parameter G - d: gemm parameter D (could be M or N) - inDPerThread: How many elements of D (M or N) each thread is going to load from memory. + - accelDDim: Accelerator instruction D dimension (for LDS transpose support, typically 16 or 32). - accelKDim: Accelerator instruction K dimension (for LDS transpose support). }]; let parameters = (ins "Type":$elementType, "Type":$elementTypeLoad, @@ -525,6 +526,7 @@ def Rock_BlockwiseMatrixParamsAttr : Rock_Attr<"BlockwiseMatrixParams", []> { "bool":$directToLDS, "bool":$splitKAcrossThreadsFirst, "int64_t":$g, "int64_t":$d, "int64_t":$inDPerThread, DefaultValuedParameter<"bool", "false">:$ldsTransposeEnabled, + DefaultValuedParameter<"int64_t", "0">:$accelDDim, DefaultValuedParameter<"int64_t", "0">:$accelKDim); let assemblyFormat = [{ diff --git a/mlir/include/mlir/Dialect/Rock/utility/LdsTransposeLoad.h b/mlir/include/mlir/Dialect/Rock/utility/LdsTransposeLoad.h index 2264a6b6c922..07d87d6aca33 100644 --- a/mlir/include/mlir/Dialect/Rock/utility/LdsTransposeLoad.h +++ b/mlir/include/mlir/Dialect/Rock/utility/LdsTransposeLoad.h @@ -65,13 +65,17 @@ struct LDSTransposeDecision { // Decides whether to enable LDS transpose for operands A and B // based on architecture, MFMA geometry, kpack constraints, and layout config. +// Parameters: +// - bLoadsFromLDS: Whether operand B actually loads from LDS. +// If false (e.g., Q matrix prefetched to registers), B will be disabled +// for LDS transpose regardless of other constraints. LDSTransposeDecision decideLDSTransposeForOperands( const rock::accel::AccelEmitter *accelEmitter, StringRef arch, Type elementTypeA, Type elementTypeB, bool directToLDS, const LDSLayoutConfigDim &ldsLayoutConfigA, const LDSLayoutConfigDim &ldsLayoutConfigB, int64_t mPerBlock, int64_t nPerBlock, int64_t kPerBlock, int64_t mPerWave, int64_t nPerWave, - int64_t kpack, bool doubleBuffering); + int64_t kpack, bool doubleBuffering, bool bLoadsFromLDS = true); } // namespace mlir::rock::hwtranspose diff --git a/mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp b/mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp index cd49dcc315d3..dfec0443cf7f 100644 --- a/mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp @@ -453,9 +453,9 @@ struct BlockwiseGemmAccelRewritePattern return nullptr; // Get accelerator dimensions from matrix params and tuning params - // accelDDim = mnPerXdl (for MFMA instructions with blocksMfma=1) + // accelDDim = accelDDim (for MFMA instructions with blocksMfma=1) // accelKDim = accelKDim from BlockwiseMatrixParamsAttr - int64_t accelDDim = tuningParams.getMnPerXdl(); + int64_t accelDDim = matrixParams.getAccelDDim(); int64_t accelKDim = matrixParams.getAccelKDim(); if (accelDDim <= 0 || accelKDim <= 0) diff --git a/mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp b/mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp index b6b019dc38eb..744647fbb013 100644 --- a/mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp @@ -207,9 +207,9 @@ class LoweringBlockwiseLoadTileOp final LDSTransposeConfigAttr transposeAttr = nullptr; if (ldsTransposeEnabled) { // Get accelerator dimensions from matrix params and tuning params - // accelDDim = mnPerXdl (for MFMA instructions with blocksMfma=1) + // accelDDim = AccelDDim (for MFMA instructions with blocksMfma=1) // accelKDim = accelKDim from BlockwiseMatrixParamsAttr - int64_t accelDDim = tuningParams.getMnPerXdl(); + int64_t accelDDim = matrixParams.getAccelDDim(); int64_t accelKDim = matrixParams.getAccelKDim(); assert(accelDDim > 0 && accelKDim > 0 && "ldsTranspose=true requires valid accel geometry in params"); diff --git a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp index 156d35af4842..7af0b8509581 100644 --- a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp @@ -2126,6 +2126,10 @@ struct GridwiseAttentionAccelRewritePattern bool directToLDSQ = loadTypeQ == GemmLoadTileType::DirectToLDSDefault || loadTypeQ == GemmLoadTileType::DirectToLDSDoubleBuffer; + // Determine if Q loads from LDS (for LDS transpose decision) + // Q bypasses LDS only when prefetch is active + bool qLoadsFromLDS = !prefetchQTile; + // Note that kPerBlock for Gemm1B is mPerBlock of Gemm0 out // Note that mPerBlock for Gemm1A is mPerBlock of Gemm0 out // Note that nPerBlock for Gemm1B is nPerBlock of Gemm0 out @@ -2412,34 +2416,60 @@ struct GridwiseAttentionAccelRewritePattern runEarlyExit(rewriter, loc, start, end, splitKV, gemm0MPerBlock, op.getPrePadG0M(), isCausal, isKVCache); - // create matrix params (LDS transpose not supported for attention) + // LDS Transpose Decision for GEMM0 (K x Q) + // Pass qLoadsFromLDS to disable LDS transpose for Q when it's prefetched + hwtranspose::LDSTransposeDecision ldsDecisionGemm0 = + hwtranspose::decideLDSTransposeForOperands( + accelEmitterPtrGemm0.get(), arch, elemTypeK, elemTypeQ, directToLDS, + ldsLayoutCfgMG0, ldsLayoutCfgNG0, gemm0MPerBlock, gemm0NPerBlock, + gemm0KPerBlock, gemm0TuningParams.getMPerWave(), + gemm0TuningParams.getNPerWave(), gemm0kpack, + /*doubleBuffering=*/false, /*bLoadsFromLDS=*/qLoadsFromLDS); + + // create matrix params BlockwiseMatrixParamsAttr matrixParamsK = BlockwiseMatrixParamsAttr::get( rewriter.getContext(), elemTypeK, elemTypeKLoad, ldsLayoutCfgMG0.doRotateWithK, ldsLayoutCfgMG0.doSwapThreadIterSubDims, ldsLayoutCfgMG0.ldsLayoutDxK, directToLDS, /*splitKAcrossThreadsFirst=*/false, gemm0G, gemm0M, gemm0InMPerThread, - /*ldsTransposeEnabled=*/false, /*accelKDim=*/0); + /*ldsTransposeEnabled=*/ldsDecisionGemm0.enableA, + /*accelDDim=*/ldsDecisionGemm0.mfmaDDim, + /*accelKDim=*/ldsDecisionGemm0.mfmaKDim); BlockwiseMatrixParamsAttr matrixParamsQ = BlockwiseMatrixParamsAttr::get( rewriter.getContext(), elemTypeQ, elemTypeQLoad, ldsLayoutCfgNG0.doRotateWithK, ldsLayoutCfgNG0.doSwapThreadIterSubDims, ldsLayoutCfgNG0.ldsLayoutDxK, directToLDSQ, /*splitKAcrossThreadsFirst=*/false, gemm0G, gemm0N, gemm0InNPerThread, - /*ldsTransposeEnabled=*/false, /*accelKDim=*/0); + /*ldsTransposeEnabled=*/ldsDecisionGemm0.enableB, + /*accelDDim=*/ldsDecisionGemm0.mfmaDDim, + /*accelKDim=*/ldsDecisionGemm0.mfmaKDim); + + // LDS Transpose Decision for GEMM1 (V x P) + // Only V (operand A) can use LDS transpose + hwtranspose::LDSTransposeDecision ldsDecisionGemm1 = + hwtranspose::decideLDSTransposeForOperands( + accelEmitterPtrGemm1.get(), arch, elemTypeV, elemTypeV, directToLDS, + ldsLayoutCfgMG1, ldsLayoutCfgMG1, gemm1MPerBlock, gemm1NPerBlock, + gemm1KPerBlock, gemm1TuningParams.getMPerWave(), + gemm1TuningParams.getNPerWave(), gemm1kpack, + /*doubleBuffering=*/false, /*bLoadsFromLDS=*/false); BlockwiseMatrixParamsAttr matrixParamsV = BlockwiseMatrixParamsAttr::get( rewriter.getContext(), elemTypeV, elemTypeVLoad, ldsLayoutCfgMG1.doRotateWithK, ldsLayoutCfgMG1.doSwapThreadIterSubDims, ldsLayoutCfgMG1.ldsLayoutDxK, directToLDS, doBypassLDSSecondGemm, gemm0G, gemm1M, gemm1InMPerThread, - /*ldsTransposeEnabled=*/false, /*accelKDim=*/0); + /*ldsTransposeEnabled=*/ldsDecisionGemm1.enableA, + /*accelDDim=*/ldsDecisionGemm1.mfmaDDim, + /*accelKDim=*/ldsDecisionGemm1.mfmaKDim); BlockwiseMatrixParamsAttr matrixParamsKxQ = BlockwiseMatrixParamsAttr::get( rewriter.getContext(), elemTypeV, elemTypeVLoad, /*rotateDWithK=*/false, /*swapThreadIterSubDims=*/false, /*LDSLayoutDxK=*/false, /*directToLDS=*/false, /*splitKAcrossThreadsFirst=*/false, gemm0G, gemm1N, gemm1InMPerThread, - /*ldsTransposeEnabled=*/false, /*accelKDim=*/0); + /*ldsTransposeEnabled=*/false, /*accelDDim=*/0, /*accelKDim=*/0); // If gemm0K is equal to gemm0KPerBlock that means // effectively there is no K loop. Therefore, we @@ -3256,9 +3286,6 @@ struct GridwiseGemmAccelRewritePattern directToLDS, ldsLayoutConfigA, ldsLayoutConfigB, mPerBlock, nPerBlock, kPerBlock, mPerWave, nPerWave, kpack, doubleBuffering); - // Note: LDS transpose geometry (accelKDim) is now stored in - // BlockwiseMatrixParamsAttr, not in tuning params - LLVM_DEBUG(llvm::dbgs() << "M: " << M << "\n" << "N: " << N << "\n" @@ -3306,6 +3333,7 @@ struct GridwiseGemmAccelRewritePattern ldsLayoutConfigA.doSwapThreadIterSubDims, ldsLayoutConfigA.ldsLayoutDxK, directToLDS, /*splitKAcrossThreadsFirst=*/false, G, M, copyMPerThread, /*ldsTranspose=*/ldsDecision.enableA, + /*accelDDim=*/ldsDecision.mfmaDDim, /*accelKDim=*/ldsDecision.mfmaKDim); BlockwiseMatrixParamsAttr matrixParamsB = BlockwiseMatrixParamsAttr::get( @@ -3314,6 +3342,7 @@ struct GridwiseGemmAccelRewritePattern ldsLayoutConfigB.doSwapThreadIterSubDims, ldsLayoutConfigB.ldsLayoutDxK, directToLDS, /*splitKAcrossThreadsFirst=*/false, G, N, copyNPerThread, /*ldsTranspose=*/ldsDecision.enableB, + /*accelDDim=*/ldsDecision.mfmaDDim, /*accelKDim=*/ldsDecision.mfmaKDim); // Allocate LDS. diff --git a/mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp b/mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp index 7431b18ce363..5ce33878ccd5 100644 --- a/mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp +++ b/mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp @@ -202,7 +202,7 @@ LDSTransposeDecision decideLDSTransposeForOperands( const LDSLayoutConfigDim &ldsLayoutConfigA, const LDSLayoutConfigDim &ldsLayoutConfigB, int64_t mPerBlock, int64_t nPerBlock, int64_t kPerBlock, int64_t mPerWave, int64_t nPerWave, - int64_t kpack, bool doubleBuffering) { + int64_t kpack, bool doubleBuffering, bool bLoadsFromLDS) { LDSTransposeDecision result; @@ -234,13 +234,21 @@ LDSTransposeDecision decideLDSTransposeForOperands( << (decA.usable ? "USABLE" : "NOT USABLE") << "\n"); // Make decision for operand B - Decision decB = - makeDecision(arch, elementTypeA, elementTypeB, directToLDS, shape, - OperandKind::B, ldsLayoutConfigB, mPerBlock, nPerBlock, - kPerBlock, mPerWave, nPerWave, doubleBuffering); - - LLVM_DEBUG(llvm::dbgs() << "[lds_transpose] Decision for operand B: " - << (decB.usable ? "USABLE" : "NOT USABLE") << "\n"); + // If B doesn't load from LDS (e.g., prefetched Q matrix), it can't use + // LDS transpose regardless of other constraints + Decision decB; + if (!bLoadsFromLDS) { + decB.usable = false; + LLVM_DEBUG(llvm::dbgs() + << "[lds_transpose] Decision for operand B: NOT USABLE " + << "(bypasses LDS - prefetched to registers)\n"); + } else { + decB = makeDecision(arch, elementTypeA, elementTypeB, directToLDS, shape, + OperandKind::B, ldsLayoutConfigB, mPerBlock, nPerBlock, + kPerBlock, mPerWave, nPerWave, doubleBuffering); + LLVM_DEBUG(llvm::dbgs() << "[lds_transpose] Decision for operand B: " + << (decB.usable ? "USABLE" : "NOT USABLE") << "\n"); + } // ======================================== // KPACK CONSTRAINT LOGIC @@ -879,7 +887,7 @@ static StrideConfig computeStrideConfiguration(OperandKind operand, } } else { // Operand B (N dimension) - if (waveGrid.wavesInN >= 2 && waveGrid.wavesInN == waveGrid.wavesInM) { + if (waveGrid.wavesInN >= 2) { // BALANCED GRID (2×2, 3×3, 4×4) → tiles are interleaved in N dimension // Special case: balanced grids require interleaved tile access config.tileOffsetStride = waveGrid.wavesInN * dDim; From ce7f0c956fa5fb4409450617c89379511edc0924 Mon Sep 17 00:00:00 2001 From: stefan koncarevic Date: Tue, 23 Dec 2025 12:28:03 -0600 Subject: [PATCH 02/14] Extend LDS transpose load support for 8 and 16 wave configurations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit extends the LDS transpose load optimization to support workgroups with 8 waves (blockSize=512) and 16 waves (blockSize=1024). Previously, the optimization was limited to 1-4 waves only. This restriction has been lifted to enable LDS transpose load for larger workgroup sizes commonly used in high-performance GEMM configurations. Changes: - Extended numWaves limit from 4 to 16 in decideLDSTransposeForOperands() - Added wave grid layout computation for 8 waves: - 2×4, 4×2 (preferred balanced layouts) - 1×8, 8×1 (fallback layouts) - Added wave grid layout computation for 16 waves: - 4×4 (preferred balanced layout) - 2×8, 8×2 (semi-balanced layouts) - 1×16, 16×1 (fallback layouts) Updated tests: - lds_transpose_attributes_toblockwise.mlir: Changed CHECK-NOT to CHECK for 8 and 16 wave tests, confirming LDS transpose is now enabled for these configurations - PrLdsTransposeLoad.toml: Added e2e test cases for 8-wave (4×2, 1×8) and 16-wave (8×2, 1×16) grid configurations --- .../Dialect/Rock/utility/LdsTransposeLoad.cpp | 55 +++++++++++++++++-- .../lds_transpose_attributes_toblockwise.mlir | 4 +- mlir/test/e2e/PrLdsTransposeLoad.toml | 16 ++++++ 3 files changed, 68 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp b/mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp index 5ce33878ccd5..7ae30397cc3d 100644 --- a/mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp +++ b/mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp @@ -286,9 +286,10 @@ LDSTransposeDecision decideLDSTransposeForOperands( } // else - implicitly: neither operand usable, enableA/enableB remain false. - // TODO: adapt code to support numWaves = 8, 16 and 32 (only wmma). + // Check if numWaves is supported (1, 2, 3, 4, 8, 16) + // TODO: support 32 waves for WMMA int64_t numWaves = (mPerBlock * nPerBlock) / (mPerWave * nPerWave); - if (numWaves > 4) { + if (numWaves > 16) { result.enableA = false; result.enableB = false; } @@ -753,9 +754,7 @@ computeWaveGridLayout(PatternRewriter &b, Location loc, Value waveId, // Determine wave grid layout based on physical waves and wave tiles // This distributes waves spatially across M and N dimensions - // Note: numWaves can only be 1, 2, 3, or 4 (for 64, 128, 192, 256 - // threads) - // TODO: numWaves can be 8 and 16 (and 32 for wmma) as well, update this code + // Supported: 1, 2, 3, 4, 8, 16 waves (32 is WMMA only, not yet supported) int64_t wavesInM = 1; int64_t wavesInN = 1; @@ -819,6 +818,52 @@ computeWaveGridLayout(PatternRewriter &b, Location loc, Value waveId, } } break; + + case 8: + // Eight waves: prefer 2×4 or 4×2 (balanced), then 1×8 or 8×1 + if (waveTilesInM >= 2 && waveTilesInN >= 4) { + wavesInM = 2; + wavesInN = 4; + } else if (waveTilesInM >= 4 && waveTilesInN >= 2) { + wavesInM = 4; + wavesInN = 2; + } else if (waveTilesInN >= 8) { + wavesInM = 1; + wavesInN = 8; + } else if (waveTilesInM >= 8) { + wavesInM = 8; + wavesInN = 1; + } else { + // Fallback: prefer 2×4 layout + wavesInM = 2; + wavesInN = 4; + } + break; + + case 16: + // Sixteen waves: prefer 4×4 (balanced), then 2×8, 8×2, 1×16, 16×1 + if (waveTilesInM >= 4 && waveTilesInN >= 4) { + wavesInM = 4; + wavesInN = 4; + } else if (waveTilesInM >= 2 && waveTilesInN >= 8) { + wavesInM = 2; + wavesInN = 8; + } else if (waveTilesInM >= 8 && waveTilesInN >= 2) { + wavesInM = 8; + wavesInN = 2; + } else if (waveTilesInN >= 16) { + wavesInM = 1; + wavesInN = 16; + } else if (waveTilesInM >= 16) { + wavesInM = 16; + wavesInN = 1; + } else { + // Fallback: prefer 4×4 layout + wavesInM = 4; + wavesInN = 4; + } + break; + default: return failure(); } diff --git a/mlir/test/Dialect/Rock/lds_transpose_attributes_toblockwise.mlir b/mlir/test/Dialect/Rock/lds_transpose_attributes_toblockwise.mlir index 144b0a893e7c..2408996c5c19 100644 --- a/mlir/test/Dialect/Rock/lds_transpose_attributes_toblockwise.mlir +++ b/mlir/test/Dialect/Rock/lds_transpose_attributes_toblockwise.mlir @@ -25,7 +25,7 @@ module attributes {mhal.arch = "amdgcn-amd-amdhsa:gfx950"} { %b = rock.transform %arg1 by (d1 * 64 + d2)> by [ ["raw"] at [0]>, [] at []>] bounds = [1, 32, 64] -> [2048]> : memref<2048xf16> to memref<1x32x64xf16> %c = rock.transform %arg2 by (d1 * 64 + d2)> by [ ["raw"] at [0]>, [] at []>] bounds = [1, 64, 64] -> [4096]> : memref<4096xf16> to memref<1x64x64xf16> - // CHECK-NOT: ldsTransposeEnabled + // CHECK: ldsTransposeEnabled = true rock.gridwise_gemm_accel(%a, %b, %c) storeMethod(set) features = mfma|dot|atomic_add|atomic_add_bf16|atomic_add_f16|direct_to_lds_32b|direct_to_lds_128b @@ -46,7 +46,7 @@ module attributes {mhal.arch = "amdgcn-amd-amdhsa:gfx950"} { %b = rock.transform %arg1 by (d1 * 64 + d2)> by [ ["raw"] at [0]>, [] at []>] bounds = [1, 32, 64] -> [2048]> : memref<2048xf16> to memref<1x32x64xf16> %c = rock.transform %arg2 by (d1 * 64 + d2)> by [ ["raw"] at [0]>, [] at []>] bounds = [1, 64, 64] -> [4096]> : memref<4096xf16> to memref<1x64x64xf16> - // CHECK-NOT: ldsTransposeEnabled + // CHECK: ldsTransposeEnabled = true rock.gridwise_gemm_accel(%a, %b, %c) storeMethod(set) features = mfma|dot|atomic_add|atomic_add_bf16|atomic_add_f16|direct_to_lds_32b|direct_to_lds_128b diff --git a/mlir/test/e2e/PrLdsTransposeLoad.toml b/mlir/test/e2e/PrLdsTransposeLoad.toml index 2eb5bbbaa111..88f201f705fe 100644 --- a/mlir/test/e2e/PrLdsTransposeLoad.toml +++ b/mlir/test/e2e/PrLdsTransposeLoad.toml @@ -52,6 +52,22 @@ config = "-g 1 -m 4096 -k 256 -n 384 --transA=true --transB=false --perf_config [[suite.test]] config = "-g 256 -m 900 -k 1280 -n 3840 --transA=true --transB=false --perf_config v3:128,32,8,32,16,8,1,3,2,1,1" +# 8 waves: 4×2 grid +[[suite.test]] +config = "-g 1 -m 64 -k 128 -n 32 --transA=true --transB=false --perf_config v4:64,32,8,16,16,16,16,1,4,2,0,0,1,1" + +# 8 waves: 1×8 grid +[[suite.test]] +config = "-g 1 -m 16 -k 128 -n 128 --transA=true --transB=false --perf_config v4:16,128,8,16,16,16,16,1,4,2,0,0,1,1" + +# 16 waves: 8×2 grid +[[suite.test]] +config = "-g 1 -m 128 -k 128 -n 32 --transA=true --transB=false --perf_config v4:128,32,8,16,16,16,16,1,4,2,0,0,1,1" + +# 16 waves: 1×16 grid +[[suite.test]] +config = "-g 1 -m 16 -k 128 -n 256 --transA=true --transB=false --perf_config v4:16,256,8,16,16,16,16,1,4,2,0,0,1,1" + # ============================================================================ # Suite 2: ONLY A uses LDS transpose (transA=true, transB=true) # kpack=1 (single operand constraint), scheduleVersion swapped (3<->4) From d28193c5691c129db2e8871fb5b4e52c8f38cc97 Mon Sep 17 00:00:00 2001 From: stefan koncarevic Date: Fri, 26 Dec 2025 05:04:25 -0600 Subject: [PATCH 03/14] Fix K-access formula for hybrid LDS transpose load scenario When one operand uses regular load and the other uses LDS transpose load, the regular load must use a compatible K-access pattern. The new formula is only applied when: - useLdsTransposeLoad is true (hybrid scenario) - kVec >= kBase (enough elements to decompose) This ensures correct data alignment between regular and transpose loads for MFMA operations, and prevents assertion failures when kpack < kBase. Changes: - Add useLdsTransposeLoad parameter to wrapLDSBufferForLoad - Implement hybrid K-access formula with blk_d/blk_k split - Pass LDS transpose state from BlockwiseGemmToThreadwise - Update tests in PrLdsTransposeLoad.toml --- .../mlir/Dialect/Rock/IR/AccelEmitter.h | 13 ++- .../Transforms/BlockwiseGemmToThreadwise.cpp | 20 ++++- .../BlockwiseLoadTileToThreadwise.cpp | 12 ++- .../lib/Dialect/Rock/utility/AccelEmitter.cpp | 85 ++++++++++++++++--- .../Dialect/Rock/utility/LdsTransposeLoad.cpp | 47 ++++------ mlir/test/e2e/PrLdsTransposeLoad.toml | 54 ++++++++---- 6 files changed, 157 insertions(+), 74 deletions(-) diff --git a/mlir/include/mlir/Dialect/Rock/IR/AccelEmitter.h b/mlir/include/mlir/Dialect/Rock/IR/AccelEmitter.h index a9b1d3be4919..12b677fbfc15 100644 --- a/mlir/include/mlir/Dialect/Rock/IR/AccelEmitter.h +++ b/mlir/include/mlir/Dialect/Rock/IR/AccelEmitter.h @@ -99,11 +99,14 @@ struct AccelEmitter { /// Return a wrapped view of the LDS buffer tailored for the accelerator /// load pattern. This is similar to wrapLDSBufferForStore, but while storing /// in LDS follows a similar pattern among accelerators, loading from LDS - /// is dependent on the type of accelerator we are targeting + /// is dependent on the type of accelerator we are targeting. + /// When useLdsTransposeLoad is true, a special K access pattern + /// is used that is compatible with LDS transpose load on the other operand. virtual Value wrapLDSBufferForLoad(OpBuilder &b, Location loc, Value buffer, const BlockwiseMatrixParamsAttr &matrixParams, - int64_t blockSize, StringRef dName) const = 0; + int64_t blockSize, StringRef dName, + bool useLdsTransposeLoad = false) const = 0; /// This functions creates the subtile views that is : /// 1) gridSubTileView : @@ -187,7 +190,8 @@ struct MfmaEmitter : public AccelEmitter { Value wrapLDSBufferForLoad(OpBuilder &b, Location loc, Value buffer, const BlockwiseMatrixParamsAttr &matrixParams, - int64_t blockSize, StringRef dName) const override; + int64_t blockSize, StringRef dName, + bool useLdsTransposeLoad = false) const override; FailureOr createAccelGemmOperandTransforms( OpBuilder &b, Location loc, int64_t kIters, @@ -240,7 +244,8 @@ struct WmmaEmitter : public AccelEmitter { Value wrapLDSBufferForLoad(OpBuilder &b, Location loc, Value buffer, const BlockwiseMatrixParamsAttr &matrixParams, - int64_t blockSize, StringRef dName) const override; + int64_t blockSize, StringRef dName, + bool useLdsTransposeLoad = false) const override; FailureOr createAccelGemmOperandTransforms( OpBuilder &b, Location loc, int64_t kIters, diff --git a/mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp b/mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp index dfec0443cf7f..6182c6512ca8 100644 --- a/mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp @@ -507,24 +507,36 @@ struct BlockwiseGemmAccelRewritePattern // considered a temporary hack until we have a proper way of "searching" // through different schedules (either heuristically or automatically) + // Determine if the other operand uses LDS transpose load + // This is needed to select the correct K access pattern for regular loads + bool bUsesLdsTranspose = matrixParamsB.getLdsTransposeEnabled(); + bool aUsesLdsTranspose = matrixParamsA.getLdsTransposeEnabled(); + Value wrappedLDSBufferForLoadA, wrappedLDSBufferForLoadB; if (loadAFromLDS) { + // When loading A, check if B uses transpose load wrappedLDSBufferForLoadA = accelEmitterPtr->wrapLDSBufferForLoad( - b, loc, op.getMatrixA(), matrixParamsA, op.getBlockSize(), "m"); + b, loc, op.getMatrixA(), matrixParamsA, op.getBlockSize(), "m", + /*useLdsTransposeLoad=*/bUsesLdsTranspose); } if (loadBFromLDS) { + // When loading B, check if A uses transpose load wrappedLDSBufferForLoadB = accelEmitterPtr->wrapLDSBufferForLoad( - b, loc, op.getMatrixB(), matrixParamsB, op.getBlockSize(), "n"); + b, loc, op.getMatrixB(), matrixParamsB, op.getBlockSize(), "n", + /*useLdsTransposeLoad=*/aUsesLdsTranspose); } Value wrappedLDSBufferForScaleA, wrappedLDSBufferForScaleB; if (isScaledGemm) { + // Scaled GEMM (FP4) doesn't support LDS transpose load yet if (loadAFromLDS) { wrappedLDSBufferForScaleA = accelEmitterPtr->wrapLDSBufferForLoad( - b, loc, op.getScaleA(), matrixParamsA, op.getBlockSize(), "m"); + b, loc, op.getScaleA(), matrixParamsA, op.getBlockSize(), "m", + /*useLdsTransposeLoad=*/false); } if (loadBFromLDS) { wrappedLDSBufferForScaleB = accelEmitterPtr->wrapLDSBufferForLoad( - b, loc, op.getScaleB(), matrixParamsB, op.getBlockSize(), "n"); + b, loc, op.getScaleB(), matrixParamsB, op.getBlockSize(), "n", + /*useLdsTransposeLoad=*/false); } } diff --git a/mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp b/mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp index 744647fbb013..5f238069ac91 100644 --- a/mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp @@ -79,12 +79,14 @@ class LoweringBlockwiseLoadTileOp final const std::unique_ptr &accelEmitterPtr, Value tid, StringRef dName, Value ldsView, Value regs, int64_t blockSize, bool forceUnroll, const BlockwiseMatrixParamsAttr &matrixParams, - LDSTransposeConfigAttr transposeAttr = nullptr) const { + LDSTransposeConfigAttr transposeAttr = nullptr, + bool useLdsTransposeLoad = false) const { // wrapLDSBufferForLoad is reading a single set of Ks into private memory // A/B[m/n, 0:kBasePerThread] Value ldsViewForLoad = accelEmitterPtr->wrapLDSBufferForLoad( - b, loc, ldsView, matrixParams, blockSize, dName); + b, loc, ldsView, matrixParams, blockSize, dName, + useLdsTransposeLoad); // We enhance the transformation from wrapLDSBufferForLoad using a builder // that, given a single index, splits it into "m"("n") and "k" and lets @@ -452,9 +454,13 @@ class LoweringBlockwiseLoadTileOp final ldsViewForGemm = viewBufferAs(b, ldsByteBuffer, ldsReadType); } + // Determine if the other operand uses LDS transpose load + // If we're loading A, check if B uses transpose; if loading B, check A + bool useLdsTransposeLoad = isA ? matrixParamsB.getLdsTransposeEnabled() + : matrixParamsA.getLdsTransposeEnabled(); generateReadLoop(loc, b, accelEmitterPtr, tid, dName, ldsViewForGemm, destRegisters, blockSize, forceUnroll, matrixParams, - transposeAttr); + transposeAttr, useLdsTransposeLoad); if (stageLDSReadNew) rock::YieldOp::create(b, loc); } diff --git a/mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp b/mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp index 47f4ff1c252b..3e7b6bc964ea 100644 --- a/mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp +++ b/mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp @@ -425,7 +425,7 @@ llvm::FailureOr MfmaEmitter::computeOutputTransforms( Value MfmaEmitter::wrapLDSBufferForLoad( OpBuilder &b, Location loc, Value buffer, const BlockwiseMatrixParamsAttr &matrixParams, int64_t blockSize, - StringRef dName) const { + StringRef dName, bool useLdsTransposeLoad) const { StringRef thisWaveDim = dName == "m" ? "wave_m" : "wave_n"; StringRef otherWaveDim = dName == "m" ? "wave_n" : "wave_m"; @@ -542,20 +542,75 @@ Value MfmaEmitter::wrapLDSBufferForLoad( TransformMapAttr splitWaveIdAttr = splitWaveId.get(); transformAttrs.push_back(splitWaveIdAttr); - TopDownTMBuilder toLDSRowCol = - TopDownTMBuilder::below(splitWaveId, splitWaveIdAttr); + TopDownTMBuilder toLDSRowCol(b, {}, {}, loc); + + // Use LDS transpose compatible K formula only when: + // 1. Other operand uses LDS transpose load (hybrid scenario) + // 2. kVec >= kBase (enough elements per load to decompose) + int64_t kBase = accelEmitterParams.kBase; + if (useLdsTransposeLoad && kVec >= kBase) { + // K access pattern must match the transpose load's pattern. + // For double-rate MFMA, properly distribute K across threads + MfmaInsnAttr mfmaAttr = mfmaGroup.getInsnAttr(); + int64_t instrK = mfmaAttr.k; + int64_t numBlksInK = instrK / kBase; + int64_t numBlksInD = (waveSize / inputSpanLen) / numBlksInK; + + // Split blk_id into blk_d (for D dimension) and blk_k (for K dimension) + TopDownTMBuilder splitBlkId = + TopDownTMBuilder::below(splitWaveId, splitWaveIdAttr); + splitBlkId.passThrough({"wave_m", "wave_n"}, {0, 1}, + {"wave_m", "wave_n"}); + splitBlkId.merge({"blk_d", "blk_k"}, {2, 3}, "blk_id", + {numBlksInD, numBlksInK}); + splitBlkId.passThrough({"blk_td", "d_iter", "k_iter", "k_vec"}, + {4, 5, 6, 7}, + {"blk_td", "d_iter", "k_iter", "k_vec"}); + TransformMapAttr splitBlkIdAttr = splitBlkId.get(); + transformAttrs.push_back(splitBlkIdAttr); + + // Split k_vec into k_mfma and k_base for kpack > kBase + int64_t numMfmaPerKVec = kVec / kBase; + + TopDownTMBuilder splitKVec = + TopDownTMBuilder::below(splitBlkId, splitBlkIdAttr); + splitKVec.passThrough({"wave_m", "wave_n"}, {0, 1}, {"wave_m", "wave_n"}); + splitKVec.passThrough({"blk_d", "blk_k", "blk_td", "d_iter", "k_iter"}, + {2, 3, 4, 5, 6}, + {"blk_d", "blk_k", "blk_td", "d_iter", "k_iter"}); + splitKVec.merge({"k_mfma", "k_base"}, {7, 8}, "k_vec", + {numMfmaPerKVec, kBase}); + TransformMapAttr splitKVecAttr = splitKVec.get(); + transformAttrs.push_back(splitKVecAttr); + + toLDSRowCol = TopDownTMBuilder::below(splitKVec, splitKVecAttr); + + // d = d_iter * dWaves * numBlksInD * inputSpanLen + wave_d * numBlksInD * + // inputSpanLen + blk_d * inputSpanLen + blk_td + toLDSRowCol.unmerge("d", 0, {"d_iter", thisWaveDim, "blk_d", "blk_td"}, + {dRepeats, dWaves, numBlksInD, inputSpanLen}); + + // k = k_iter * (numMfmaPerKVec * instrK) + k_mfma * instrK + blk_k * + // kBase + k_base + toLDSRowCol.unmerge("k", 1, {"k_iter", "k_mfma", "blk_k", "k_base"}, + {kIter, numMfmaPerKVec, numBlksInK, kBase}); - // d = blk_td + d_i * waveOffset - toLDSRowCol.unmerge("d", 0, {"d_iter", thisWaveDim, "blk_td"}, - {dRepeats, dWaves, inputSpanLen}); - if (matrixParams.getSplitKAcrossThreadsFirst()) { - // k = blk_id + (waveSize / inputSpanLen) * k_i - toLDSRowCol.unmerge("k", 1, {"k_iter", "blk_id", "k_vec"}, - {kIter, waveSize / inputSpanLen, kVec}); } else { - // k = k_i + kpackPerBlock * blk_id - toLDSRowCol.unmerge("k", 1, {"blk_id", "k_iter", "k_vec"}, - {waveSize / inputSpanLen, kIter, kVec}); + // Standard formula for regular load scenarios or when kVec < kBase + toLDSRowCol = TopDownTMBuilder::below(splitWaveId, splitWaveIdAttr); + + // d = blk_td + d_i * waveOffset + toLDSRowCol.unmerge("d", 0, {"d_iter", thisWaveDim, "blk_td"}, + {dRepeats, dWaves, inputSpanLen}); + if (matrixParams.getSplitKAcrossThreadsFirst()) { + // k = blk_id + (waveSize / inputSpanLen) * k_i + toLDSRowCol.unmerge("k", 1, {"k_iter", "blk_id", "k_vec"}, + {kIter, waveSize / inputSpanLen, kVec}); + } else { + // k = k_i + kpackPerBlock * blk_id + toLDSRowCol.unmerge("k", 1, {"blk_id", "k_iter", "k_vec"}, + {waveSize / inputSpanLen, kIter, kVec}); + } } toLDSRowCol.ignore(otherWaveDim); @@ -858,7 +913,9 @@ int64_t WmmaEmitter::getDDim(StringRef dName) const { Value WmmaEmitter::wrapLDSBufferForLoad( OpBuilder &b, Location loc, Value buffer, const BlockwiseMatrixParamsAttr &matrixParams, int64_t blockSize, - StringRef dName) const { + StringRef dName, bool useLdsTransposeLoad) const { + // Note: WMMA does not support LDS transpose load, so the parameter is unused. + (void)useLdsTransposeLoad; // Extract relevant tuning parameters int64_t mPerBlock = tuningParams.getMPerBlock(); diff --git a/mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp b/mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp index 7ae30397cc3d..d4f021dc2715 100644 --- a/mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp +++ b/mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp @@ -250,41 +250,26 @@ LDSTransposeDecision decideLDSTransposeForOperands( << (decB.usable ? "USABLE" : "NOT USABLE") << "\n"); } - // ======================================== - // KPACK CONSTRAINT LOGIC - // ======================================== - bool bothUsable = decA.usable && decB.usable; - bool onlyOneUsable = decA.usable != decB.usable; - - if (bothUsable) { - // Case 1: Both operands can use LDS transpose - always enable - result.enableA = true; - result.enableB = true; - result.mfmaDDim = mfmaDDim; - result.mfmaKDim = mfmaKDim; - LLVM_DEBUG(llvm::dbgs() - << "[lds_transpose] Enabled for BOTH operands (A and B)\n"); - } else if (onlyOneUsable && kpack == 1) { - // Case 2: Only one operand can use it with kpack == 1 - // kpack == 1: Safe to enable for single operand + // Enable LDS transpose load for each operand that supports it. + // The K access pattern formula in AccelEmitter.cpp (useLdsTransposeLoad) + // ensures compatibility when mixing regular load with transpose load. + bool anyUsable = decA.usable || decB.usable; + + if (anyUsable) { result.enableA = decA.usable; result.enableB = decB.usable; result.mfmaDDim = mfmaDDim; result.mfmaKDim = mfmaKDim; - LLVM_DEBUG(llvm::dbgs() << "[lds_transpose] Enabled for " - << (decA.usable ? "operand A" : "operand B") - << " only (kpack=1)\n"); - } else if (onlyOneUsable) { - // Case 3: Only one operand usable but kpack > 1 - // kpack > 1 with asymmetric support - disable both (current limitation) - result.enableA = false; - result.enableB = false; - // Geometry NOT set - avoids polluting tuning params with unused data - LLVM_DEBUG(llvm::dbgs() << "[lds_transpose] DISABLED: only one " - "operand eligible but kpack=" - << kpack << " > 1 (current limitation)\n"); + + LLVM_DEBUG({ + if (decA.usable && decB.usable) + llvm::dbgs() << "[lds_transpose] Enabled for BOTH operands (A and B)\n"; + else + llvm::dbgs() << "[lds_transpose] Enabled for " + << (decA.usable ? "operand A" : "operand B") << " only\n"; + }); } - // else - implicitly: neither operand usable, enableA/enableB remain false. + // else - neither operand usable, enableA/enableB remain false. // Check if numWaves is supported (1, 2, 3, 4, 8, 16) // TODO: support 32 waves for WMMA @@ -1188,7 +1173,7 @@ LogicalResult emitThreadwiseHWTranspose(PatternRewriter &b, auto [k_base_local, m_offset_base] = computeLDSBaseOffsets(b, loc, dDim, instrK, lane); - // K stride per tile: KMfma (e.g., 8) + // K stride per tile: instrK (MFMA K dimension) int64_t kTileStride = instrK; Value kTileStrideVal = arith::ConstantIndexOp::create(b, loc, kTileStride); Value ldsStrideVal = arith::ConstantIndexOp::create(b, loc, ldsStride); diff --git a/mlir/test/e2e/PrLdsTransposeLoad.toml b/mlir/test/e2e/PrLdsTransposeLoad.toml index 88f201f705fe..d1606c50cf89 100644 --- a/mlir/test/e2e/PrLdsTransposeLoad.toml +++ b/mlir/test/e2e/PrLdsTransposeLoad.toml @@ -26,7 +26,7 @@ config = "-g 1 -m 16 -k 64 -n 16 --transA=true --transB=false --perf_config v3:1 config = "-g 1 -m 32 -k 128 -n 32 --transA=true --transB=false --perf_config v3:32,32,32,32,32,4,1,3,2,1,1" [[suite.test]] -config = "-g 1 -m 32 -k 256 -n 32 --transA=true --transB=false --perf_config v3:32,32,256,32,32,1,1,4,2,1,1" +config = "-g 1 -m 32 -k 256 -n 32 --transA=true --transB=false --perf_config v3:32,32,4,32,32,8,1,4,2,1,1" [[suite.test]] config = "-g 256 -m 32 -k 32 -n 32 --transA=true --transB=false --perf_config v3:32,32,32,32,32,1,1,3,2,1,1" @@ -70,7 +70,7 @@ config = "-g 1 -m 16 -k 128 -n 256 --transA=true --transB=false --perf_config v4 # ============================================================================ # Suite 2: ONLY A uses LDS transpose (transA=true, transB=true) -# kpack=1 (single operand constraint), scheduleVersion swapped (3<->4) +# B uses regular load with LDS transpose load # ============================================================================ [[suite]] name = "lds_transpose_A_only" @@ -82,41 +82,50 @@ config = "-g 1 -m 16 -k 16 -n 16 --transA=true --transB=true --perf_config v3:16 config = "-g 1 -m 32 -k 8 -n 32 --transA=true --transB=true --perf_config v3:32,32,8,32,32,1,1,3,2,1,1" [[suite.test]] -config = "-g 1 -m 16 -k 64 -n 16 --transA=true --transB=true --perf_config v3:16,16,32,16,16,1,1,3,2,1,1" +config = "-g 1 -m 16 -k 64 -n 16 --transA=true --transB=true --perf_config v3:16,16,32,16,16,4,1,3,2,1,1" [[suite.test]] config = "-g 1 -m 32 -k 128 -n 32 --transA=true --transB=true --perf_config v3:64,32,16,32,32,1,1,4,2,1,1" [[suite.test]] -config = "-g 1 -m 32 -k 256 -n 32 --transA=true --transB=true --perf_config v3:32,64,16,32,32,1,1,3,2,1,1" +config = "-g 1 -m 32 -k 256 -n 32 --transA=true --transB=true --perf_config v3:32,64,16,32,32,1,8,3,2,1,1" [[suite.test]] config = "-g 256 -m 32 -k 32 -n 32 --transA=true --transB=true --perf_config v3:32,32,32,32,16,1,1,4,2,1,1" [[suite.test]] -config = "-g 256 -m 16 -k 64 -n 16 --transA=true --transB=true --perf_config v3:32,16,32,16,16,1,1,3,2,1,1" +config = "-g 256 -m 16 -k 64 -n 16 --transA=true --transB=true --perf_config v3:128,32,2,32,32,8,1,3,2,1,1" [[suite.test]] config = "-g 1 -m 64 -k 64 -n 64 --transA=true --transB=true --perf_config v3:32,16,16,32,16,1,1,4,2,1,1" [[suite.test]] -config = "-g 1 -m 64 -k 128 -n 64 --transA=true --transB=true --perf_config v3:64,128,32,64,16,1,1,4,2,1,1" +config = "-g 1 -m 64 -k 128 -n 64 --transA=true --transB=true --perf_config v3:64,128,4,64,16,16,1,4,2,1,1" [[suite.test]] -config = "-g 1 -m 128 -k 64 -n 128 --transA=true --transB=true --perf_config v3:128,128,16,32,32,1,1,3,2,1,1" +config = "-g 1 -m 128 -k 64 -n 128 --transA=true --transB=true --perf_config v3:128,128,8,32,32,16,1,3,2,1,1" [[suite.test]] -config = "-g 1 -m 144 -k 1280 -n 3840 --transA=true --transB=true --perf_config v3:32,128,16,32,32,1,1,4,2,1,1" +config = "-g 1 -m 144 -k 1280 -n 3840 --transA=true --transB=true --perf_config v3:32,128,16,32,32,4,1,4,2,1,1" [[suite.test]] -config = "-g 1 -m 4096 -k 256 -n 384 --transA=true --transB=true --perf_config v3:64,64,32,32,16,1,1,3,2,1,1" +config = "-g 1 -m 1280 -k 256 -n 384 --transA=true --transB=true --perf_config v3:64,64,32,32,16,1,1,3,2,1,1" [[suite.test]] -config = "-g 256 -m 900 -k 1280 -n 3840 --transA=true --transB=true --perf_config v3:128,64,16,32,32,1,1,4,2,1,1" +config = "-g 1 -m 16 -k 64 -n 16 --transA=true --transB=true --perf_config v3:16,16,8,16,16,4,1,3,2,1,1" + +[[suite.test]] +config = "-g 1 -m 16 -k 128 -n 16 --transA=true --transB=true --perf_config v3:16,16,4,16,16,16,1,3,2,1,1" + +[[suite.test]] +config = "-g 1 -m 16 -k 128 -n 16 --transA=true --transB=true --perf_config v3:16,16,8,16,16,16,1,3,2,1,1" + +[[suite.test]] +config = "-g 1 -m 16 -k 256 -n 16 --transA=true --transB=true --perf_config v3:16,16,8,16,16,32,1,3,2,1,1" # ============================================================================ # Suite 3: ONLY B uses LDS transpose (transA=false, transB=false) -# kpack=1 (single operand constraint), scheduleVersion same as Suite 1 +# A uses regular load with LDS transpose load # ============================================================================ [[suite]] name = "lds_transpose_B_only" @@ -131,7 +140,7 @@ config = "-g 1 -m 32 -k 8 -n 32 --transA=false --transB=false --perf_config v3:3 config = "-g 1 -m 16 -k 64 -n 16 --transA=false --transB=false --perf_config v3:16,16,32,16,16,1,1,4,2,1,1" [[suite.test]] -config = "-g 1 -m 32 -k 128 -n 32 --transA=false --transB=false --perf_config v3:32,32,32,16,16,1,1,3,2,1,1" +config = "-g 1 -m 32 -k 128 -n 32 --transA=false --transB=false --perf_config v3:32,32,8,16,16,32,1,3,2,1,1" [[suite.test]] config = "-g 1 -m 32 -k 256 -n 32 --transA=false --transB=false --perf_config v3:32,64,16,32,32,1,1,4,2,1,1" @@ -140,22 +149,31 @@ config = "-g 1 -m 32 -k 256 -n 32 --transA=false --transB=false --perf_config v3 config = "-g 256 -m 32 -k 32 -n 32 --transA=false --transB=false --perf_config v3:64,32,32,32,16,1,1,3,2,1,1" [[suite.test]] -config = "-g 256 -m 16 -k 64 -n 16 --transA=false --transB=false --perf_config v3:16,16,32,16,16,1,1,4,2,1,1" +config = "-g 256 -m 16 -k 64 -n 16 --transA=false --transB=false --perf_config v3:32,32,2,32,32,16,1,4,2,1,1" [[suite.test]] -config = "-g 1 -m 64 -k 64 -n 64 --transA=false --transB=false --perf_config v3:64,32,16,32,16,1,1,3,2,1,1" +config = "-g 1 -m 64 -k 64 -n 64 --transA=false --transB=false --perf_config v3:64,32,16,32,16,4,1,3,2,1,1" [[suite.test]] -config = "-g 1 -m 64 -k 128 -n 64 --transA=false --transB=false --perf_config v3:32,128,32,16,16,1,1,3,2,1,1" +config = "-g 1 -m 64 -k 128 -n 64 --transA=false --transB=false --perf_config v3:32,128,32,16,16,8,1,3,2,1,1" [[suite.test]] -config = "-g 1 -m 128 -k 64 -n 128 --transA=false --transB=false --perf_config v3:128,128,32,64,16,1,1,4,2,1,1" +config = "-g 1 -m 128 -k 64 -n 128 --transA=false --transB=false --perf_config v3:128,128,4,64,16,4,1,4,2,1,1" [[suite.test]] config = "-g 1 -m 144 -k 1280 -n 3840 --transA=false --transB=false --perf_config v3:32,32,8,32,32,1,1,3,2,1,1" [[suite.test]] -config = "-g 1 -m 4096 -k 256 -n 384 --transA=false --transB=false --perf_config v3:64,64,16,32,16,1,1,4,2,1,1" +config = "-g 1 -m 1280 -k 256 -n 384 --transA=false --transB=false --perf_config v3:64,64,16,32,16,1,1,4,2,1,1" + +[[suite.test]] +config = "-g 1 -m 16 -k 64 -n 16 --transA=false --transB=false --perf_config v3:64,64,2,32,32,32,1,3,2,1,1" + +[[suite.test]] +config = "-g 1 -m 16 -k 128 -n 16 --transA=false --transB=false --perf_config v3:16,16,4,16,16,32,1,3,2,1,1" + +[[suite.test]] +config = "-g 1 -m 16 -k 128 -n 16 --transA=false --transB=false --perf_config v3:16,16,8,16,16,16,1,3,2,1,1" [[suite.test]] -config = "-g 256 -m 900 -k 1280 -n 3840 --transA=false --transB=false --perf_config v3:128,64,16,64,16,1,1,3,2,1,1" +config = "-g 1 -m 16 -k 256 -n 16 --transA=false --transB=false --perf_config v3:16,16,8,16,16,32,1,3,2,1,1" From 2c4f149ddd97b4bcc4ec59e90eea746d369400b4 Mon Sep 17 00:00:00 2001 From: stefan koncarevic Date: Fri, 26 Dec 2025 05:41:50 -0600 Subject: [PATCH 04/14] Fix clang format --- .../Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp b/mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp index 5f238069ac91..1d4c00d8a6d2 100644 --- a/mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp @@ -85,8 +85,7 @@ class LoweringBlockwiseLoadTileOp final // wrapLDSBufferForLoad is reading a single set of Ks into private memory // A/B[m/n, 0:kBasePerThread] Value ldsViewForLoad = accelEmitterPtr->wrapLDSBufferForLoad( - b, loc, ldsView, matrixParams, blockSize, dName, - useLdsTransposeLoad); + b, loc, ldsView, matrixParams, blockSize, dName, useLdsTransposeLoad); // We enhance the transformation from wrapLDSBufferForLoad using a builder // that, given a single index, splits it into "m"("n") and "k" and lets @@ -455,9 +454,11 @@ class LoweringBlockwiseLoadTileOp final } // Determine if the other operand uses LDS transpose load - // If we're loading A, check if B uses transpose; if loading B, check A - bool useLdsTransposeLoad = isA ? matrixParamsB.getLdsTransposeEnabled() - : matrixParamsA.getLdsTransposeEnabled(); + // If we're loading A, check if B uses transpose; if loading B, check + // A + bool useLdsTransposeLoad = + isA ? matrixParamsB.getLdsTransposeEnabled() + : matrixParamsA.getLdsTransposeEnabled(); generateReadLoop(loc, b, accelEmitterPtr, tid, dName, ldsViewForGemm, destRegisters, blockSize, forceUnroll, matrixParams, transposeAttr, useLdsTransposeLoad); From e0b7855f910f5650ff12abb729864d7f9cfa0f63 Mon Sep 17 00:00:00 2001 From: stefan koncarevic Date: Fri, 26 Dec 2025 07:02:23 -0600 Subject: [PATCH 05/14] Add nightly e2e tests for LDS transpose load (GEMM) --- .../Rock/lowering_load_transpose_lds.mlir | 17 ++ mlir/test/e2e/CMakeLists.txt | 1 + mlir/test/e2e/LdsTransposeLoad.cfg | 3 + mlir/test/e2e/LdsTransposeLoad.toml | 206 ++++++++++++++++++ 4 files changed, 227 insertions(+) create mode 100644 mlir/test/Dialect/Rock/lowering_load_transpose_lds.mlir create mode 100644 mlir/test/e2e/LdsTransposeLoad.cfg create mode 100644 mlir/test/e2e/LdsTransposeLoad.toml diff --git a/mlir/test/Dialect/Rock/lowering_load_transpose_lds.mlir b/mlir/test/Dialect/Rock/lowering_load_transpose_lds.mlir new file mode 100644 index 000000000000..29a2f190a06b --- /dev/null +++ b/mlir/test/Dialect/Rock/lowering_load_transpose_lds.mlir @@ -0,0 +1,17 @@ +// RUN: rocmlir-opt --rock-sugar-to-loops %s | FileCheck %s + +// CHECK-LABEL: func @test_load_transpose_fp16 +module attributes {mhal.arch = "amdgcn-amd-amdhsa:gfx950"} { + func.func @test_load_transpose_fp16(%src: memref<128x256xf16, #gpu.address_space>, %i: index, %j: index) -> vector<4xf16> { + // CHECK: amdgpu.transpose_load %arg0[%arg1, %arg2] : memref<128x256xf16, #gpu.address_space> -> vector<4xf16> + %v = rock.lds_transpose_load %src[%i, %j] : memref<128x256xf16, #gpu.address_space> -> vector<4xf16> + return %v : vector<4xf16> + } + +// CHECK-LABEL: func @test_load_transpose_bf16 + func.func @test_load_transpose_bf16(%src: memref<64x128xbf16, #gpu.address_space>, %i: index, %j: index) -> vector<4xbf16> { + // CHECK: amdgpu.transpose_load %arg0[%arg1, %arg2] : memref<64x128xbf16, #gpu.address_space> -> vector<4xbf16> + %v = rock.lds_transpose_load %src[%i, %j] : memref<64x128xbf16, #gpu.address_space> -> vector<4xbf16> + return %v : vector<4xbf16> + } +} diff --git a/mlir/test/e2e/CMakeLists.txt b/mlir/test/e2e/CMakeLists.txt index 51065392a450..0af732ca1b3f 100644 --- a/mlir/test/e2e/CMakeLists.txt +++ b/mlir/test/e2e/CMakeLists.txt @@ -98,6 +98,7 @@ if (ROCK_E2E_TEST_ENABLED) GemmElementwiseGemmDirectToLDS ConvElementwiseGemmDirectToLDS AttentionNonPowerOfTwoTileSize + LdsTransposeLoad ) endif() # Create a list for dummy files diff --git a/mlir/test/e2e/LdsTransposeLoad.cfg b/mlir/test/e2e/LdsTransposeLoad.cfg new file mode 100644 index 000000000000..ea089113128c --- /dev/null +++ b/mlir/test/e2e/LdsTransposeLoad.cfg @@ -0,0 +1,3 @@ +if not 'lds_transpose_load' in config.features: + config.unsupported = True + diff --git a/mlir/test/e2e/LdsTransposeLoad.toml b/mlir/test/e2e/LdsTransposeLoad.toml new file mode 100644 index 000000000000..5fd8152a0f1a --- /dev/null +++ b/mlir/test/e2e/LdsTransposeLoad.toml @@ -0,0 +1,206 @@ +directory = "LdsTransposeLoad" +prefix = "rocmlir-gen" +suffix = "--operation gemm --arch %arch %pv %constrained_float_range_random_data %rocmlir_gen_flags | rocmlir-driver -c | mlir-runner -O2 --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext --entry-point-result=void | FileCheck %s --check-prefix=" + +[[axis]] +name = "data type" +values = ["f16", "bf16"] +prefix = "-t " + +# ============================================================================ +# Suite 1: BOTH A and B use LDS transpose (transA=true, transB=false) +# Comprehensive coverage of kpack (1,4,8,16,32) and kpackPerBlock (2,4,8,16,32) +# ============================================================================ +[[suite]] +name = "lds_transpose_both_operands" + +# kpack=1 with kpackPerBlock >= 16, kpack=4 with kpackPerBlock < 16 +[[suite.test]] +config = "-g 1 -m 1024 -k 1024 -n 1024 --transA=true --transB=false --perf_config v3:64,64,8,32,32,4,1,3,2,1,1" + +[[suite.test]] +config = "-g 1 -m 2048 -k 512 -n 2048 --transA=true --transB=false --perf_config v3:128,64,16,32,32,1,1,4,2,1,1" + +[[suite.test]] +config = "-g 1 -m 512 -k 2048 -n 512 --transA=true --transB=false --perf_config v3:32,32,32,32,32,1,1,3,2,1,1" + +# kpack=4, various kpackPerBlock +[[suite.test]] +config = "-g 1 -m 1024 -k 1024 -n 1024 --transA=true --transB=false --perf_config v3:64,64,8,32,32,4,1,3,2,1,1" + +[[suite.test]] +config = "-g 1 -m 2048 -k 256 -n 2048 --transA=true --transB=false --perf_config v3:128,128,4,32,32,4,1,4,2,1,1" + +[[suite.test]] +config = "-g 1 -m 512 -k 512 -n 512 --transA=true --transB=false --perf_config v3:32,32,16,32,32,4,1,3,2,1,1" + +# kpack=8, various kpackPerBlock +[[suite.test]] +config = "-g 1 -m 1024 -k 512 -n 1024 --transA=true --transB=false --perf_config v3:64,64,4,32,32,8,1,4,2,1,1" + +[[suite.test]] +config = "-g 1 -m 2048 -k 1024 -n 2048 --transA=true --transB=false --perf_config v3:128,64,8,32,32,8,1,3,2,1,1" + +[[suite.test]] +config = "-g 1 -m 512 -k 256 -n 512 --transA=true --transB=false --perf_config v3:32,32,2,32,32,8,1,3,2,1,1" + +# kpack=16, various kpackPerBlock +[[suite.test]] +config = "-g 1 -m 1024 -k 1024 -n 1024 --transA=true --transB=false --perf_config v3:64,64,4,32,32,16,1,3,2,1,1" + +[[suite.test]] +config = "-g 1 -m 2048 -k 512 -n 2048 --transA=true --transB=false --perf_config v3:128,128,2,32,32,16,1,4,2,1,1" + +[[suite.test]] +config = "-g 1 -m 512 -k 2048 -n 512 --transA=true --transB=false --perf_config v3:32,32,8,16,16,16,1,3,2,1,1" + +# kpack=32, various kpackPerBlock +[[suite.test]] +config = "-g 1 -m 1024 -k 1024 -n 1024 --transA=true --transB=false --perf_config v3:64,64,2,32,32,32,1,3,2,1,1" + +[[suite.test]] +config = "-g 1 -m 2048 -k 2048 -n 2048 --transA=true --transB=false --perf_config v3:128,64,4,32,32,32,1,4,2,1,1" + +[[suite.test]] +config = "-g 1 -m 512 -k 512 -n 512 --transA=true --transB=false --perf_config v3:32,32,4,16,16,32,1,3,2,1,1" + +# Grouped GEMM +[[suite.test]] +config = "-g 64 -m 1024 -k 256 -n 1024 --transA=true --transB=false --perf_config v3:64,64,8,32,32,4,1,3,2,1,1" + +[[suite.test]] +config = "-g 128 -m 512 -k 512 -n 512 --transA=true --transB=false --perf_config v3:32,32,16,32,32,8,1,4,2,1,1" + +# 8 waves configurations +[[suite.test]] +config = "-g 1 -m 1024 -k 256 -n 512 --transA=true --transB=false --perf_config v4:64,32,8,16,16,16,16,1,4,2,0,0,1,1" + +[[suite.test]] +config = "-g 1 -m 256 -k 512 -n 2048 --transA=true --transB=false --perf_config v4:16,128,8,16,16,16,16,1,4,2,0,0,1,1" + +# 16 waves configurations +[[suite.test]] +config = "-g 1 -m 2048 -k 256 -n 512 --transA=true --transB=false --perf_config v4:128,32,8,16,16,16,16,1,4,2,0,0,1,1" + +[[suite.test]] +config = "-g 1 -m 256 -k 512 -n 4096 --transA=true --transB=false --perf_config v4:16,256,8,16,16,16,16,1,4,2,0,0,1,1" + +# ============================================================================ +# Suite 2: ONLY A uses LDS transpose (transA=true, transB=true) +# B uses regular load - hybrid scenario +# ============================================================================ +[[suite]] +name = "lds_transpose_A_only_hybrid" + +# kpack=1 +[[suite.test]] +config = "-g 1 -m 1024 -k 512 -n 1024 --transA=true --transB=true --perf_config v3:64,64,16,32,32,1,1,4,2,1,1" + +[[suite.test]] +config = "-g 1 -m 2048 -k 1024 -n 2048 --transA=true --transB=true --perf_config v3:128,128,8,64,32,4,1,3,2,1,1" + +# kpack=4 +[[suite.test]] +config = "-g 1 -m 1024 -k 256 -n 1024 --transA=true --transB=true --perf_config v3:64,64,4,32,32,4,1,3,2,1,1" + +[[suite.test]] +config = "-g 1 -m 2048 -k 512 -n 2048 --transA=true --transB=true --perf_config v3:128,64,8,32,16,4,1,4,2,1,1" + +# kpack=8 +[[suite.test]] +config = "-g 1 -m 1024 -k 1024 -n 1024 --transA=true --transB=true --perf_config v3:64,64,2,32,32,8,1,3,2,1,1" + +[[suite.test]] +config = "-g 1 -m 2048 -k 256 -n 2048 --transA=true --transB=true --perf_config v3:128,128,4,64,16,8,1,4,2,1,1" + +# kpack=16 +[[suite.test]] +config = "-g 1 -m 1024 -k 512 -n 1024 --transA=true --transB=true --perf_config v3:64,64,2,32,32,16,1,3,2,1,1" + +[[suite.test]] +config = "-g 1 -m 2048 -k 1024 -n 2048 --transA=true --transB=true --perf_config v3:128,64,4,32,32,16,1,4,2,1,1" + +# kpack=32 +[[suite.test]] +config = "-g 1 -m 1024 -k 2048 -n 1024 --transA=true --transB=true --perf_config v3:64,64,4,32,16,32,1,3,2,1,1" + +[[suite.test]] +config = "-g 1 -m 2048 -k 512 -n 2048 --transA=true --transB=true --perf_config v3:128,128,2,64,32,32,1,4,2,1,1" + +# Grouped GEMM +[[suite.test]] +config = "-g 64 -m 512 -k 512 -n 512 --transA=true --transB=true --perf_config v3:32,64,8,32,32,8,1,3,2,1,1" + +# ============================================================================ +# Suite 3: ONLY B uses LDS transpose (transA=false, transB=false) +# A uses regular load - hybrid scenario +# ============================================================================ +[[suite]] +name = "lds_transpose_B_only_hybrid" + +# kpack=4 (kpackPerBlock <= 8 requires kpack >= 4) +[[suite.test]] +config = "-g 1 -m 1024 -k 256 -n 1024 --transA=false --transB=false --perf_config v3:64,64,8,32,16,4,1,3,2,1,1" + +[[suite.test]] +config = "-g 1 -m 2048 -k 512 -n 2048 --transA=false --transB=false --perf_config v3:128,64,4,32,32,4,1,4,2,1,1" + +# kpack=4 +[[suite.test]] +config = "-g 1 -m 1024 -k 512 -n 1024 --transA=false --transB=false --perf_config v3:64,64,8,32,16,4,1,3,2,1,1" + +[[suite.test]] +config = "-g 1 -m 2048 -k 1024 -n 2048 --transA=false --transB=false --perf_config v3:128,128,4,64,32,4,1,4,2,1,1" + +# kpack=8 +[[suite.test]] +config = "-g 1 -m 1024 -k 256 -n 1024 --transA=false --transB=false --perf_config v3:64,64,2,32,32,8,1,3,2,1,1" + +[[suite.test]] +config = "-g 1 -m 2048 -k 512 -n 2048 --transA=false --transB=false --perf_config v3:128,64,4,32,32,8,1,4,2,1,1" + +# kpack=16 +[[suite.test]] +config = "-g 1 -m 1024 -k 1024 -n 1024 --transA=false --transB=false --perf_config v3:64,64,4,32,16,16,1,3,2,1,1" + +[[suite.test]] +config = "-g 1 -m 2048 -k 256 -n 2048 --transA=false --transB=false --perf_config v3:128,128,2,64,32,16,1,4,2,1,1" + +# kpack=32 +[[suite.test]] +config = "-g 1 -m 1024 -k 512 -n 1024 --transA=false --transB=false --perf_config v3:64,64,2,32,32,32,1,3,2,1,1" + +[[suite.test]] +config = "-g 1 -m 2048 -k 1024 -n 2048 --transA=false --transB=false --perf_config v3:128,64,2,32,32,32,1,4,2,1,1" + +# Grouped GEMM +[[suite.test]] +config = "-g 64 -m 512 -k 256 -n 512 --transA=false --transB=false --perf_config v3:32,32,4,32,16,8,1,3,2,1,1" + +# ============================================================================ +# Suite 4: Large matrices (stress tests) +# ============================================================================ +[[suite]] +name = "lds_transpose_large_matrices" + +[[suite.test]] +config = "-g 1 -m 4096 -k 4096 -n 4096 --transA=true --transB=false --perf_config v3:128,128,8,64,32,8,1,4,2,1,1" + +[[suite.test]] +config = "-g 1 -m 8192 -k 1024 -n 8192 --transA=true --transB=false --perf_config v3:128,128,4,64,64,16,1,3,2,1,1" + +[[suite.test]] +config = "-g 1 -m 1024 -k 8192 -n 1024 --transA=true --transB=false --perf_config v3:64,64,8,32,32,32,1,4,2,1,1" + +[[suite.test]] +config = "-g 1 -m 3072 -k 1024 -n 3072 --transA=true --transB=false --perf_config v3:128,64,8,32,32,16,1,3,2,1,1" + +[[suite.test]] +config = "-g 1 -m 1536 -k 6144 -n 1536 --transA=true --transB=false --perf_config v3:64,64,16,32,32,8,1,4,2,1,1" + +# Large with grouped +[[suite.test]] +config = "-g 32 -m 2048 -k 1024 -n 2048 --transA=true --transB=false --perf_config v3:128,128,8,64,32,8,1,3,2,1,1" + +[[suite.test]] +config = "-g 16 -m 4096 -k 512 -n 4096 --transA=true --transB=false --perf_config v3:128,64,4,32,32,16,1,4,2,1,1" From 2819548acd177816949e99c571a1ad8a081449f8 Mon Sep 17 00:00:00 2001 From: stefan koncarevic Date: Mon, 29 Dec 2025 10:05:33 -0600 Subject: [PATCH 06/14] Address review comments for LDS transpose load - Move kBase variable declaration earlier in wrapLDSBufferForLoad - Remove duplicate MfmaInsnAttr declaration, reuse existing mfmaAttr - Add negative test for LDS transpose load on unsupported arch (gfx942) --- mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp | 3 +-- mlir/test/Dialect/Rock/lds_transpose_error.mlir | 11 +++++++++++ mlir/test/e2e/LdsTransposeLoad.cfg | 1 - 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp b/mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp index 3e7b6bc964ea..b7dc2b4db910 100644 --- a/mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp +++ b/mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp @@ -445,6 +445,7 @@ Value MfmaEmitter::wrapLDSBufferForLoad( MfmaInsnAttr mfmaAttr = mfmaGroup.getInsnAttr(); int64_t inputSpanLen = mfmaAttr.inputSpanLen; int64_t kpackPerThread = accelEmitterParams.kpackPerThread; + int64_t kBase = accelEmitterParams.kBase; bool isKReduction = mfmaAttr.isKReduction; int64_t kIter = kpackPerThread; int64_t kVec = 1; @@ -547,11 +548,9 @@ Value MfmaEmitter::wrapLDSBufferForLoad( // Use LDS transpose compatible K formula only when: // 1. Other operand uses LDS transpose load (hybrid scenario) // 2. kVec >= kBase (enough elements per load to decompose) - int64_t kBase = accelEmitterParams.kBase; if (useLdsTransposeLoad && kVec >= kBase) { // K access pattern must match the transpose load's pattern. // For double-rate MFMA, properly distribute K across threads - MfmaInsnAttr mfmaAttr = mfmaGroup.getInsnAttr(); int64_t instrK = mfmaAttr.k; int64_t numBlksInK = instrK / kBase; int64_t numBlksInD = (waveSize / inputSpanLen) / numBlksInK; diff --git a/mlir/test/Dialect/Rock/lds_transpose_error.mlir b/mlir/test/Dialect/Rock/lds_transpose_error.mlir index 700b14023361..efc7cfc42fa4 100644 --- a/mlir/test/Dialect/Rock/lds_transpose_error.mlir +++ b/mlir/test/Dialect/Rock/lds_transpose_error.mlir @@ -75,3 +75,14 @@ func.func @threadwise_read_into_kperblock_not_divisible( } [] (%source) [] -> %dest : memref<128xf16, #gpu.address_space> -> memref<8xf16, #gpu.address_space> return } + +// ----- + +// Error case: LDS transpose load not supported on gfx942 +module attributes {mhal.arch = "amdgcn-amd-amdhsa:gfx942"} { + func.func @lds_transpose_load_unsupported_arch(%src: memref<128x256xf16, #gpu.address_space>, %i: index, %j: index) -> vector<4xf16> { + // expected-error @+1 {{LDS transpose load is not supported on this architecture: amdgcn-amd-amdhsa:gfx942}} + %v = rock.lds_transpose_load %src[%i, %j] : memref<128x256xf16, #gpu.address_space> -> vector<4xf16> + return %v : vector<4xf16> + } +} diff --git a/mlir/test/e2e/LdsTransposeLoad.cfg b/mlir/test/e2e/LdsTransposeLoad.cfg index ea089113128c..46909aa10a02 100644 --- a/mlir/test/e2e/LdsTransposeLoad.cfg +++ b/mlir/test/e2e/LdsTransposeLoad.cfg @@ -1,3 +1,2 @@ if not 'lds_transpose_load' in config.features: config.unsupported = True - From f042d38a2926e40879dc40eae0140b29678c45f9 Mon Sep 17 00:00:00 2001 From: stefan koncarevic Date: Fri, 2 Jan 2026 05:15:18 -0600 Subject: [PATCH 07/14] Add LDS transpose E2E tests and addressed review comment - Add GEMM1 LDS transpose tests (V transpose + P prefetch) to nightly - Create PrLdsTransposeLoadAttention.toml with 14 quick PR tests --- .../mlir/Dialect/Rock/IR/AccelEmitter.h | 13 +- .../BlockwiseLoadTileToThreadwise.cpp | 15 +- .../Transforms/GridwiseGemmToBlockwise.cpp | 23 ++- .../lib/Dialect/Rock/utility/AccelEmitter.cpp | 138 ++++++++++++++---- mlir/test/e2e/CMakeLists.txt | 2 + mlir/test/e2e/LdsTransposeLoadAttention.cfg | 2 + mlir/test/e2e/LdsTransposeLoadAttention.toml | 126 ++++++++++++++++ mlir/test/e2e/PrLdsTransposeLoadAttention.cfg | 2 + .../test/e2e/PrLdsTransposeLoadAttention.toml | 91 ++++++++++++ 9 files changed, 372 insertions(+), 40 deletions(-) create mode 100644 mlir/test/e2e/LdsTransposeLoadAttention.cfg create mode 100644 mlir/test/e2e/LdsTransposeLoadAttention.toml create mode 100644 mlir/test/e2e/PrLdsTransposeLoadAttention.cfg create mode 100644 mlir/test/e2e/PrLdsTransposeLoadAttention.toml diff --git a/mlir/include/mlir/Dialect/Rock/IR/AccelEmitter.h b/mlir/include/mlir/Dialect/Rock/IR/AccelEmitter.h index 12b677fbfc15..cf4689c6420e 100644 --- a/mlir/include/mlir/Dialect/Rock/IR/AccelEmitter.h +++ b/mlir/include/mlir/Dialect/Rock/IR/AccelEmitter.h @@ -116,11 +116,14 @@ struct AccelEmitter { /// 3) threadSubTileView : /// iter --> ... --> [KPerThread, DPerThread] /// for each operand tile to be used with gemm accelerators. + /// When otherOperandUsesLdsTranspose is true, a special K access pattern + /// is used that is compatible with LDS transpose load on the other operand. virtual FailureOr createAccelGemmOperandTransforms( OpBuilder &b, Location loc, int64_t kIters, ArrayRef bidGridLengths, int64_t blockSize, int64_t dInCopyPerThread, StringRef dName, bool isKContiguousDim, - bool rotateDWithK, bool doSplitKAcrossThreadsFirst = false) const = 0; + bool rotateDWithK, bool doSplitKAcrossThreadsFirst = false, + bool otherOperandUsesLdsTranspose = false) const = 0; /// Validate the accelerator structure virtual LogicalResult validateAcceleratorProperties() { return success(); }; @@ -197,8 +200,8 @@ struct MfmaEmitter : public AccelEmitter { OpBuilder &b, Location loc, int64_t kIters, ArrayRef bidGridLengths, int64_t blockSize, int64_t dInCopyPerThread, StringRef dName, bool isKContiguousDim, - bool rotateDWithK, - bool doSplitKAcrossThreadsFirst = false) const override; + bool rotateDWithK, bool doSplitKAcrossThreadsFirst = false, + bool otherOperandUsesLdsTranspose = false) const override; FailureOr computeOutputTransforms( OpBuilder &b, Location loc, int64_t mLen, int64_t nLen, int64_t blockSize, @@ -251,8 +254,8 @@ struct WmmaEmitter : public AccelEmitter { OpBuilder &b, Location loc, int64_t kIters, ArrayRef bidGridLengths, int64_t blockSize, int64_t dInCopyPerThread, StringRef dName, bool isKContiguousDim, - bool rotateDWithK, - bool doSplitKAcrossThreadsFirst = false) const override; + bool rotateDWithK, bool doSplitKAcrossThreadsFirst = false, + bool otherOperandUsesLdsTranspose = false) const override; FailureOr computeOutputTransforms( OpBuilder &b, Location loc, int64_t mLen, int64_t nLen, int64_t blockSize, diff --git a/mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp b/mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp index 1d4c00d8a6d2..cd42eec64552 100644 --- a/mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp @@ -277,9 +277,14 @@ class LoweringBlockwiseLoadTileOp final FailureOr maybeBufferViews; if (loadType == GemmLoadTileType::BypassLDS) { + // Check if the other operand uses LDS transpose load + bool otherOperandUsesLdsTranspose = + isA ? matrixParamsB.getLdsTransposeEnabled() + : matrixParamsA.getLdsTransposeEnabled(); maybeBufferViews = accelEmitterPtr->createAccelGemmOperandTransforms( b, loc, kIters, bidGridLengths, blockSize, vecDimInfo.inDPerThread, - dName, isKContiguousDim, false); + dName, isKContiguousDim, false, + /*doSplitKAcrossThreadsFirst=*/false, otherOperandUsesLdsTranspose); } else { maybeBufferViews = getLoadRegsAsTileViews( b, loc, source, dName, bidGridOrder, bidGridLengths, blockSize, @@ -339,10 +344,16 @@ class LoweringBlockwiseLoadTileOp final subview = createSliceOfFirstDim(b, loc, subview, di); } + // Check if the other operand uses LDS transpose load + bool otherOperandUsesLdsTranspose = + isA ? matrixParamsB.getLdsTransposeEnabled() + : matrixParamsA.getLdsTransposeEnabled(); FailureOr maybeBufferViews = accelEmitterPtr->createAccelGemmOperandTransforms( b, loc, kIters, bidGridLengths, blockSize, - vecDimInfo.inDPerThread, dName, isKContiguousDim, false); + vecDimInfo.inDPerThread, dName, isKContiguousDim, false, + /*doSplitKAcrossThreadsFirst=*/false, + otherOperandUsesLdsTranspose); if (failed(maybeBufferViews)) return failure(); // InBufferViews provide --> K x D subtile views. diff --git a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp index 7af0b8509581..c970003a0211 100644 --- a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp @@ -2127,8 +2127,8 @@ struct GridwiseAttentionAccelRewritePattern loadTypeQ == GemmLoadTileType::DirectToLDSDoubleBuffer; // Determine if Q loads from LDS (for LDS transpose decision) - // Q bypasses LDS only when prefetch is active - bool qLoadsFromLDS = !prefetchQTile; + // Q bypasses LDS only when loadTypeQ is BypassLDS + bool qLoadsFromLDS = loadTypeQ != GemmLoadTileType::BypassLDS; // Note that kPerBlock for Gemm1B is mPerBlock of Gemm0 out // Note that mPerBlock for Gemm1A is mPerBlock of Gemm0 out @@ -2446,30 +2446,41 @@ struct GridwiseAttentionAccelRewritePattern /*accelKDim=*/ldsDecisionGemm0.mfmaKDim); // LDS Transpose Decision for GEMM1 (V x P) - // Only V (operand A) can use LDS transpose + // Note: LDS transpose for V is ONLY enabled when P is prefetched + // (doBypassLDSSecondGemm = true). hwtranspose::LDSTransposeDecision ldsDecisionGemm1 = hwtranspose::decideLDSTransposeForOperands( accelEmitterPtrGemm1.get(), arch, elemTypeV, elemTypeV, directToLDS, ldsLayoutCfgMG1, ldsLayoutCfgMG1, gemm1MPerBlock, gemm1NPerBlock, gemm1KPerBlock, gemm1TuningParams.getMPerWave(), gemm1TuningParams.getNPerWave(), gemm1kpack, - /*doubleBuffering=*/false, /*bLoadsFromLDS=*/false); + /*doubleBuffering=*/false, + /*bLoadsFromLDS=*/!doBypassLDSSecondGemm); + + // Enable LDS transpose for V only when P is prefetched + bool enableLdsTransposeForV = + doBypassLDSSecondGemm && ldsDecisionGemm1.enableA; BlockwiseMatrixParamsAttr matrixParamsV = BlockwiseMatrixParamsAttr::get( rewriter.getContext(), elemTypeV, elemTypeVLoad, ldsLayoutCfgMG1.doRotateWithK, ldsLayoutCfgMG1.doSwapThreadIterSubDims, ldsLayoutCfgMG1.ldsLayoutDxK, directToLDS, doBypassLDSSecondGemm, gemm0G, gemm1M, gemm1InMPerThread, - /*ldsTransposeEnabled=*/ldsDecisionGemm1.enableA, + /*ldsTransposeEnabled=*/enableLdsTransposeForV, /*accelDDim=*/ldsDecisionGemm1.mfmaDDim, /*accelKDim=*/ldsDecisionGemm1.mfmaKDim); + // P matrix (operand B) - when prefetched, uses LDS transpose compatible + // K formula via otherOperandUsesLdsTranspose in + // createAccelGemmOperandTransforms BlockwiseMatrixParamsAttr matrixParamsKxQ = BlockwiseMatrixParamsAttr::get( rewriter.getContext(), elemTypeV, elemTypeVLoad, /*rotateDWithK=*/false, /*swapThreadIterSubDims=*/false, /*LDSLayoutDxK=*/false, /*directToLDS=*/false, /*splitKAcrossThreadsFirst=*/false, gemm0G, gemm1N, gemm1InMPerThread, - /*ldsTransposeEnabled=*/false, /*accelDDim=*/0, /*accelKDim=*/0); + /*ldsTransposeEnabled=*/false, + /*accelDDim=*/ldsDecisionGemm1.mfmaDDim, + /*accelKDim=*/ldsDecisionGemm1.mfmaKDim); // If gemm0K is equal to gemm0KPerBlock that means // effectively there is no K loop. Therefore, we diff --git a/mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp b/mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp index b7dc2b4db910..e21982727b40 100644 --- a/mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp +++ b/mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp @@ -545,10 +545,10 @@ Value MfmaEmitter::wrapLDSBufferForLoad( TopDownTMBuilder toLDSRowCol(b, {}, {}, loc); - // Use LDS transpose compatible K formula only when: - // 1. Other operand uses LDS transpose load (hybrid scenario) - // 2. kVec >= kBase (enough elements per load to decompose) + // Use LDS transpose compatible K formula when this operand uses LDS + // transpose load (and kVec >= kBase to ensure proper K distribution) if (useLdsTransposeLoad && kVec >= kBase) { + // K access pattern must match the transpose load's pattern. // For double-rate MFMA, properly distribute K across threads int64_t instrK = mfmaAttr.k; @@ -595,7 +595,7 @@ Value MfmaEmitter::wrapLDSBufferForLoad( {kIter, numMfmaPerKVec, numBlksInK, kBase}); } else { - // Standard formula for regular load scenarios or when kVec < kBase + // Standard formula for regular load scenarios toLDSRowCol = TopDownTMBuilder::below(splitWaveId, splitWaveIdAttr); // d = blk_td + d_i * waveOffset @@ -660,7 +660,8 @@ MfmaEmitter::createAccelGemmOperandTransforms( OpBuilder &b, Location loc, int64_t kIters, ArrayRef bidGridLengths, int64_t blockSize, int64_t dInCopyPerThread, StringRef dName, bool isKContiguousDim, - bool rotateDWithK, bool doSplitKAcrossThreadsFirst) const { + bool rotateDWithK, bool doSplitKAcrossThreadsFirst, + bool otherOperandUsesLdsTranspose) const { StringRef thisWaveDim = dName == "m" ? "wave_m" : "wave_n"; StringRef otherWaveDim = dName == "m" ? "wave_n" : "wave_m"; StringRef thisBlockDim = dName == "m" ? "m_block" : "n_block"; @@ -679,7 +680,9 @@ MfmaEmitter::createAccelGemmOperandTransforms( MfmaInsnAttr mfmaAttr = mfmaGroup.getInsnAttr(); int64_t inputSpanLen = mfmaAttr.inputSpanLen; int64_t kpackPerThread = accelEmitterParams.kpackPerThread; + int64_t kBase = accelEmitterParams.kBase; bool isKReduction = mfmaAttr.isKReduction; + int64_t instrK = mfmaAttr.k; // Extract relevant derived parameters int64_t mWaves = mPerBlock / mPerWave; @@ -757,9 +760,86 @@ MfmaEmitter::createAccelGemmOperandTransforms( TransformMapAttr splitWaveIdAttr = splitWaveId.get(); transformAttrs.push_back(splitWaveIdAttr); // Fourth coordinate transform - TopDownTMBuilder toLDSRowCol = - TopDownTMBuilder::below(splitWaveId, splitWaveIdAttr); - { + // Check if we need LDS transpose compatible K formula + bool useLdsTransposeCompatibleK = + otherOperandUsesLdsTranspose && isKReduction && (kPack >= kBase); + int64_t numBlksInK = instrK / kBase; + int64_t numBlksInD = (waveSize / inputSpanLen) / numBlksInK; + + TransformMapAttr toLDSRowColAttr; + if (useLdsTransposeCompatibleK) { + // LDS transpose compatible path: split blk_id into blk_d and blk_k + // Also split kpack into k_mfma and k_base to match LDS transpose pattern + int64_t numMfmaPerKPack = kPack / kBase; + + // First, add a transform to split blk_id + TopDownTMBuilder splitBlkId = + TopDownTMBuilder::below(splitWaveId, splitWaveIdAttr); + splitBlkId.passThrough({"k_loop", "g_block"}); + splitBlkId.passThrough({thisBlockDim}, {2}, {thisBlockDim}); + splitBlkId.passThrough({"kpack"}, {3}, {"kpack"}); + splitBlkId.passThrough({"wave_m", "wave_n"}, {4, 5}, + {"wave_m", "wave_n"}); + splitBlkId.merge({"blk_d", "blk_k"}, {6, 7}, "blk_id", + {numBlksInD, numBlksInK}); + splitBlkId.passThrough({"blk_td", "d_iter", "k_iter"}, {8, 9, 10}, + {"blk_td", "d_iter", "k_iter"}); + TransformMapAttr splitBlkIdAttr = splitBlkId.get(); + transformAttrs.push_back(splitBlkIdAttr); + + // Split kpack into k_mfma and k_base (similar to wrapLDSBufferForLoad) + TopDownTMBuilder splitKpack = + TopDownTMBuilder::below(splitBlkId, splitBlkIdAttr); + splitKpack.passThrough({"k_loop", "g_block"}); + splitKpack.passThrough({thisBlockDim}, {2}, {thisBlockDim}); + splitKpack.merge({"k_mfma", "k_base"}, {3, 4}, "kpack", + {numMfmaPerKPack, kBase}); + splitKpack.passThrough({"wave_m", "wave_n"}, {5, 6}, + {"wave_m", "wave_n"}); + splitKpack.passThrough({"blk_d", "blk_k", "blk_td", "d_iter", "k_iter"}, + {7, 8, 9, 10, 11}, + {"blk_d", "blk_k", "blk_td", "d_iter", "k_iter"}); + TransformMapAttr splitKpackAttr = splitKpack.get(); + transformAttrs.push_back(splitKpackAttr); + + // Then create the coordinate transform + TopDownTMBuilder toLDSRowCol = + TopDownTMBuilder::below(splitKpack, splitKpackAttr); + toLDSRowCol.passThrough({"k_loop", "g_block"}); + toLDSRowCol.passThrough({thisBlockDim}, {2}, {thisBlockDim}); + + // d = d_iter * dWaves * numBlksInD * inputSpanLen + wave_d * numBlksInD * + // inputSpanLen + blk_d * inputSpanLen + blk_td + toLDSRowCol.unmerge("d", 3, {"d_iter", thisWaveDim, "blk_d", "blk_td"}, + {dRepeats, dWaves, numBlksInD, inputSpanLen}); + + // k = k_iter * (numMfmaPerKPack * instrK) + k_mfma * instrK + blk_k * + // kBase + k_base This matches the formula in wrapLDSBufferForLoad + toLDSRowCol.unmerge("k", 4, {"k_iter", "k_mfma", "blk_k", "k_base"}, + {kpackPerThread, numMfmaPerKPack, numBlksInK, kBase}); + + toLDSRowCol.ignore(otherWaveDim); + toLDSRowColAttr = toLDSRowCol.get(); + transformAttrs.push_back(toLDSRowColAttr); + + // Fifth coordinate transform for LDS transpose compatible path + // Note: rotateDWithK should not be used with LDS transpose compatible + { + TopDownTMBuilder offset = + TopDownTMBuilder::below(toLDSRowCol, toLDSRowColAttr); + offset.passThrough({"G"}, {0}, {"g_block"}); + offset.unmerge({"K"}, 1, {"k_loop", "k"}, + {kIters, kPackPerBlock * kPack}); + offset.unmerge("D", 2, {thisBlockDim, "d"}, + {thisDimNumBlocks, dPerBlock}); + TransformMapAttr offsetAttr = offset.get(); + transformAttrs.push_back(offsetAttr); + } + ret.gridSubTile = b.getArrayAttr(transformAttrs); + } else { + // Regular path + TopDownTMBuilder toLDSRowCol = + TopDownTMBuilder::below(splitWaveId, splitWaveIdAttr); toLDSRowCol.passThrough({"k_loop", "g_block"}); toLDSRowCol.passThrough({thisBlockDim}, {2}, {thisBlockDim}); toLDSRowCol.passThrough({"kpack"}, {3}, {"kpack"}); @@ -784,25 +864,26 @@ MfmaEmitter::createAccelGemmOperandTransforms( toLDSRowCol.passThrough({"k"}, 5, {"k_iter"}); } toLDSRowCol.ignore(otherWaveDim); + toLDSRowColAttr = toLDSRowCol.get(); + transformAttrs.push_back(toLDSRowColAttr); + + // Fifth coordinate transform + { + int64_t stride = (kPack == 1 ? dInCopyPerThread : 1); + auto offset = rotateIf(rotateDWithK, toLDSRowCol, toLDSRowColAttr, + stride, "d", dPerBlock, 3, "k", kPackPerBlock, + {"k_loop", "g_block", thisBlockDim, "kpack"}, + {"k"}, transformAttrs); + offset.passThrough({"G"}, {0}, {"g_block"}); + offset.unmerge({"K"}, 1, {"k_loop", "k", "kpack"}, + {kIters, kPackPerBlock, kPack}); + offset.unmerge("D", 2, {thisBlockDim, "d"}, + {thisDimNumBlocks, dPerBlock}); + TransformMapAttr offsetAttr = offset.get(); + transformAttrs.push_back(offsetAttr); + } + ret.gridSubTile = b.getArrayAttr(transformAttrs); } - TransformMapAttr toLDSRowColAttr = toLDSRowCol.get(); - transformAttrs.push_back(toLDSRowColAttr); - // Fifth coordinate transform - { - int64_t stride = (kPack == 1 ? dInCopyPerThread : 1); - auto offset = rotateIf(rotateDWithK, toLDSRowCol, toLDSRowColAttr, stride, - "d", dPerBlock, 3, "k", kPackPerBlock, - {"k_loop", "g_block", thisBlockDim, "kpack"}, - {"k"}, transformAttrs); - offset.passThrough({"G"}, {0}, {"g_block"}); - offset.unmerge({"K"}, 1, {"k_loop", "k", "kpack"}, - {kIters, kPackPerBlock, kPack}); - offset.unmerge("D", 2, {thisBlockDim, "d"}, - {thisDimNumBlocks, dPerBlock}); - TransformMapAttr offsetAttr = offset.get(); - transformAttrs.push_back(offsetAttr); - } - ret.gridSubTile = b.getArrayAttr(transformAttrs); } // compute block sub tile transforms { @@ -1011,7 +1092,10 @@ WmmaEmitter::createAccelGemmOperandTransforms( OpBuilder &b, Location loc, int64_t kIters, ArrayRef bidGridLengths, int64_t blockSize, int64_t dInCopyPerThread, StringRef dName, bool isKContiguousDim, - bool rotateDWithK, bool doSplitKAcrossThreadsFirst) const { + bool rotateDWithK, bool doSplitKAcrossThreadsFirst, + bool otherOperandUsesLdsTranspose) const { + // Note: WMMA does not support LDS transpose load, so the parameter is unused + (void)otherOperandUsesLdsTranspose; StringRef thisWaveDim = dName == "m" ? "wave_m" : "wave_n"; StringRef otherWaveDim = dName == "m" ? "wave_n" : "wave_m"; StringRef thisBlockDim = dName == "m" ? "m_block" : "n_block"; diff --git a/mlir/test/e2e/CMakeLists.txt b/mlir/test/e2e/CMakeLists.txt index 0af732ca1b3f..317918b9cc68 100644 --- a/mlir/test/e2e/CMakeLists.txt +++ b/mlir/test/e2e/CMakeLists.txt @@ -50,6 +50,7 @@ if (ROCMLIR_DRIVER_PR_E2E_TEST_ENABLED) PrConvElementwiseGemmBF16SplitK PrGemmDirectToLDS PrLdsTransposeLoad + PrLdsTransposeLoadAttention PrConvDirectToLDS PrAttentionDirectToLDS ) @@ -99,6 +100,7 @@ if (ROCK_E2E_TEST_ENABLED) ConvElementwiseGemmDirectToLDS AttentionNonPowerOfTwoTileSize LdsTransposeLoad + LdsTransposeLoadAttention ) endif() # Create a list for dummy files diff --git a/mlir/test/e2e/LdsTransposeLoadAttention.cfg b/mlir/test/e2e/LdsTransposeLoadAttention.cfg new file mode 100644 index 000000000000..46909aa10a02 --- /dev/null +++ b/mlir/test/e2e/LdsTransposeLoadAttention.cfg @@ -0,0 +1,2 @@ +if not 'lds_transpose_load' in config.features: + config.unsupported = True diff --git a/mlir/test/e2e/LdsTransposeLoadAttention.toml b/mlir/test/e2e/LdsTransposeLoadAttention.toml new file mode 100644 index 000000000000..181eaa5d8d1a --- /dev/null +++ b/mlir/test/e2e/LdsTransposeLoadAttention.toml @@ -0,0 +1,126 @@ +directory = "LdsTransposeLoadAttention" +prefix = "rocmlir-gen" +suffix = "--operation attention --arch %arch %pv %constrained_float_range_random_data %rocmlir_gen_flags -relDiff_threshold 0.3 -absDiff_threshold 0.3 -RMS_threshold 0.15 | rocmlir-driver -c | mlir-runner -O2 --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext --entry-point-result=void | FileCheck %s --check-prefix=" + +[[axis]] +name = "data type" +values = ["f16", "bf16"] +prefix = "-t " + +# ============================================================================ +# LDS Transpose Load for Attention +# All tests use schedule_version=3 or 4 (direct-to-LDS) +# +# Prefetch: head_dim_qk <= kpackPerBlock * kpack +# No prefetch: head_dim_qk > kpackPerBlock * kpack +# ============================================================================ + +# ============================================================================ +# Suite 1: transK=false, transQ=true - LDS transpose on BOTH K and Q +# Q loads from LDS (no prefetch), both can use LDS transpose +# ============================================================================ +[[suite]] +name = "lds_transpose_both_k_and_q" + +[[suite.test]] +config = "-seq_len_q 128 -seq_len_k 128 -head_dim_qk 256 -head_dim_v 128 --transK=false --transQ=true -perf_config attn:v2:32,32,32,32,32,32,4,1,3,2,1" + +[[suite.test]] +config = "-seq_len_q 256 -seq_len_k 256 -head_dim_qk 512 -head_dim_v 128 --transK=false --transQ=true -perf_config attn:v2:64,64,32,32,32,32,4,1,4,2,1" + +[[suite.test]] +config = "-seq_len_q 128 -seq_len_k 128 -head_dim_qk 256 -head_dim_v 128 --transK=false --transQ=true -perf_config attn:v2:32,32,16,16,16,16,8,1,3,2,1" + +[[suite.test]] +config = "-seq_len_q 128 -seq_len_k 128 -head_dim_qk 512 -head_dim_v 128 --transK=false --transQ=true -perf_config attn:v2:32,32,32,16,16,16,8,1,4,2,1" + +# ============================================================================ +# Suite 2: transK=false, transQ=true - Q prefetched to registers +# K uses LDS transpose, Q bypasses LDS (prefetch) +# ============================================================================ +[[suite]] +name = "lds_transpose_k_only_q_prefetch" + +[[suite.test]] +config = "-seq_len_q 128 -seq_len_k 128 -head_dim_qk 64 -head_dim_v 128 --transK=false --transQ=true -perf_config attn:v2:32,32,32,32,32,32,8,1,3,2,1" + +[[suite.test]] +config = "-seq_len_q 256 -seq_len_k 256 -head_dim_qk 128 -head_dim_v 128 --transK=false --transQ=true -perf_config attn:v2:64,64,16,16,16,16,8,1,4,2,1" + +[[suite.test]] +config = "-seq_len_q 64 -seq_len_k 64 -head_dim_qk 64 -head_dim_v 64 --transK=false --transQ=true -perf_config attn:v2:16,16,16,16,16,16,4,1,3,2,1" + +[[suite.test]] +config = "-seq_len_q 128 -seq_len_k 128 -head_dim_qk 128 -head_dim_v 128 --transK=false --transQ=true -perf_config attn:v2:16,16,32,16,16,16,4,1,4,2,1" + +# ============================================================================ +# Suite 3: transK=false, transQ=false - Only K uses LDS transpose +# Q not usable for LDS transpose, no prefetch +# ============================================================================ +[[suite]] +name = "lds_transpose_k_only_no_prefetch" + +[[suite.test]] +config = "-seq_len_q 128 -seq_len_k 128 -head_dim_qk 256 -head_dim_v 128 --transK=false --transQ=false -perf_config attn:v2:32,32,32,32,32,32,8,1,3,2,1" + +[[suite.test]] +config = "-seq_len_q 256 -seq_len_k 256 -head_dim_qk 256 -head_dim_v 256 --transK=false --transQ=false -perf_config attn:v2:64,64,16,16,16,16,16,1,4,2,1" + +[[suite.test]] +config = "-seq_len_q 128 -seq_len_k 128 -head_dim_qk 512 -head_dim_v 64 --transK=false --transQ=false -perf_config attn:v2:16,16,16,16,16,16,4,1,3,2,1" + +# ============================================================================ +# Suite 4: transK=false, transQ=false - Only K uses LDS transpose, Q prefetch +# K uses LDS transpose, Q bypasses LDS (prefetch) +# ============================================================================ +[[suite]] +name = "lds_transpose_k_only_q_prefetch_hybrid" + +[[suite.test]] +config = "-seq_len_q 128 -seq_len_k 128 -head_dim_qk 128 -head_dim_v 128 --transK=false --transQ=false -perf_config attn:v2:32,32,32,32,32,32,8,1,3,2,1" + +[[suite.test]] +config = "-seq_len_q 256 -seq_len_k 256 -head_dim_qk 256 -head_dim_v 128 --transK=false --transQ=false -perf_config attn:v2:64,64,32,32,32,32,8,1,4,2,1" + +[[suite.test]] +config = "-seq_len_q 64 -seq_len_k 64 -head_dim_qk 64 -head_dim_v 64 --transK=false --transQ=false -perf_config attn:v2:32,32,16,16,16,16,4,1,3,2,1" + +[[suite.test]] +config = "-seq_len_q 128 -seq_len_k 128 -head_dim_qk 128 -head_dim_v 128 --transK=false --transQ=false -perf_config attn:v2:16,16,32,16,16,16,4,1,4,2,1" + +[[suite.test]] +config = "-seq_len_q 128 -seq_len_k 128 -head_dim_qk 128 -head_dim_v 128 --transK=false --transQ=false -perf_config attn:v2:32,32,32,32,32,32,4,1,3,2,1" + +# ============================================================================ +# Suite 5: transK=true, transQ=true - Only Q can use LDS transpose +# K not usable (transposed), Q usable +# ============================================================================ +[[suite]] +name = "lds_transpose_q_only" + +[[suite.test]] +config = "-seq_len_q 128 -seq_len_k 128 -head_dim_qk 256 -head_dim_v 128 --transK=true --transQ=true -perf_config attn:v2:32,32,32,32,32,32,8,1,3,2,1" + +[[suite.test]] +config = "-seq_len_q 256 -seq_len_k 256 -head_dim_qk 256 -head_dim_v 256 --transK=true --transQ=true -perf_config attn:v2:64,64,16,16,16,16,16,1,4,2,1" + +[[suite.test]] +config = "-seq_len_q 128 -seq_len_k 128 -head_dim_qk 256 -head_dim_v 64 --transK=true --transQ=true -perf_config attn:v2:32,32,16,16,16,16,8,1,3,2,1" + +# ============================================================================ +# Suite 6: GEMM1 LDS Transpose - V uses LDS transpose, P is prefetched +# ============================================================================ +[[suite]] +name = "lds_transpose_gemm1_v_with_p_prefetch" + +# 32x8 MFMA, medium dimensions +[[suite.test]] +config = "-seq_len_q 128 -seq_len_k 128 -head_dim_qk 64 -head_dim_v 64 --transK=false --transQ=false -perf_config attn:v3:32,32,32,32,32,32,32,4,1,3,2,0,1" + +# 32x8 MFMA, larger dimensions +[[suite.test]] +config = "-seq_len_q 256 -seq_len_k 256 -head_dim_qk 128 -head_dim_v 128 --transK=false --transQ=false -perf_config attn:v3:32,32,32,32,32,32,32,4,1,3,2,0,1" + +# 16x16 MFMA, smaller dimensions +[[suite.test]] +config = "-seq_len_q 64 -seq_len_k 64 -head_dim_qk 64 -head_dim_v 64 --transK=false --transQ=false -perf_config attn:v3:16,16,16,16,16,16,16,4,1,3,2,0,1" diff --git a/mlir/test/e2e/PrLdsTransposeLoadAttention.cfg b/mlir/test/e2e/PrLdsTransposeLoadAttention.cfg new file mode 100644 index 000000000000..46909aa10a02 --- /dev/null +++ b/mlir/test/e2e/PrLdsTransposeLoadAttention.cfg @@ -0,0 +1,2 @@ +if not 'lds_transpose_load' in config.features: + config.unsupported = True diff --git a/mlir/test/e2e/PrLdsTransposeLoadAttention.toml b/mlir/test/e2e/PrLdsTransposeLoadAttention.toml new file mode 100644 index 000000000000..eb9bf30e4e17 --- /dev/null +++ b/mlir/test/e2e/PrLdsTransposeLoadAttention.toml @@ -0,0 +1,91 @@ +directory = "PrLdsTransposeLoadAttention" +prefix = "rocmlir-gen" +suffix = "--operation attention --arch %arch -pv %constrained_float_range_random_data %rocmlir_gen_flags -relDiff_threshold 0.3 -absDiff_threshold 0.3 -RMS_threshold 0.15 | rocmlir-driver -c | mlir-runner -O2 --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext --entry-point-result=void | FileCheck %s --check-prefix=" + +[[axis]] +name = "data type" +values = ["f16", "bf16"] +prefix = "-t " + +# ============================================================================ +# LDS Transpose Load for Attention - PR Tests (quick validation) +# ============================================================================ + +# ============================================================================ +# Suite 1: Both K and Q use LDS transpose (transK=false, transQ=true) +# ============================================================================ +[[suite]] +name = "lds_transpose_both" + +[[suite.test]] +config = "-seq_len_q 64 -seq_len_k 64 -head_dim_qk 128 -head_dim_v 64 --transK=false --transQ=true -perf_config attn:v2:32,32,32,32,32,32,4,1,3,2,1" + +[[suite.test]] +config = "-seq_len_q 64 -seq_len_k 64 -head_dim_qk 128 -head_dim_v 64 --transK=false --transQ=true -perf_config attn:v2:16,16,16,16,16,16,4,1,3,2,1" + +# ============================================================================ +# Suite 2: K uses LDS transpose, Q prefetch (transK=false, transQ=true) +# ============================================================================ +[[suite]] +name = "lds_transpose_k_q_prefetch" + +[[suite.test]] +config = "-seq_len_q 64 -seq_len_k 64 -head_dim_qk 64 -head_dim_v 64 --transK=false --transQ=true -perf_config attn:v2:32,32,32,32,32,32,8,1,3,2,1" + +[[suite.test]] +config = "-seq_len_q 64 -seq_len_k 64 -head_dim_qk 64 -head_dim_v 64 --transK=false --transQ=true -perf_config attn:v2:16,16,16,16,16,16,4,1,3,2,1" + +# ============================================================================ +# Suite 3: Only K uses LDS transpose (transK=false, transQ=false) +# ============================================================================ +[[suite]] +name = "lds_transpose_k_only" + +[[suite.test]] +config = "-seq_len_q 64 -seq_len_k 64 -head_dim_qk 128 -head_dim_v 64 --transK=false --transQ=false -perf_config attn:v2:32,32,32,32,32,32,8,1,3,2,1" + +[[suite.test]] +config = "-seq_len_q 64 -seq_len_k 64 -head_dim_qk 128 -head_dim_v 64 --transK=false --transQ=false -perf_config attn:v2:16,16,16,16,16,16,4,1,3,2,1" + +[[suite.test]] +config = "-seq_len_q 64 -seq_len_k 64 -head_dim_qk 64 -head_dim_v 64 --transK=false --transQ=false -perf_config attn:v2:32,32,32,32,32,32,4,1,3,2,1" + +# ============================================================================ +# Suite 4: Only Q uses LDS transpose (transK=true, transQ=true) +# ============================================================================ +[[suite]] +name = "lds_transpose_q_only" + +[[suite.test]] +config = "-seq_len_q 64 -seq_len_k 64 -head_dim_qk 128 -head_dim_v 64 --transK=true --transQ=true -perf_config attn:v2:32,32,32,32,32,32,8,1,3,2,1" + +[[suite.test]] +config = "-seq_len_q 64 -seq_len_k 64 -head_dim_qk 128 -head_dim_v 64 --transK=true --transQ=true -perf_config attn:v2:16,16,16,16,16,16,4,1,3,2,1" + +# ============================================================================ +# Suite 5: GEMM1 V uses LDS transpose, P prefetch (attn:v3) +# ============================================================================ +[[suite]] +name = "lds_transpose_gemm1_v" + +[[suite.test]] +config = "-seq_len_q 64 -seq_len_k 64 -head_dim_qk 64 -head_dim_v 64 --transK=false --transQ=false -perf_config attn:v3:32,32,32,32,32,32,32,4,1,3,2,0,1" + +[[suite.test]] +config = "-seq_len_q 64 -seq_len_k 64 -head_dim_qk 64 -head_dim_v 64 --transK=false --transQ=false -perf_config attn:v3:16,16,16,16,16,16,16,4,1,3,2,0,1" + +[[suite.test]] +config = "-seq_len_q 128 -seq_len_k 128 -head_dim_qk 64 -head_dim_v 64 --transK=false --transQ=false -perf_config attn:v3:32,32,32,32,32,32,32,4,1,3,2,0,1" + +# ============================================================================ +# Suite 6: Mixed scenarios with different MFMA sizes +# ============================================================================ +[[suite]] +name = "lds_transpose_mixed" + +[[suite.test]] +config = "-seq_len_q 64 -seq_len_k 64 -head_dim_qk 64 -head_dim_v 128 --transK=false --transQ=true -perf_config attn:v2:32,32,16,16,16,16,4,1,3,2,1" + +[[suite.test]] +config = "-seq_len_q 128 -seq_len_k 64 -head_dim_qk 64 -head_dim_v 64 --transK=false --transQ=false -perf_config attn:v2:32,32,32,32,32,32,8,1,3,2,1" + From 9027654fe6e899736934f082b3b2586858a6e519 Mon Sep 17 00:00:00 2001 From: stefan koncarevic Date: Tue, 20 Jan 2026 05:49:40 -0600 Subject: [PATCH 08/14] Add FP8/BF8 support for LDS transpose load Implement ds_read_tr8_b64 offset formulas for FP8/BF8 MFMA (16x32, 32x16). Enable mixed fp8/bf8 type combinations for GEMM operations on gfx950. --- mlir/include/mlir/Dialect/Rock/IR/RockOps.td | 15 +- mlir/lib/Dialect/Rock/IR/RockDialect.cpp | 21 ++ .../Dialect/Rock/utility/LdsTransposeLoad.cpp | 280 ++++++++++++------ .../Rock/lowering_load_transpose_lds.mlir | 14 + mlir/test/Dialect/Rock/ops.mlir | 21 ++ 5 files changed, 261 insertions(+), 90 deletions(-) diff --git a/mlir/include/mlir/Dialect/Rock/IR/RockOps.td b/mlir/include/mlir/Dialect/Rock/IR/RockOps.td index 32a8d458c096..3147f65f8339 100644 --- a/mlir/include/mlir/Dialect/Rock/IR/RockOps.td +++ b/mlir/include/mlir/Dialect/Rock/IR/RockOps.td @@ -1197,9 +1197,10 @@ defvar SameShapeVectorOfI1 = [{ def Rock_LDSTransposeLoadOp : Rock_Op<"lds_transpose_load", [DeclareOpInterfaceMethods< MemoryEffectsOpInterface>]>, - Arguments<(ins Arg, "LDS source buffer">:$source, + Arguments<(ins Arg, + "LDS source buffer">:$source, Variadic:$indices)>, - Results<(outs VectorOfLengthAndType<[4], [F16, BF16]>:$result)> { + Results<(outs AnyVectorOfNonZeroRank:$result)> { let summary = "Hardware-assisted LDS transpose load for matrix accelerator tile"; let description = [{ @@ -1209,8 +1210,14 @@ def Rock_LDSTransposeLoadOp instruction geometry (`dDim × instrK`), where: - `dDim` is the accelerator M/N dimension (e.g., 16 or 32) - `instrK` is the accelerator K dimension (e.g., 8, 16, or 32) - The operation returns a vector of 4 elements per thread containing - transposed elements in a layout suitable for matrix accelerator instructions. + + For 16-bit types (f16, bf16): + - Uses ds_read_tr16_b64 instruction + - Returns vector<4xtype> (4 elements per thread) + + For 8-bit types (f8E4M3FN, f8E5M2 - OCP FP8 for gfx950): + - Uses ds_read_tr8_b64 instruction + - Returns vector<8xtype> (8 elements per thread) Benefits: - Reduces LDS bank conflicts through optimized access patterns diff --git a/mlir/lib/Dialect/Rock/IR/RockDialect.cpp b/mlir/lib/Dialect/Rock/IR/RockDialect.cpp index 7851e4982877..48c0192e18b3 100644 --- a/mlir/lib/Dialect/Rock/IR/RockDialect.cpp +++ b/mlir/lib/Dialect/Rock/IR/RockDialect.cpp @@ -2123,6 +2123,27 @@ LogicalResult LDSTransposeLoadOp::verify() { << srcElemType << ")"; } + // Verify result vector length based on element type: + // - 16-bit types (f16, bf16): ds_read_tr16_b64 returns 4 elements + // - 8-bit types (f8E4M3FN, f8E5M2 - OCP FP8 for gfx950): ds_read_tr8_b64 + // returns 8 elements + int64_t expectedVecLen; + if (srcElemType.isF16() || srcElemType.isBF16()) { + expectedVecLen = 4; + } else if (isa(srcElemType) || + isa(srcElemType)) { + expectedVecLen = 8; + } else { + return emitOpError("unsupported element type for LDS transpose load: ") + << srcElemType; + } + + if (resultType.getNumElements() != expectedVecLen) { + return emitOpError("expected result vector of ") + << expectedVecLen << " elements for " << srcElemType + << " type, but got " << resultType.getNumElements(); + } + // Check hardware support using AmdArchDb StringAttr archAttr = rock::getArch(*this).value_or(StringAttr::get(getContext(), "gfx00")); diff --git a/mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp b/mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp index d4f021dc2715..01e426c55db7 100644 --- a/mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp +++ b/mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp @@ -21,8 +21,13 @@ // to the LDS transpose load operation in an accelerator-friendly layout. // // It is intended to simplify the IR generation logic and ensure -// consistent handling of f16/bf16 matrix accelerator tile loads from LDS -// memory. +// consistent handling of f16/bf16/fp8/bf8 matrix accelerator tile loads from +// LDS memory. +// +// Supported element types: +// - f16, bf16: uses ds_read_tr16_b64 (returns 4 elements per thread) +// - f8E4M3FN, f8E5M2 (OCP FP8): uses ds_read_tr8_b64 (returns 8 elements per +// thread) // //===----------------------------------------------------------------------===// @@ -47,6 +52,32 @@ namespace { bool archSupported(StringRef arch) { return arch.contains("gfx950"); } +// Check if element type is supported for LDS transpose load +// - f16, bf16: ds_read_tr16_b64 (4 elements) +// - f8E4M3FN, f8E5M2 (OCP FP8 for gfx950): ds_read_tr8_b64 (8 elements) +static bool isSupportedElementType(Type t) { + return t.isF16() || t.isBF16() || isa(t) || + isa(t); +} + +// Check if element type is 8-bit float (FP8 E4M3 or BF8 E5M2) +// Used for: +// 1. Selecting ds_read_tr8_b64 vs ds_read_tr16_b64 +// 2. Checking mixed-type compatibility (fp8+bf8 combinations are valid) +static bool isFp8Type(Type t) { + return isa(t) || isa(t); +} + +// Returns the number of elements returned by LDS transpose load instruction +static int64_t getTransposeLoadVectorLength(Type elemType) { + if (elemType.isF16() || elemType.isBF16()) { + return 4; // ds_read_tr16_b64 + } else if (isFp8Type(elemType)) { + return 8; // ds_read_tr8_b64 + } + llvm_unreachable("Unsupported element type for LDS transpose load"); +} + // Validates MFMA geometry for LDS transpose support. // Only specific combinations are supported: (16,16), (16,32), (32,8), (32,16) static bool isValidMfmaGeometry(int64_t dDim, int64_t kDim) { @@ -119,7 +150,16 @@ static Decision makeDecision(StringRef arch, Type elemTypeA, Type elemTypeB, return dec; } - if (elemTypeA != elemTypeB || !(elemTypeA.isF16() || elemTypeA.isBF16())) + // Check type compatibility: + // - Same types are always allowed (f16==f16, fp8==fp8, etc.) + // - Mixed FP8/BF8 combinations are allowed (hardware supports mixed fp8/bf8 + // MFMA: + // mfma_f32_16x16x32_fp8_fp8, mfma_f32_16x16x32_fp8_bf8, etc.) + // - Other mixed types are NOT allowed (e.g., f16 with fp8) + bool typesCompatible = (elemTypeA == elemTypeB) || + (isFp8Type(elemTypeA) && isFp8Type(elemTypeB)); + if (!typesCompatible || !isSupportedElementType(elemTypeA) || + !isSupportedElementType(elemTypeB)) return dec; // Validate MFMA geometry @@ -477,8 +517,9 @@ static Value computePanelFinalOffset(PatternRewriter &b, Location loc, // emitPanelLoad - Emit an LDS transpose load operation // // Computes the final LDS offset and emits a hardware LDS transpose load -// instruction (ds_read_tr16_b64). This instruction always returns vector<4xf16> -// regardless of the layout. +// instruction: +// - ds_read_tr16_b64 for f16/bf16: returns vector<4> +// - ds_read_tr8_b64 for fp8/bf8: returns vector<8> // // The final offset is computed as: final_offset = k_base * ldsStride + m_base // where ldsStride depends on the operand (mPerBlock for A, nPerBlock for B). @@ -491,10 +532,10 @@ static Value computePanelFinalOffset(PatternRewriter &b, Location loc, // kBase - K dimension base offset for this panel // mBase - M/N dimension base offset for this panel // ldsStride - Stride between K rows in LDS (mPerBlock or nPerBlock) -// panelVecType - Result type (always vector<4xf16> or vector<4xbf16>) +// panelVecType - Result type (vector<4> for f16/bf16, vector<8> for fp8/bf8) // // Returns: -// The loaded panel vector (vector<4xf16/bf16>) +// The loaded panel vector (vector<4> for f16/bf16, vector<8> for fp8/bf8) //===----------------------------------------------------------------------===// static Value emitPanelLoad(PatternRewriter &b, Location loc, Value rawSrc, Value kBase, Value mBase, Value ldsStride, @@ -503,7 +544,7 @@ static Value emitPanelLoad(PatternRewriter &b, Location loc, Value rawSrc, Value kOffset = arith::MulIOp::create(b, loc, kBase, ldsStride); Value finalOffset = arith::AddIOp::create(b, loc, mBase, kOffset); - // Emit hardware LDS transpose load: ds_read_tr16_b64 + // Emit hardware LDS transpose load auto loadOp = rock::LDSTransposeLoadOp::create(b, loc, panelVecType, rawSrc, ValueRange{finalOffset}); @@ -518,9 +559,9 @@ static Value emitPanelLoad(PatternRewriter &b, Location loc, Value rawSrc, // elements (ds_read_tr16_b64 always returns vector<4xf16>). // // Parameters: -// panelVectors - Array of loaded panel vectors (each is vector<4xf16>) -// dest - Destination memref (rank-1, scalar layout) -// targetElems - Maximum number of elements to write +// panelVectors - Array of loaded panel vectors (vector<4> for f16/bf16, +// vector<8> for fp8/bf8) dest - Destination memref (rank-1, scalar +// layout) targetElems - Maximum number of elements to write // // Returns: // success() if all target elements were written @@ -534,14 +575,16 @@ writePanelVectorsToDestination(PatternRewriter &b, Location loc, int64_t produced = 0; // Extract elements per vector from the actual vector type - // Hardware instruction ds_read_tr16_b64 always returns vector<4xf16> + // Hardware instructions: + // - ds_read_tr16_b64 returns vector<4xf16/bf16> + // - ds_read_tr8_b64 returns vector<8xfp8> assert(!panelVectors.empty() && "Panel vectors array must not be empty"); auto panelVecType = cast(panelVectors[0].getType()); int64_t elementsPerVector = panelVecType.getShape()[0]; - // Verify hardware constraint: ds_read_tr16_b64 returns exactly 4 elements - assert(elementsPerVector == 4 && - "LDS transpose load must produce vector<4xf16> per panel"); + // Verify hardware constraint: 4 elements for 16-bit, 8 elements for 8-bit + assert((elementsPerVector == 4 || elementsPerVector == 8) && + "LDS transpose load must produce vector<4> or vector<8> per panel"); LLVM_DEBUG(llvm::dbgs() << "[lds_transpose] Writing " << panelVectors.size() << " panel vectors (" @@ -597,60 +640,114 @@ writePanelVectorsToDestination(PatternRewriter &b, Location loc, //===----------------------------------------------------------------------===// static SmallVector getBasePanelOffsets(PatternRewriter &b, Location loc, int64_t dDim, int64_t kDim, - Value lane) { - Value c16 = arith::ConstantIndexOp::create(b, loc, 16); - Value c4 = arith::ConstantIndexOp::create(b, loc, 4); + Value lane, Type elemType) { + // Common constants used by both FP8 and F16/BF16 Value c2 = arith::ConstantIndexOp::create(b, loc, 2); + Value c4 = arith::ConstantIndexOp::create(b, loc, 4); + Value c16 = arith::ConstantIndexOp::create(b, loc, 16); - Value blockId = arith::DivUIOp::create(b, loc, lane, c16); - Value laneInBlock = arith::RemUIOp::create(b, loc, lane, c16); - - // Base offset calculations - Value mOffsetBase = arith::MulIOp::create( - b, loc, arith::RemUIOp::create(b, loc, laneInBlock, c4), c4); - Value kOffsetBase = arith::DivUIOp::create(b, loc, laneInBlock, c4); + Value kOffsetBase, mOffsetBase; + + if (isFp8Type(elemType)) { + Value c8 = arith::ConstantIndexOp::create(b, loc, 8); + + if (dDim == 16 && kDim == 32) { + // FP8/BF8 16x32: 4-block formula + // Block layout: + // Block 0 (T0-T15): K=0..7 + // Block 1 (T16-T31): K=8..15 + // Block 2 (T32-T47): K=16..23 + // Block 3 (T48-T63): K=24..31 + + Value blockId = arith::DivUIOp::create(b, loc, lane, c16); + Value laneInBlock = arith::RemUIOp::create(b, loc, lane, c16); + Value kLocal = arith::DivUIOp::create(b, loc, laneInBlock, c2); + Value mParity = arith::RemUIOp::create(b, loc, laneInBlock, c2); + + // kOffsetBase = k_local + block_id * 8 + Value blockKOffset = arith::MulIOp::create(b, loc, blockId, c8); + kOffsetBase = arith::AddIOp::create(b, loc, kLocal, blockKOffset); + + // mOffsetBase = m_parity * 8 + mOffsetBase = arith::MulIOp::create(b, loc, mParity, c8); + + } else if (dDim == 32 && kDim == 16) { + // FP8 32x16: 4-block formula (VERIFIED from HIP testing) + // Block layout: + // Block 0 (T0-T15): M=0..15, K=0..7 + // Block 1 (T16-T31): M=16..31, K=0..7 + // Block 2 (T32-T47): M=0..15, K=8..15 + // Block 3 (T48-T63): M=16..31, K=8..15 + + Value blockId = arith::DivUIOp::create(b, loc, lane, c16); + Value laneInBlock = arith::RemUIOp::create(b, loc, lane, c16); + Value kLocal = arith::DivUIOp::create(b, loc, laneInBlock, c2); + Value mParity = arith::RemUIOp::create(b, loc, laneInBlock, c2); + + // m_block = block_id % 2, k_block = block_id / 2 + Value mBlock = arith::RemUIOp::create(b, loc, blockId, c2); + Value kBlock = arith::DivUIOp::create(b, loc, blockId, c2); + + // kOffsetBase = k_local + k_block * 8 + Value kBlockOffset = arith::MulIOp::create(b, loc, kBlock, c8); + kOffsetBase = arith::AddIOp::create(b, loc, kLocal, kBlockOffset); + + // mOffsetBase = m_parity * 8 + m_block * 16 + Value mParityOffset = arith::MulIOp::create(b, loc, mParity, c8); + Value mBlockOffset = arith::MulIOp::create(b, loc, mBlock, c16); + mOffsetBase = arith::AddIOp::create(b, loc, mParityOffset, mBlockOffset); - SmallVector panelOffsets; + } else { + llvm_unreachable("Unsupported FP8 MFMA geometry in getBasePanelOffsets"); + } - if (dDim == 16 && kDim == 32) { - // 16x32 layout - panelOffsets = {kOffsetBase, mOffsetBase}; - } else if (dDim == 16 && kDim == 16) { - // 16x16 layout - // kbase = kOffsetBase + (blockId * 4) - Value kBase = arith::AddIOp::create( - b, loc, arith::MulIOp::create(b, loc, blockId, c4), kOffsetBase); - panelOffsets = {kBase, mOffsetBase}; - } else if (dDim == 32 && kDim == 16) { - // 32x16 layout - // mbase = mOffsetBase + (blockId % 2) * 16 - Value mBase = arith::AddIOp::create( - b, loc, - arith::MulIOp::create(b, loc, - arith::RemUIOp::create(b, loc, blockId, c2), c16), - mOffsetBase); - panelOffsets = {kOffsetBase, mBase}; - } else if (dDim == 32 && kDim == 8) { - // 32x8 layout - // k_base_local = kOffsetBase + (blockId / 2) * 4 - Value kBase = arith::AddIOp::create( - b, loc, - arith::MulIOp::create(b, loc, - arith::DivUIOp::create(b, loc, blockId, c2), c4), - kOffsetBase); - - // m_offset_base = mOffsetBase + (blockId % 2) * 16 - Value mBase = arith::AddIOp::create( - b, loc, - arith::MulIOp::create(b, loc, - arith::RemUIOp::create(b, loc, blockId, c2), c16), - mOffsetBase); - panelOffsets = {kBase, mBase}; } else { - llvm_unreachable("Unsupported MFMA geometry in getBasePanelOffsets"); + // F16/BF16 uses block-based lane mapping + Value blockId = arith::DivUIOp::create(b, loc, lane, c16); + Value laneInBlock = arith::RemUIOp::create(b, loc, lane, c16); + + // Base calculations common to all F16/BF16 geometries + Value kLocal = arith::DivUIOp::create(b, loc, laneInBlock, c4); + Value mLocal = arith::MulIOp::create( + b, loc, arith::RemUIOp::create(b, loc, laneInBlock, c4), c4); + + if (dDim == 16 && kDim == 32) { + // 16x32: direct mapping + kOffsetBase = kLocal; + mOffsetBase = mLocal; + + } else if (dDim == 16 && kDim == 16) { + // 16x16: k += blockId * 4 + kOffsetBase = arith::AddIOp::create( + b, loc, kLocal, arith::MulIOp::create(b, loc, blockId, c4)); + mOffsetBase = mLocal; + + } else if (dDim == 32 && kDim == 16) { + // 32x16: m += (blockId % 2) * 16 + kOffsetBase = kLocal; + mOffsetBase = arith::AddIOp::create( + b, loc, mLocal, + arith::MulIOp::create( + b, loc, arith::RemUIOp::create(b, loc, blockId, c2), c16)); + + } else if (dDim == 32 && kDim == 8) { + // 32x8: k += (blockId / 2) * 4, m += (blockId % 2) * 16 + kOffsetBase = arith::AddIOp::create( + b, loc, kLocal, + arith::MulIOp::create( + b, loc, arith::DivUIOp::create(b, loc, blockId, c2), c4)); + mOffsetBase = arith::AddIOp::create( + b, loc, mLocal, + arith::MulIOp::create( + b, loc, arith::RemUIOp::create(b, loc, blockId, c2), c16)); + + } else { + llvm_unreachable( + "Unsupported F16/BF16 MFMA geometry in getBasePanelOffsets"); + } } - return panelOffsets; + return {kOffsetBase, mOffsetBase}; } //===----------------------------------------------------------------------===// @@ -671,6 +768,7 @@ static SmallVector getBasePanelOffsets(PatternRewriter &b, Location loc, // dDim - MFMA D dimension (M or N, 16 or 32) // kDim - MFMA K dimension (8, 16, or 32) // lane - Thread's lane ID within the workgroup +// elemType - Element type (f16, bf16, fp8, or bf8) for selecting lane mapping // // Returns: // std::pair: @@ -679,8 +777,10 @@ static SmallVector getBasePanelOffsets(PatternRewriter &b, Location loc, //===----------------------------------------------------------------------===// static std::pair computeLDSBaseOffsets(PatternRewriter &b, Location loc, int64_t dDim, - int64_t kDim, Value lane) { - SmallVector offsets = getBasePanelOffsets(b, loc, dDim, kDim, lane); + int64_t kDim, Value lane, + Type elemType) { + SmallVector offsets = + getBasePanelOffsets(b, loc, dDim, kDim, lane, elemType); LLVM_DEBUG(llvm::dbgs() << "[lds_transpose] Computed LDS base offsets for " << dDim << "x" << kDim << ": " @@ -1086,8 +1186,9 @@ static Value computeFinalMNOffset(PatternRewriter &b, Location loc, // emitThreadwiseHWTranspose - Lower threadwise_read_into to HW transpose loads //===----------------------------------------------------------------------===// // Lowers threadwise_read_into with LDS transpose config into hardware transpose -// load instructions (ds_read_tr16_b64) that read from LDS in MFMA-friendly -// order. +// load instructions that read from LDS in MFMA-friendly order: +// - ds_read_tr16_b64 for f16/bf16 (returns vector<4>) +// - ds_read_tr8_b64 for fp8/bf8 (returns vector<8>) // // Algorithm: // 1. Extract config: MFMA geometry (dDim, kDim), tiling params, operand kind @@ -1100,8 +1201,8 @@ static Value computeFinalMNOffset(PatternRewriter &b, Location loc, // - Outer: M/N tiles (all at once for double-buffering, one at a time // otherwise) // - Inner: K tiles (1 load for single-rate, 2 loads for double-rate layouts) -// 6. For each iteration: compute final LDS offset, emit ds_read_tr16_b64 -// instruction (returns vector<4xf16>) +// 6. For each iteration: compute final LDS offset, emit LDS transpose load +// instruction (ds_read_tr16_b64 for f16/bf16, ds_read_tr8_b64 for fp8/bf8) // 7. Extract elements from panel vectors and write sequentially to destination // // Example: 16x32 layout, 1 M-tile, 2 K-tiles → 2 ds_read_tr16_b64 calls → 8 @@ -1155,23 +1256,27 @@ LogicalResult emitThreadwiseHWTranspose(PatternRewriter &b, int64_t ldsStride = (operand == OperandKind::A) ? mPerBlock : nPerBlock; // Determine if this is a double-rate instruction - // Double-rate ONLY for (32,16) and (16,32) MFMA - // (16,16) and (32,8) are SINGLE-RATE - bool isDoubleRate = - (dDim == 32 && instrK == 16) || (dDim == 16 && instrK == 32); - - // Each ds_read_tr16_b64 call ALWAYS returns vector<4xf16> - // For double-rate, we make 2 calls and store all 8 elements separately - VectorType panelVecType = VectorType::get({4}, elemType); + // Double-rate ONLY for (32,16) and (16,32) MFMA with F16/BF16 + // FP8/BF8 uses ds_read_tr8_b64 which returns 8 elements, so (16,32) and + // (32,16) are SINGLE-RATE for FP8/BF8 (16,16) and (32,8) are always + // SINGLE-RATE + bool isDoubleRate = !isFp8Type(elemType) && ((dDim == 32 && instrK == 16) || + (dDim == 16 && instrK == 32)); + + // Determine vector length based on element type: + // - f16/bf16: ds_read_tr16_b64 returns vector<4> + // - fp8/bf8: ds_read_tr8_b64 returns vector<8> + int64_t vecLen = getTransposeLoadVectorLength(elemType); + VectorType panelVecType = VectorType::get({vecLen}, elemType); // panelVectors will contain: - // - Single-rate: 1 vector<4xf16> per K tile - // - Double-rate: 2 vector<4xf16> per K tile (low + high) + // - Single-rate: 1 vector per K tile + // - Double-rate (f16/bf16 only): 2 vectors per K tile (low + high) SmallVector panelVectors; // Get base offsets using computeLDSBaseOffsets helper auto [k_base_local, m_offset_base] = - computeLDSBaseOffsets(b, loc, dDim, instrK, lane); + computeLDSBaseOffsets(b, loc, dDim, instrK, lane, elemType); // K stride per tile: instrK (MFMA K dimension) int64_t kTileStride = instrK; @@ -1226,20 +1331,21 @@ LogicalResult emitThreadwiseHWTranspose(PatternRewriter &b, useDynamicMnIndex, waveOffsetStrideVal, tileOffsetStrideVal); if (!isDoubleRate) { - // SINGLE-RATE (L32x8, L16x16): One load per K tile + // SINGLE-RATE (L32x8, L16x16, or FP8/BF8): One load per K tile Value k_base = computePanelFinalOffset(b, loc, isDoubleRate, k_base_local, kOffsetBase, kIdx, kTileStrideVal); - // Emit LDS transpose load for this K tile (single-rate: one per K tile) + // Emit LDS transpose load for this K tile Value panelVec = emitPanelLoad(b, loc, rawSrc, k_base, m_base, ldsStrideVal, panelVecType); panelVectors.push_back(panelVec); } else { - // DOUBLE-RATE (L32x16, L16x32): TWO loads per K tile - // Each load returns vector<4xf16>, total 8 elements per K tile - // Compute K offsets for low and high halves + // DOUBLE-RATE (L32x16, L16x32 with F16/BF16 only): TWO loads per K tile + // Each load returns vector<4> for f16/bf16, total 8 elements per K tile + // Note: FP8/BF8 is NEVER double-rate (ds_read_tr8_b64 returns 8 + // elements) Compute K offsets for low and high halves Value k_base_low = computePanelFinalOffset( b, loc, isDoubleRate, k_base_local, kOffsetBase, kIdx, kTileStrideVal, /*isHighHalf=*/false); @@ -1269,8 +1375,10 @@ LogicalResult emitThreadwiseHWTranspose(PatternRewriter &b, int64_t loadsPerKTile = isDoubleRate ? 2 : 1; int64_t expectedLoads = actualMnTiles * kPanels * loadsPerKTile; - // Each load ALWAYS produces 4 elements (ds_read_tr16_b64 → vector<4xf16>) - int64_t sliceElems = expectedLoads * 4; + // Each load produces vecLen elements: + // - f16/bf16: 4 elements (ds_read_tr16_b64) + // - fp8/bf8: 8 elements (ds_read_tr8_b64) + int64_t sliceElems = expectedLoads * vecLen; // Verify we generated the expected number of loads if (panelVectors.size() != (size_t)expectedLoads) { diff --git a/mlir/test/Dialect/Rock/lowering_load_transpose_lds.mlir b/mlir/test/Dialect/Rock/lowering_load_transpose_lds.mlir index 29a2f190a06b..b6a64a6985f7 100644 --- a/mlir/test/Dialect/Rock/lowering_load_transpose_lds.mlir +++ b/mlir/test/Dialect/Rock/lowering_load_transpose_lds.mlir @@ -14,4 +14,18 @@ module attributes {mhal.arch = "amdgcn-amd-amdhsa:gfx950"} { %v = rock.lds_transpose_load %src[%i, %j] : memref<64x128xbf16, #gpu.address_space> -> vector<4xbf16> return %v : vector<4xbf16> } + +// CHECK-LABEL: func @test_load_transpose_fp8_e4m3 + func.func @test_load_transpose_fp8_e4m3(%src: memref<128x256xf8E4M3FN, #gpu.address_space>, %i: index, %j: index) -> vector<8xf8E4M3FN> { + // CHECK: amdgpu.transpose_load %arg0[%arg1, %arg2] : memref<128x256xf8E4M3FN, #gpu.address_space> -> vector<8xf8E4M3FN> + %v = rock.lds_transpose_load %src[%i, %j] : memref<128x256xf8E4M3FN, #gpu.address_space> -> vector<8xf8E4M3FN> + return %v : vector<8xf8E4M3FN> + } + +// CHECK-LABEL: func @test_load_transpose_fp8_e5m2 + func.func @test_load_transpose_fp8_e5m2(%src: memref<64x128xf8E5M2, #gpu.address_space>, %i: index, %j: index) -> vector<8xf8E5M2> { + // CHECK: amdgpu.transpose_load %arg0[%arg1, %arg2] : memref<64x128xf8E5M2, #gpu.address_space> -> vector<8xf8E5M2> + %v = rock.lds_transpose_load %src[%i, %j] : memref<64x128xf8E5M2, #gpu.address_space> -> vector<8xf8E5M2> + return %v : vector<8xf8E5M2> + } } diff --git a/mlir/test/Dialect/Rock/ops.mlir b/mlir/test/Dialect/Rock/ops.mlir index 3e19d0dcd4ea..b21b937a1686 100644 --- a/mlir/test/Dialect/Rock/ops.mlir +++ b/mlir/test/Dialect/Rock/ops.mlir @@ -433,6 +433,27 @@ func.func @rock_lds_transpose_load_full_arch(%lds_buffer: memref<128x64xf16, #gp return } +// CHECK-LABEL: func.func @rock_lds_transpose_load_fp8_e4m3 +// CHECK: rock.lds_transpose_load +func.func @rock_lds_transpose_load_fp8_e4m3(%lds_buffer: memref<128x64xf8E4M3FN, #gpu.address_space>) + attributes {arch = "amdgcn-amd-amdhsa:gfx950"} { + %c0 = arith.constant 0 : index + %fragment = rock.lds_transpose_load %lds_buffer[%c0, %c0] + : memref<128x64xf8E4M3FN, #gpu.address_space> -> vector<8xf8E4M3FN> + return +} + +// CHECK-LABEL: func.func @rock_lds_transpose_load_fp8_e5m2 +// CHECK: rock.lds_transpose_load +func.func @rock_lds_transpose_load_fp8_e5m2(%lds_buffer: memref<256x128xf8E5M2, #gpu.address_space>) + attributes {arch = "amdgcn-amd-amdhsa:gfx950"} { + %c0 = arith.constant 0 : index + %c32 = arith.constant 32 : index + %fragment = rock.lds_transpose_load %lds_buffer[%c32, %c0] + : memref<256x128xf8E5M2, #gpu.address_space> -> vector<8xf8E5M2> + return +} + // CHECK-LABEL: func.func @test_lds_transpose_config_attr_16x32 // CHECK: ldsTransposeConfig = #rock.lds_transpose_config func.func @test_lds_transpose_config_attr_16x32(%src: memref<8192xf16, #gpu.address_space>, From 84c9425c8acbdccf88321802cd91d09bddec9826 Mon Sep 17 00:00:00 2001 From: stefan koncarevic Date: Tue, 20 Jan 2026 08:34:42 -0600 Subject: [PATCH 09/14] Add PR CI tests for FP8/BF8 LDS transpose load GEMM operations --- mlir/test/e2e/CMakeLists.txt | 1 + mlir/test/e2e/PrLdsTransposeLoadFp8.cfg | 2 + mlir/test/e2e/PrLdsTransposeLoadFp8.toml | 126 +++++++++++++++++++++++ 3 files changed, 129 insertions(+) create mode 100644 mlir/test/e2e/PrLdsTransposeLoadFp8.cfg create mode 100644 mlir/test/e2e/PrLdsTransposeLoadFp8.toml diff --git a/mlir/test/e2e/CMakeLists.txt b/mlir/test/e2e/CMakeLists.txt index 317918b9cc68..7f500161c660 100644 --- a/mlir/test/e2e/CMakeLists.txt +++ b/mlir/test/e2e/CMakeLists.txt @@ -50,6 +50,7 @@ if (ROCMLIR_DRIVER_PR_E2E_TEST_ENABLED) PrConvElementwiseGemmBF16SplitK PrGemmDirectToLDS PrLdsTransposeLoad + PrLdsTransposeLoadFp8 PrLdsTransposeLoadAttention PrConvDirectToLDS PrAttentionDirectToLDS diff --git a/mlir/test/e2e/PrLdsTransposeLoadFp8.cfg b/mlir/test/e2e/PrLdsTransposeLoadFp8.cfg new file mode 100644 index 000000000000..46909aa10a02 --- /dev/null +++ b/mlir/test/e2e/PrLdsTransposeLoadFp8.cfg @@ -0,0 +1,2 @@ +if not 'lds_transpose_load' in config.features: + config.unsupported = True diff --git a/mlir/test/e2e/PrLdsTransposeLoadFp8.toml b/mlir/test/e2e/PrLdsTransposeLoadFp8.toml new file mode 100644 index 000000000000..41b36b149e6a --- /dev/null +++ b/mlir/test/e2e/PrLdsTransposeLoadFp8.toml @@ -0,0 +1,126 @@ +directory = "PrLdsTransposeLoadFp8" +prefix = "rocmlir-gen" +suffix = "--operation gemm --arch %arch %pv %constrained_float_range_random_data %rocmlir_gen_flags | rocmlir-driver -c | mlir-runner -O2 --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext --entry-point-result=void | FileCheck %s --check-prefix=" + +[[axis]] +name = "data type" +values = ["fp8_fp8", "bf8_bf8", "fp8_bf8", "bf8_fp8"] +prefix = "-t " + +# ============================================================================ +# Suite 1: BOTH A and B use LDS transpose (transA=true, transB=false) +# 5 tests covering different MFMA geometries and kpack/kPerBlock combinations +# Only fp8_fp8 and bf8_bf8 +# ============================================================================ +[[suite]] +name = "lds_transpose_both_operands_fp8" + +[[suite.test]] +config = "-g 1 -m 64 -k 128 -n 64 --transA=true --transB=false --perf_config v3:64,64,32,16,16,8,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +[[suite.test]] +config = "-g 1 -m 128 -k 64 -n 64 --transA=true --transB=false --perf_config v3:128,64,16,32,16,16,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +[[suite.test]] +config = "-g 1 -m 256 -k 256 -n 64 --transA=true --transB=false --perf_config v3:256,64,8,64,16,32,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +[[suite.test]] +config = "-g 1 -m 128 -k 128 -n 128 --transA=true --transB=false --perf_config v3:128,128,16,32,32,8,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +[[suite.test]] +config = "-g 1 -m 256 -k 128 -n 128 --transA=true --transB=false --perf_config v3:256,128,32,64,32,8,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +# ============================================================================ +# Suite 2: ONLY A uses LDS transpose (transA=true, transB=true) +# 3 tests - B uses regular load +# Only fp8_fp8 and bf8_bf8 +# ============================================================================ +[[suite]] +name = "lds_transpose_A_only_fp8" + +[[suite.test]] +config = "-g 1 -m 64 -k 64 -n 64 --transA=true --transB=true --perf_config v3:64,64,16,16,16,8,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +[[suite.test]] +config = "-g 1 -m 128 -k 128 -n 64 --transA=true --transB=true --perf_config v3:128,64,8,32,16,16,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +[[suite.test]] +config = "-g 1 -m 128 -k 256 -n 128 --transA=true --transB=true --perf_config v3:128,128,32,32,32,8,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +# ============================================================================ +# Suite 3: ONLY B uses LDS transpose (transA=false, transB=false) +# 3 tests - A uses regular load +# Only fp8_fp8 and bf8_bf8 +# ============================================================================ +[[suite]] +name = "lds_transpose_B_only_fp8" + +[[suite.test]] +config = "-g 1 -m 64 -k 64 -n 64 --transA=false --transB=false --perf_config v3:64,64,16,16,16,8,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +[[suite.test]] +config = "-g 1 -m 128 -k 128 -n 64 --transA=false --transB=false --perf_config v3:128,64,32,32,16,1,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +[[suite.test]] +config = "-g 1 -m 128 -k 256 -n 128 --transA=false --transB=false --perf_config v3:128,128,8,32,32,32,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +# ============================================================================ +# Suite 4: Mixed FP8/BF8 types (fp8_bf8 and bf8_fp8 only) +# 3 tests: both operands, only A, only B +# ============================================================================ +[[suite]] +name = "lds_transpose_mixed_fp8_bf8" + +# Test 1: Mixed types, both operands use LDS transpose, 16x32 MFMA +[[suite.test]] +config = "-g 1 -m 64 -k 128 -n 64 --transA=true --transB=false --perf_config v3:64,64,32,16,16,8,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_fp8", "bf8_bf8"] + +# Test 2: Mixed types, only A uses LDS transpose, 32x16 MFMA +[[suite.test]] +config = "-g 1 -m 128 -k 128 -n 128 --transA=true --transB=true --perf_config v3:128,128,16,32,32,8,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_fp8", "bf8_bf8"] + +# Test 3: Mixed types, only B uses LDS transpose, 16x32 MFMA +[[suite.test]] +config = "-g 1 -m 64 -k 64 -n 64 --transA=false --transB=false --perf_config v3:64,64,16,16,16,8,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_fp8", "bf8_bf8"] From e3ecb268849c3a6dbbbdade9661df9a72f21fd47 Mon Sep 17 00:00:00 2001 From: stefan koncarevic Date: Fri, 23 Jan 2026 08:35:20 -0600 Subject: [PATCH 10/14] Add INT8 LDS transpose load support for GEMM and Attention - Add INT8 (i8) support in LdsTransposeLoad.cpp for ds_read_tr8_b64 - Support mfma_i32_16x16x32_i8, mfma_i32_16x16x64_i8, mfma_i32_32x32x16_i8, mfma_i32_32x32x32_i8 - Add INT8 16x64 and 32x32 MFMA geometries with double-rate K coverage - Handle kpack=1 case for INT8 MFMAs with kBase=16 in AccelEmitter.cpp - Add validation for INT8 MFMA geometries in RockDialect.cpp - Add e2e tests for INT8 LDS transpose in GEMM and Attention --- mlir/include/mlir/Dialect/Rock/IR/RockOps.td | 6 +- mlir/lib/Dialect/Rock/IR/RockDialect.cpp | 10 +- .../lib/Dialect/Rock/utility/AccelEmitter.cpp | 58 ++++++ .../Dialect/Rock/utility/LdsTransposeLoad.cpp | 165 ++++++++++++------ .../Rock/lds_transpose_attributes.mlir | 68 ++++++++ .../Dialect/Rock/lds_transpose_error.mlir | 14 +- .../Rock/lowering_load_transpose_lds.mlir | 7 + mlir/test/e2e/CMakeLists.txt | 2 + .../e2e/PrLdsTransposeLoadAttentionI8.cfg | 2 + .../e2e/PrLdsTransposeLoadAttentionI8.toml | 47 +++++ mlir/test/e2e/PrLdsTransposeLoadI8.cfg | 2 + mlir/test/e2e/PrLdsTransposeLoadI8.toml | 86 +++++++++ 12 files changed, 404 insertions(+), 63 deletions(-) create mode 100644 mlir/test/e2e/PrLdsTransposeLoadAttentionI8.cfg create mode 100644 mlir/test/e2e/PrLdsTransposeLoadAttentionI8.toml create mode 100644 mlir/test/e2e/PrLdsTransposeLoadI8.cfg create mode 100644 mlir/test/e2e/PrLdsTransposeLoadI8.toml diff --git a/mlir/include/mlir/Dialect/Rock/IR/RockOps.td b/mlir/include/mlir/Dialect/Rock/IR/RockOps.td index 3147f65f8339..8272b543ed38 100644 --- a/mlir/include/mlir/Dialect/Rock/IR/RockOps.td +++ b/mlir/include/mlir/Dialect/Rock/IR/RockOps.td @@ -1197,7 +1197,7 @@ defvar SameShapeVectorOfI1 = [{ def Rock_LDSTransposeLoadOp : Rock_Op<"lds_transpose_load", [DeclareOpInterfaceMethods< MemoryEffectsOpInterface>]>, - Arguments<(ins Arg, + Arguments<(ins Arg, "LDS source buffer">:$source, Variadic:$indices)>, Results<(outs AnyVectorOfNonZeroRank:$result)> { @@ -1209,13 +1209,13 @@ def Rock_LDSTransposeLoadOp The tile dimensions match the selected matrix-multiply accelerator instruction geometry (`dDim × instrK`), where: - `dDim` is the accelerator M/N dimension (e.g., 16 or 32) - - `instrK` is the accelerator K dimension (e.g., 8, 16, or 32) + - `instrK` is the accelerator K dimension (e.g., 8, 16, 32, or 64) For 16-bit types (f16, bf16): - Uses ds_read_tr16_b64 instruction - Returns vector<4xtype> (4 elements per thread) - For 8-bit types (f8E4M3FN, f8E5M2 - OCP FP8 for gfx950): + For 8-bit types (f8E4M3FN, f8E5M2 - OCP FP8, i8 - INT8 for gfx950): - Uses ds_read_tr8_b64 instruction - Returns vector<8xtype> (8 elements per thread) diff --git a/mlir/lib/Dialect/Rock/IR/RockDialect.cpp b/mlir/lib/Dialect/Rock/IR/RockDialect.cpp index 48c0192e18b3..37314f953552 100644 --- a/mlir/lib/Dialect/Rock/IR/RockDialect.cpp +++ b/mlir/lib/Dialect/Rock/IR/RockDialect.cpp @@ -559,9 +559,13 @@ LogicalResult TransformMapAttr::verify( } // Helper function to check valid MFMA geometry for LDS transpose +// Supported geometries: +// - F16/BF16: (16,16), (16,32), (32,8), (32,16) +// - FP8/BF8: (16,32), (32,16) +// - INT8: (16,32), (16,64), (32,16), (32,32) static bool isValidLdsTransposeMfmaGeometry(int64_t dDim, int64_t kDim) { - return (dDim == 16 && (kDim == 16 || kDim == 32)) || - (dDim == 32 && (kDim == 8 || kDim == 16)); + return (dDim == 16 && (kDim == 16 || kDim == 32 || kDim == 64)) || + (dDim == 32 && (kDim == 8 || kDim == 16 || kDim == 32)); } LogicalResult LDSTransposeConfigAttr::verify( @@ -573,7 +577,7 @@ LogicalResult LDSTransposeConfigAttr::verify( if (!isValidLdsTransposeMfmaGeometry(dDim, kDim)) { return emitError() << "invalid MFMA geometry (" << dDim << "x" << kDim << ") for LDS transpose - valid combinations: " - "(16,16), (16,32), (32,8), (32,16)"; + "(16,16), (16,32), (16,64), (32,8), (32,16), (32,32)"; } // Validate positive dimensions diff --git a/mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp b/mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp index 711d4a7b4edf..6ab4fce4425b 100644 --- a/mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp +++ b/mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp @@ -594,6 +594,64 @@ Value MfmaEmitter::wrapLDSBufferForLoad( toLDSRowCol.unmerge("k", 1, {"k_iter", "k_mfma", "blk_k", "k_base"}, {kIter, numMfmaPerKVec, numBlksInK, kBase}); + } else if (useLdsTransposeLoad && kBase == 16) { + // This handles INT8 32x32 (instrK=32, kBase=16) and INT8 16x64 + // (instrK=64, kBase=16) when kpack=1 + int64_t instrK = mfmaAttr.k; + int64_t numBlksInK = instrK / kBase; + int64_t numBlksInD = (waveSize / inputSpanLen) / numBlksInK; + + // Split blk_id into blk_d (for D dimension) and blk_k (for K dimension) + TopDownTMBuilder splitBlkId = + TopDownTMBuilder::below(splitWaveId, splitWaveIdAttr); + splitBlkId.passThrough({"wave_m", "wave_n"}, {0, 1}, + {"wave_m", "wave_n"}); + splitBlkId.merge({"blk_d", "blk_k"}, {2, 3}, "blk_id", + {numBlksInD, numBlksInK}); + splitBlkId.passThrough({"blk_td", "d_iter", "k_iter", "k_vec"}, + {4, 5, 6, 7}, + {"blk_td", "d_iter", "k_iter", "k_vec"}); + TransformMapAttr splitBlkIdAttr = splitBlkId.get(); + transformAttrs.push_back(splitBlkIdAttr); + + // For kVec < kBase, we split k_iter into outer_k (MFMA iterations) + // and inner_k (within kBase for each blk_k) + // Total K per thread = kIter * kVec + // Each thread covers K range: blk_k * kBase to blk_k * kBase + kBase - 1 + // Number of MFMA iterations = (kIter * kVec) / kBase (iterations per + // blk_k) K iterations within each kBase = kBase / kVec Constraint: + // numMfmaIters * kIterPerBlk = kIter + int64_t kIterPerBlk = kBase / kVec; // iterations to cover kBase elements + int64_t numMfmaIters = kIter / kIterPerBlk; // number of outer iterations + + TopDownTMBuilder splitKIter = + TopDownTMBuilder::below(splitBlkId, splitBlkIdAttr); + splitKIter.passThrough({"wave_m", "wave_n"}, {0, 1}, + {"wave_m", "wave_n"}); + splitKIter.passThrough({"blk_d", "blk_k", "blk_td", "d_iter"}, + {2, 3, 4, 5}, + {"blk_d", "blk_k", "blk_td", "d_iter"}); + splitKIter.merge({"outer_k", "inner_k"}, {6, 7}, "k_iter", + {numMfmaIters, kIterPerBlk}); + splitKIter.passThrough({"k_vec"}, {8}, {"k_vec"}); + TransformMapAttr splitKIterAttr = splitKIter.get(); + transformAttrs.push_back(splitKIterAttr); + + toLDSRowCol = TopDownTMBuilder::below(splitKIter, splitKIterAttr); + + // d = d_iter * dWaves * numBlksInD * inputSpanLen + wave_d * numBlksInD * + // inputSpanLen + blk_d * inputSpanLen + blk_td + toLDSRowCol.unmerge("d", 0, {"d_iter", thisWaveDim, "blk_d", "blk_td"}, + {dRepeats, dWaves, numBlksInD, inputSpanLen}); + + // k = outer_k * instrK + blk_k * kBase + inner_k * kVec + k_vec + // This ensures K access pattern matches transpose load: + // - outer_k selects which MFMA iteration + // - blk_k selects which K block (0..kBase-1 or kBase..instrK-1) + // - inner_k * kVec + k_vec gives position within the K block + toLDSRowCol.unmerge("k", 1, {"outer_k", "blk_k", "inner_k", "k_vec"}, + {numMfmaIters, numBlksInK, kIterPerBlk, kVec}); + } else { // Standard formula for regular load scenarios toLDSRowCol = TopDownTMBuilder::below(splitWaveId, splitWaveIdAttr); diff --git a/mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp b/mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp index 01e426c55db7..18297c496dd9 100644 --- a/mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp +++ b/mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp @@ -21,13 +21,14 @@ // to the LDS transpose load operation in an accelerator-friendly layout. // // It is intended to simplify the IR generation logic and ensure -// consistent handling of f16/bf16/fp8/bf8 matrix accelerator tile loads from +// consistent handling of f16/bf16/fp8/bf8/i8 matrix accelerator tile loads from // LDS memory. // // Supported element types: // - f16, bf16: uses ds_read_tr16_b64 (returns 4 elements per thread) // - f8E4M3FN, f8E5M2 (OCP FP8): uses ds_read_tr8_b64 (returns 8 elements per -// thread) +// thread) +// - i8 (INT8): uses ds_read_tr8_b64 (returns 8 elements per thread) // //===----------------------------------------------------------------------===// @@ -55,9 +56,10 @@ bool archSupported(StringRef arch) { return arch.contains("gfx950"); } // Check if element type is supported for LDS transpose load // - f16, bf16: ds_read_tr16_b64 (4 elements) // - f8E4M3FN, f8E5M2 (OCP FP8 for gfx950): ds_read_tr8_b64 (8 elements) +// - i8 (INT8 for gfx950): ds_read_tr8_b64 (8 elements) static bool isSupportedElementType(Type t) { return t.isF16() || t.isBF16() || isa(t) || - isa(t); + isa(t) || t.isInteger(8); } // Check if element type is 8-bit float (FP8 E4M3 or BF8 E5M2) @@ -68,21 +70,34 @@ static bool isFp8Type(Type t) { return isa(t) || isa(t); } +// Check if element type is INT8 (i8) +// INT8 uses ds_read_tr8_b64 instruction like FP8 +static bool isInt8Type(Type t) { return t.isInteger(8); } + +// Check if element type uses 8-bit transpose load (FP8, BF8, or INT8) +// All these types use ds_read_tr8_b64 which returns 8 elements per thread +static bool uses8BitTransposeLoad(Type t) { + return isFp8Type(t) || isInt8Type(t); +} + // Returns the number of elements returned by LDS transpose load instruction static int64_t getTransposeLoadVectorLength(Type elemType) { if (elemType.isF16() || elemType.isBF16()) { return 4; // ds_read_tr16_b64 - } else if (isFp8Type(elemType)) { - return 8; // ds_read_tr8_b64 + } else if (uses8BitTransposeLoad(elemType)) { + return 8; // ds_read_tr8_b64 (FP8, BF8, INT8) } llvm_unreachable("Unsupported element type for LDS transpose load"); } // Validates MFMA geometry for LDS transpose support. -// Only specific combinations are supported: (16,16), (16,32), (32,8), (32,16) +// Supported combinations: +// - F16/BF16: (16,16), (16,32), (32,8), (32,16) +// - FP8/BF8: (16,32), (32,16) +// - INT8: (16,32), (16,64), (32,16), (32,32) static bool isValidMfmaGeometry(int64_t dDim, int64_t kDim) { - return (dDim == 16 && (kDim == 16 || kDim == 32)) || - (dDim == 32 && (kDim == 8 || kDim == 16)); + return (dDim == 16 && (kDim == 16 || kDim == 32 || kDim == 64)) || + (dDim == 32 && (kDim == 8 || kDim == 16 || kDim == 32)); } // Shape of a single MFMA instruction (internal use only). @@ -334,7 +349,7 @@ LDSTransposeConfigAttr buildTransposeAttrFromParams( "MFMA geometry must be set when building transpose attributes"); assert(isValidMfmaGeometry(mfmaDDim, mfmaKDim) && "Invalid MFMA geometry for LDS transpose - valid: (16,16), (16,32), " - "(32,8), (32,16)"); + "(16,64), (32,8), (32,16), (32,32)"); // Create structured attribute with all parameters return LDSTransposeConfigAttr::get(rewriter.getContext(), mfmaDDim, mfmaKDim, @@ -431,14 +446,29 @@ static Value getDoubleRateKOffsetBase(PatternRewriter &b, Location loc, Value kOffsetBase; if (dDim == 32 && kDim == 16) { - // 32x16 layout + // 32x16 layout (F16/BF16) kOffsetBase = arith::MulIOp::create( b, loc, arith::DivUIOp::create(b, loc, blockId, c2), c8); } else if (dDim == 16 && kDim == 32) { - // 16x32 layout + // 16x32 layout (F16/BF16) kOffsetBase = arith::MulIOp::create(b, loc, blockId, c8); + } else if (dDim == 16 && kDim == 64) { + // 16x64 layout (INT8): k_offset_base = block_id * 16 + // Each block covers 16 K values, with first/second call covering + // K=0..7/8..15 within each block's 16-K range + Value c16val = arith::ConstantIndexOp::create(b, loc, 16); + kOffsetBase = arith::MulIOp::create(b, loc, blockId, c16val); + } else if (dDim == 32 && kDim == 32) { + // 32x32 layout (INT8): k_offset_base = k_block * 16 + // where k_block = block_id / 2 + // Block 0,1 (k_block=0) cover K=0..15, Block 2,3 (k_block=1) cover K=16..31 + // First/second call covers K+0..7 / K+8..15 within each range + Value c16val = arith::ConstantIndexOp::create(b, loc, 16); + Value kBlock = arith::DivUIOp::create(b, loc, blockId, c2); + kOffsetBase = arith::MulIOp::create(b, loc, kBlock, c16val); } else { - llvm_unreachable("Invalid double-rate geometry - must be 32x16 or 16x32"); + llvm_unreachable( + "Invalid double-rate geometry - must be 32x16, 16x32, 16x64, or 32x32"); } return kOffsetBase; @@ -455,15 +485,17 @@ static Value getDoubleRateKOffsetBase(PatternRewriter &b, Location loc, // Single-rate: k_final = k_base_local + (kTileIdx * kTileStride) // Double-rate: k_final = k_base_local + kOffsetBase + (kTileIdx * // kTileStride) + halfOffset -// where halfOffset = 0 for low half, 4 for high half +// where halfOffset = 0 for low half, 4 for F16/BF16 high half, +// 8 for INT8 high half // // Parameters: -// isDoubleRate - Whether this is a double-rate layout (L32x16, L16x32) -// kBaseLocal - Local K base offset from computeLDSBaseOffsets() +// isDoubleRate - Whether this is a double-rate layout (L32x16, L16x32, +// L16x64) kBaseLocal - Local K base offset from computeLDSBaseOffsets() // kOffsetBase - Double-rate K offset base (from getDoubleRateKOffsetBase) // kTileIdx - Current K tile index (0, 1, 2, ...) -// kTileStride - K stride per tile (instrK, e.g., 8 or 16) -// isHighHalf - For double-rate: true = high half (+4), false = low half +// kTileStride - K stride per tile (instrK, e.g., 8, 16, or 64) +// isHighHalf - For double-rate: true = high half, false = low half +// halfOffset - Offset to add for high half (4 for F16/BF16, 8 for INT8) // // Returns: // Final K offset value to use for emitPanelLoad() @@ -471,13 +503,13 @@ static Value getDoubleRateKOffsetBase(PatternRewriter &b, Location loc, static Value computePanelFinalOffset(PatternRewriter &b, Location loc, bool isDoubleRate, Value kBaseLocal, Value kOffsetBase, int64_t kTileIdx, - Value kTileStride, - bool isHighHalf = false) { + Value kTileStride, bool isHighHalf = false, + int64_t halfOffset = 4) { Value kBase = kBaseLocal; if (isDoubleRate) { - // Double-rate: k_offset = kOffsetBase + kTileIdx * kTileStride [+ 4 for - // high] + // Double-rate: k_offset = kOffsetBase + kTileIdx * kTileStride [+ + // halfOffset] Value kTileOffset; if (kTileIdx > 0) { Value kIdxVal = arith::ConstantIndexOp::create(b, loc, kTileIdx); @@ -488,10 +520,10 @@ static Value computePanelFinalOffset(PatternRewriter &b, Location loc, Value k_offset = arith::AddIOp::create(b, loc, kOffsetBase, kTileOffset); - // For high half, add 4 + // For high half, add halfOffset (4 for F16/BF16, 8 for INT8) if (isHighHalf) { - Value c4 = arith::ConstantIndexOp::create(b, loc, 4); - k_offset = arith::AddIOp::create(b, loc, k_offset, c4); + Value cHalfOffset = arith::ConstantIndexOp::create(b, loc, halfOffset); + k_offset = arith::AddIOp::create(b, loc, k_offset, cHalfOffset); } // k_base = k_base_local + k_offset @@ -648,22 +680,24 @@ static SmallVector getBasePanelOffsets(PatternRewriter &b, Location loc, Value kOffsetBase, mOffsetBase; - if (isFp8Type(elemType)) { + // 8-bit types (FP8, BF8, INT8) use ds_read_tr8_b64 + // They share the same offset formulas for 16x32 and 32x16 geometries + if (uses8BitTransposeLoad(elemType)) { Value c8 = arith::ConstantIndexOp::create(b, loc, 8); + Value blockId = arith::DivUIOp::create(b, loc, lane, c16); + Value laneInBlock = arith::RemUIOp::create(b, loc, lane, c16); + Value kLocal = arith::DivUIOp::create(b, loc, laneInBlock, c2); + Value mParity = arith::RemUIOp::create(b, loc, laneInBlock, c2); + if (dDim == 16 && kDim == 32) { - // FP8/BF8 16x32: 4-block formula + // 16x32: 4-block formula (FP8/BF8/INT8) // Block layout: // Block 0 (T0-T15): K=0..7 // Block 1 (T16-T31): K=8..15 // Block 2 (T32-T47): K=16..23 // Block 3 (T48-T63): K=24..31 - Value blockId = arith::DivUIOp::create(b, loc, lane, c16); - Value laneInBlock = arith::RemUIOp::create(b, loc, lane, c16); - Value kLocal = arith::DivUIOp::create(b, loc, laneInBlock, c2); - Value mParity = arith::RemUIOp::create(b, loc, laneInBlock, c2); - // kOffsetBase = k_local + block_id * 8 Value blockKOffset = arith::MulIOp::create(b, loc, blockId, c8); kOffsetBase = arith::AddIOp::create(b, loc, kLocal, blockKOffset); @@ -672,18 +706,13 @@ static SmallVector getBasePanelOffsets(PatternRewriter &b, Location loc, mOffsetBase = arith::MulIOp::create(b, loc, mParity, c8); } else if (dDim == 32 && kDim == 16) { - // FP8 32x16: 4-block formula (VERIFIED from HIP testing) + // 32x16: 4-block formula (FP8/BF8/INT8) // Block layout: // Block 0 (T0-T15): M=0..15, K=0..7 // Block 1 (T16-T31): M=16..31, K=0..7 // Block 2 (T32-T47): M=0..15, K=8..15 // Block 3 (T48-T63): M=16..31, K=8..15 - Value blockId = arith::DivUIOp::create(b, loc, lane, c16); - Value laneInBlock = arith::RemUIOp::create(b, loc, lane, c16); - Value kLocal = arith::DivUIOp::create(b, loc, laneInBlock, c2); - Value mParity = arith::RemUIOp::create(b, loc, laneInBlock, c2); - // m_block = block_id % 2, k_block = block_id / 2 Value mBlock = arith::RemUIOp::create(b, loc, blockId, c2); Value kBlock = arith::DivUIOp::create(b, loc, blockId, c2); @@ -697,8 +726,30 @@ static SmallVector getBasePanelOffsets(PatternRewriter &b, Location loc, Value mBlockOffset = arith::MulIOp::create(b, loc, mBlock, c16); mOffsetBase = arith::AddIOp::create(b, loc, mParityOffset, mBlockOffset); + } else if (isInt8Type(elemType) && dDim == 16 && kDim == 64) { + // INT8 16x64: Double-rate layout (INT8 only) + // k_base_local = k_local (block offset comes from + // getDoubleRateKOffsetBase) m_base = m_parity * 8 + + kOffsetBase = kLocal; + mOffsetBase = arith::MulIOp::create(b, loc, mParity, c8); + + } else if (isInt8Type(elemType) && dDim == 32 && kDim == 32) { + // INT8 32x32: Double-rate layout (INT8 only) + // k_base_local = k_local (k_block * 16 comes from + // getDoubleRateKOffsetBase) m_base = m_parity * 8 + m_block * 16 + + Value mBlock = arith::RemUIOp::create(b, loc, blockId, c2); + + kOffsetBase = kLocal; + + Value mParityOffset = arith::MulIOp::create(b, loc, mParity, c8); + Value mBlockOffset = arith::MulIOp::create(b, loc, mBlock, c16); + mOffsetBase = arith::AddIOp::create(b, loc, mParityOffset, mBlockOffset); + } else { - llvm_unreachable("Unsupported FP8 MFMA geometry in getBasePanelOffsets"); + llvm_unreachable( + "Unsupported 8-bit type MFMA geometry in getBasePanelOffsets"); } } else { @@ -1256,12 +1307,20 @@ LogicalResult emitThreadwiseHWTranspose(PatternRewriter &b, int64_t ldsStride = (operand == OperandKind::A) ? mPerBlock : nPerBlock; // Determine if this is a double-rate instruction - // Double-rate ONLY for (32,16) and (16,32) MFMA with F16/BF16 - // FP8/BF8 uses ds_read_tr8_b64 which returns 8 elements, so (16,32) and - // (32,16) are SINGLE-RATE for FP8/BF8 (16,16) and (32,8) are always - // SINGLE-RATE - bool isDoubleRate = !isFp8Type(elemType) && ((dDim == 32 && instrK == 16) || - (dDim == 16 && instrK == 32)); + // Double-rate layouts require TWO ds_read calls per K tile: + // - F16/BF16 (32,16) and (16,32): ds_read_tr16_b64 returns 4 elements, + // need 2 calls for 8 elements total + // - INT8 (16,64) and (32,32): ds_read_tr8_b64 returns 8 elements, need 2 + // calls for 16 i8 values (packed into v4i32 for mfma_i32_16x16x64_i8) + // Single-rate layouts: + // - FP8/BF8/INT8 (16,32), (32,16): ds_read_tr8_b64 returns 8 elements + // directly + // - F16/BF16 (16,16), (32,8): single ds_read_tr16_b64 per K tile + bool isDoubleRate = + (!uses8BitTransposeLoad(elemType) && + ((dDim == 32 && instrK == 16) || (dDim == 16 && instrK == 32))) || + (isInt8Type(elemType) && + ((dDim == 16 && instrK == 64) || (dDim == 32 && instrK == 32))); // Determine vector length based on element type: // - f16/bf16: ds_read_tr16_b64 returns vector<4> @@ -1342,16 +1401,22 @@ LogicalResult emitThreadwiseHWTranspose(PatternRewriter &b, panelVectors.push_back(panelVec); } else { - // DOUBLE-RATE (L32x16, L16x32 with F16/BF16 only): TWO loads per K tile - // Each load returns vector<4> for f16/bf16, total 8 elements per K tile - // Note: FP8/BF8 is NEVER double-rate (ds_read_tr8_b64 returns 8 - // elements) Compute K offsets for low and high halves + // DOUBLE-RATE: TWO loads per K tile + // - F16/BF16 (L32x16, L16x32): vector<4> per load, halfOffset=4 + // - INT8 (L16x64): vector<8> per load, halfOffset=8 + // FP8/BF8 is NEVER double-rate (single ds_read_tr8_b64 returns 8 elems) + + // Determine halfOffset based on element type: + // - F16/BF16: 4 (each ds_read_tr16_b64 returns 4 elements) + // - INT8: 8 (each ds_read_tr8_b64 returns 8 elements) + int64_t halfOffsetVal = isInt8Type(elemType) ? 8 : 4; + Value k_base_low = computePanelFinalOffset( b, loc, isDoubleRate, k_base_local, kOffsetBase, kIdx, - kTileStrideVal, /*isHighHalf=*/false); + kTileStrideVal, /*isHighHalf=*/false, halfOffsetVal); Value k_base_high = computePanelFinalOffset( b, loc, isDoubleRate, k_base_local, kOffsetBase, kIdx, - kTileStrideVal, /*isHighHalf=*/true); + kTileStrideVal, /*isHighHalf=*/true, halfOffsetVal); // Emit low half load Value panelVecLow = emitPanelLoad(b, loc, rawSrc, k_base_low, m_base, diff --git a/mlir/test/Dialect/Rock/lds_transpose_attributes.mlir b/mlir/test/Dialect/Rock/lds_transpose_attributes.mlir index 5dfc22d8ac21..3d1a2d96ced4 100644 --- a/mlir/test/Dialect/Rock/lds_transpose_attributes.mlir +++ b/mlir/test/Dialect/Rock/lds_transpose_attributes.mlir @@ -65,3 +65,71 @@ module attributes {mhal.arch = "amdgcn-amd-amdhsa:gfx950"} { // CHECK-SAME: memref<256x8xf16, #gpu.address_space> -> memref<8xf16, #gpu.address_space> // CHECK: rock.threadwise_read_into {forceUnroll, ldsTransposeConfig = #rock.lds_transpose_config, useIndexDiffs} // CHECK-SAME: memref<256x32xf16, #gpu.address_space> -> memref<32xf16, #gpu.address_space> + +// ----- + +// Test INT8 32x32 MFMA (mfma_i32_32x32x32_i8) with LDS transpose attributes +#params_int8_32x32 = #rock.accel_gemm_params< + kpackPerBlock = 32, mPerBlock = 32, nPerBlock = 32, + kpack = 1, mPerWave = 32, nPerWave = 32, + mnPerXdl = 32, splitKFactor = 1, scheduleVersion = 3, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true> + +module attributes {mhal.arch = "amdgcn-amd-amdhsa:gfx950"} { + // CHECK-LABEL: func.func @test_lds_transpose_int8_32x32 + func.func @test_lds_transpose_int8_32x32( + %arg0: memref<1024xi8>, + %arg1: memref<1024xi8>, + %arg2: memref<1024xi32>) + attributes {block_size = 64 : i32, grid_size = 1 : i32, + enable_splitk_for_tuning, kernel, + num_cu = 256 : i64} { + %a = rock.transform %arg0 by (d1 * 32 + d2)> by [ ["raw"] at [0]>, [] at []>] bounds = [1, 32, 32] -> [1024]> : memref<1024xi8> to memref<1x32x32xi8> + %b = rock.transform %arg1 by (d1 * 32 + d2)> by [ ["raw"] at [0]>, [] at []>] bounds = [1, 32, 32] -> [1024]> : memref<1024xi8> to memref<1x32x32xi8> + %c = rock.transform %arg2 by (d1 * 32 + d2)> by [ ["raw"] at [0]>, [] at []>] bounds = [1, 32, 32] -> [1024]> : memref<1024xi32> to memref<1x32x32xi32> + + rock.gridwise_gemm_accel(%a, %b, %c) + storeMethod(set) + features = mfma|dot|atomic_add|atomic_add_bf16|atomic_add_f16|direct_to_lds_32b|direct_to_lds_128b + {blockSize = 64 : i32, gridSize = 1 : i32, params = #params_int8_32x32} + : memref<1x32x32xi8>, memref<1x32x32xi8>, memref<1x32x32xi32> + return + } +} + +// CHECK: rock.threadwise_read_into {forceUnroll, ldsTransposeConfig = #rock.lds_transpose_config, useIndexDiffs} +// CHECK: rock.threadwise_read_into {forceUnroll, ldsTransposeConfig = #rock.lds_transpose_config, useIndexDiffs} + +// ----- + +// Test INT8 16x64 MFMA (mfma_i32_16x16x64_i8) with LDS transpose attributes +#params_int8_16x64 = #rock.accel_gemm_params< + kpackPerBlock = 64, mPerBlock = 16, nPerBlock = 16, + kpack = 1, mPerWave = 16, nPerWave = 16, + mnPerXdl = 16, splitKFactor = 1, scheduleVersion = 3, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true> + +module attributes {mhal.arch = "amdgcn-amd-amdhsa:gfx950"} { + // CHECK-LABEL: func.func @test_lds_transpose_int8_16x64 + func.func @test_lds_transpose_int8_16x64( + %arg0: memref<1024xi8>, + %arg1: memref<1024xi8>, + %arg2: memref<256xi32>) + attributes {block_size = 64 : i32, grid_size = 1 : i32, + enable_splitk_for_tuning, kernel, + num_cu = 256 : i64} { + %a = rock.transform %arg0 by (d1 * 16 + d2)> by [ ["raw"] at [0]>, [] at []>] bounds = [1, 64, 16] -> [1024]> : memref<1024xi8> to memref<1x64x16xi8> + %b = rock.transform %arg1 by (d1 * 16 + d2)> by [ ["raw"] at [0]>, [] at []>] bounds = [1, 64, 16] -> [1024]> : memref<1024xi8> to memref<1x64x16xi8> + %c = rock.transform %arg2 by (d1 * 16 + d2)> by [ ["raw"] at [0]>, [] at []>] bounds = [1, 16, 16] -> [256]> : memref<256xi32> to memref<1x16x16xi32> + + rock.gridwise_gemm_accel(%a, %b, %c) + storeMethod(set) + features = mfma|dot|atomic_add|atomic_add_bf16|atomic_add_f16|direct_to_lds_32b|direct_to_lds_128b + {blockSize = 64 : i32, gridSize = 1 : i32, params = #params_int8_16x64} + : memref<1x64x16xi8>, memref<1x64x16xi8>, memref<1x16x16xi32> + return + } +} + +// CHECK: rock.threadwise_read_into {forceUnroll, ldsTransposeConfig = #rock.lds_transpose_config, useIndexDiffs} +// CHECK: rock.threadwise_read_into {forceUnroll, ldsTransposeConfig = #rock.lds_transpose_config, useIndexDiffs} diff --git a/mlir/test/Dialect/Rock/lds_transpose_error.mlir b/mlir/test/Dialect/Rock/lds_transpose_error.mlir index efc7cfc42fa4..61e0195fdecc 100644 --- a/mlir/test/Dialect/Rock/lds_transpose_error.mlir +++ b/mlir/test/Dialect/Rock/lds_transpose_error.mlir @@ -2,13 +2,13 @@ // Error case: Invalid MFMA geometry (16x8 is not valid) // This tests that LDSTransposeConfigAttr::verify() catches invalid MFMA -// geometry combinations. Valid combinations are: (16,16), (16,32), (32,8), (32,16) +// geometry combinations. Valid combinations are: (16,16), (16,32), (16,64), (32,8), (32,16), (32,32) func.func @threadwise_read_into_invalid_mfma_geometry_16x8( %source: memref<128xf16, #gpu.address_space>, %dest: memref<8xf16, #gpu.address_space>) attributes {arch = "amdgcn-amd-amdhsa:gfx950"} { rock.threadwise_read_into { - // expected-error @+1 {{invalid MFMA geometry (16x8) for LDS transpose - valid combinations: (16,16), (16,32), (32,8), (32,16)}} + // expected-error @+1 {{invalid MFMA geometry (16x8) for LDS transpose - valid combinations: (16,16), (16,32), (16,64), (32,8), (32,16), (32,32)}} ldsTransposeConfig = #rock.lds_transpose_config< dDim = 16, kDim = 8, mPerBlock = 128, nPerBlock = 128, kPerBlock = 32, @@ -21,15 +21,15 @@ func.func @threadwise_read_into_invalid_mfma_geometry_16x8( // ----- -// Error case: Invalid MFMA geometry (32x32 is not valid) -func.func @threadwise_read_into_invalid_mfma_geometry_32x32( +// Error case: Invalid MFMA geometry (64x16 is not valid) +func.func @threadwise_read_into_invalid_mfma_geometry_64x16( %source: memref<128xf16, #gpu.address_space>, %dest: memref<8xf16, #gpu.address_space>) attributes {arch = "amdgcn-amd-amdhsa:gfx950"} { rock.threadwise_read_into { - // expected-error @+1 {{invalid MFMA geometry (32x32) for LDS transpose - valid combinations: (16,16), (16,32), (32,8), (32,16)}} + // expected-error @+1 {{invalid MFMA geometry (64x16) for LDS transpose - valid combinations: (16,16), (16,32), (16,64), (32,8), (32,16), (32,32)}} ldsTransposeConfig = #rock.lds_transpose_config< - dDim = 32, kDim = 32, + dDim = 64, kDim = 16, mPerBlock = 128, nPerBlock = 128, kPerBlock = 32, mPerWave = 64, nPerWave = 64, doubleBuffering = false, isOperandA = true @@ -46,7 +46,7 @@ func.func @threadwise_read_into_invalid_mfma_geometry_8x8( %dest: memref<8xf16, #gpu.address_space>) attributes {arch = "amdgcn-amd-amdhsa:gfx950"} { rock.threadwise_read_into { - // expected-error @+1 {{invalid MFMA geometry (8x8) for LDS transpose - valid combinations: (16,16), (16,32), (32,8), (32,16)}} + // expected-error @+1 {{invalid MFMA geometry (8x8) for LDS transpose - valid combinations: (16,16), (16,32), (16,64), (32,8), (32,16), (32,32)}} ldsTransposeConfig = #rock.lds_transpose_config< dDim = 8, kDim = 8, mPerBlock = 128, nPerBlock = 128, kPerBlock = 32, diff --git a/mlir/test/Dialect/Rock/lowering_load_transpose_lds.mlir b/mlir/test/Dialect/Rock/lowering_load_transpose_lds.mlir index b6a64a6985f7..20521a23a52b 100644 --- a/mlir/test/Dialect/Rock/lowering_load_transpose_lds.mlir +++ b/mlir/test/Dialect/Rock/lowering_load_transpose_lds.mlir @@ -28,4 +28,11 @@ module attributes {mhal.arch = "amdgcn-amd-amdhsa:gfx950"} { %v = rock.lds_transpose_load %src[%i, %j] : memref<64x128xf8E5M2, #gpu.address_space> -> vector<8xf8E5M2> return %v : vector<8xf8E5M2> } + +// CHECK-LABEL: func @test_load_transpose_i8 + func.func @test_load_transpose_i8(%src: memref<128x256xi8, #gpu.address_space>, %i: index, %j: index) -> vector<8xi8> { + // CHECK: amdgpu.transpose_load %arg0[%arg1, %arg2] : memref<128x256xi8, #gpu.address_space> -> vector<8xi8> + %v = rock.lds_transpose_load %src[%i, %j] : memref<128x256xi8, #gpu.address_space> -> vector<8xi8> + return %v : vector<8xi8> + } } diff --git a/mlir/test/e2e/CMakeLists.txt b/mlir/test/e2e/CMakeLists.txt index 7f500161c660..59e644b38f03 100644 --- a/mlir/test/e2e/CMakeLists.txt +++ b/mlir/test/e2e/CMakeLists.txt @@ -51,7 +51,9 @@ if (ROCMLIR_DRIVER_PR_E2E_TEST_ENABLED) PrGemmDirectToLDS PrLdsTransposeLoad PrLdsTransposeLoadFp8 + PrLdsTransposeLoadI8 PrLdsTransposeLoadAttention + PrLdsTransposeLoadAttentionI8 PrConvDirectToLDS PrAttentionDirectToLDS ) diff --git a/mlir/test/e2e/PrLdsTransposeLoadAttentionI8.cfg b/mlir/test/e2e/PrLdsTransposeLoadAttentionI8.cfg new file mode 100644 index 000000000000..46909aa10a02 --- /dev/null +++ b/mlir/test/e2e/PrLdsTransposeLoadAttentionI8.cfg @@ -0,0 +1,2 @@ +if not 'lds_transpose_load' in config.features: + config.unsupported = True diff --git a/mlir/test/e2e/PrLdsTransposeLoadAttentionI8.toml b/mlir/test/e2e/PrLdsTransposeLoadAttentionI8.toml new file mode 100644 index 000000000000..31f83121bb04 --- /dev/null +++ b/mlir/test/e2e/PrLdsTransposeLoadAttentionI8.toml @@ -0,0 +1,47 @@ +directory = "PrLdsTransposeLoadAttentionI8" +prefix = "rocmlir-gen" +suffix = "--operation attention -t i8 --arch %arch -pv %random_data %rocmlir_gen_flags -RMS_threshold 0.01 | rocmlir-driver -c | mlir-runner -O2 --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext --entry-point-result=void | FileCheck %s --check-prefix=" + +# ============================================================================ +# Suite 1: Both K and Q use LDS transpose (transQ=true, transK=false) +# 1 test per MFMA = 4 tests +# ============================================================================ +[[suite]] +name = "lds_transpose_both_i8" + +[[suite.test]] +config = "-seq_len_q 64 -seq_len_k 64 -head_dim_qk 64 -head_dim_v 64 --transQ=true --transK=false -perf_config attn:v3:32,32,32,32,16,16,16,8,1,3,2,0,1" + +[[suite.test]] +config = "-seq_len_q 64 -seq_len_k 64 -head_dim_qk 128 -head_dim_v 64 --transQ=true --transK=false -perf_config attn:v3:32,32,32,64,16,16,16,8,1,3,2,0,1" + +[[suite.test]] +config = "-seq_len_q 64 -seq_len_k 64 -head_dim_qk 64 -head_dim_v 64 --transQ=true --transK=false -perf_config attn:v3:32,32,32,16,32,32,32,8,1,3,2,0,1" + +[[suite.test]] +config = "-seq_len_q 64 -seq_len_k 64 -head_dim_qk 64 -head_dim_v 64 --transQ=true --transK=false -perf_config attn:v3:32,32,32,32,32,32,32,8,1,3,2,0,1" + +# ============================================================================ +# Suite 2: Only K uses LDS transpose (transQ=false, transK=false) +# 2 tests per MFMA = 8 tests +# ============================================================================ +[[suite]] +name = "lds_transpose_k_only_i8" + +[[suite.test]] +config = "-seq_len_q 64 -seq_len_k 64 -head_dim_qk 64 -head_dim_v 64 --transQ=false --transK=false -perf_config attn:v3:32,32,32,32,16,16,16,8,1,3,2,0,1" + +[[suite.test]] +config = "-seq_len_q 128 -seq_len_k 128 -head_dim_qk 64 -head_dim_v 64 --transQ=false --transK=false -perf_config attn:v3:64,64,64,32,16,16,16,16,1,3,2,0,1" + +[[suite.test]] +config = "-seq_len_q 64 -seq_len_k 64 -head_dim_qk 128 -head_dim_v 64 --transQ=false --transK=false -perf_config attn:v3:32,32,32,64,16,16,16,8,1,3,2,0,1" + +[[suite.test]] +config = "-seq_len_q 64 -seq_len_k 64 -head_dim_qk 64 -head_dim_v 64 --transQ=false --transK=false -perf_config attn:v3:32,32,32,16,32,32,32,8,1,3,2,0,1" + +[[suite.test]] +config = "-seq_len_q 128 -seq_len_k 128 -head_dim_qk 64 -head_dim_v 64 --transQ=false --transK=false -perf_config attn:v3:64,64,64,16,32,32,32,16,1,3,2,0,1" + +[[suite.test]] +config = "-seq_len_q 64 -seq_len_k 64 -head_dim_qk 64 -head_dim_v 64 --transQ=false --transK=false -perf_config attn:v3:32,32,32,32,32,32,32,8,1,3,2,0,1" diff --git a/mlir/test/e2e/PrLdsTransposeLoadI8.cfg b/mlir/test/e2e/PrLdsTransposeLoadI8.cfg new file mode 100644 index 000000000000..46909aa10a02 --- /dev/null +++ b/mlir/test/e2e/PrLdsTransposeLoadI8.cfg @@ -0,0 +1,2 @@ +if not 'lds_transpose_load' in config.features: + config.unsupported = True diff --git a/mlir/test/e2e/PrLdsTransposeLoadI8.toml b/mlir/test/e2e/PrLdsTransposeLoadI8.toml new file mode 100644 index 000000000000..340ececfe4c1 --- /dev/null +++ b/mlir/test/e2e/PrLdsTransposeLoadI8.toml @@ -0,0 +1,86 @@ +directory = "PrLdsTransposeLoadI8" +prefix = "rocmlir-gen" +suffix = "--operation gemm --arch %arch %random_data %pv %rocmlir_gen_flags | rocmlir-driver -c | mlir-runner -O2 --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext --entry-point-result=void | FileCheck %s --check-prefix=" + +[[axis]] +name = "data type" +values = ["i8"] +prefix = "-t " + +# ============================================================================ +# Suite 1: BOTH A and B use LDS transpose (transA=true, transB=false) +# 2 tests per MFMA = 8 tests total +# ============================================================================ +[[suite]] +name = "lds_transpose_both_operands_i8" + +[[suite.test]] +config = "-g 1 -m 64 -k 128 -n 64 --transA=true --transB=false --perf_config v4:64,64,32,32,16,16,1,1,4,2,0,0,1,1" + +[[suite.test]] +config = "-g 1 -m 128 -k 256 -n 64 --transA=true --transB=false --perf_config v3:128,64,64,16,16,8,1,3,2,1,1" + +[[suite.test]] +config = "-g 1 -m 64 -k 128 -n 64 --transA=true --transB=false --perf_config v4:64,64,64,16,16,16,1,1,4,2,0,0,1,1" + +[[suite.test]] +config = "-g 1 -m 128 -k 256 -n 64 --transA=true --transB=false --perf_config v3:128,32,128,16,16,1,1,3,2,1,1" + +[[suite.test]] +config = "-g 1 -m 64 -k 64 -n 64 --transA=true --transB=false --perf_config v4:32,128,16,32,32,32,8,1,4,2,0,0,1,1" + +[[suite.test]] +config = "-g 1 -m 128 -k 128 -n 128 --transA=true --transB=false --perf_config v3:256,64,32,32,32,8,1,3,2,1,1" + +[[suite.test]] +config = "-g 1 -m 64 -k 128 -n 64 --transA=true --transB=false --perf_config v4:64,64,32,32,32,32,1,1,4,2,0,0,1,1" + +[[suite.test]] +config = "-g 1 -m 128 -k 256 -n 128 --transA=true --transB=false --perf_config v3:128,128,16,32,32,16,1,3,2,1,1" + +# ============================================================================ +# Suite 2: ONLY A uses LDS transpose (transA=true, transB=true) +# 2 tests per MFMA = 8 tests total +# ============================================================================ +[[suite]] +name = "lds_transpose_A_only_i8" + +[[suite.test]] +config = "-g 1 -m 64 -k 64 -n 64 --transA=true --transB=true --perf_config v4:64,64,32,32,16,16,1,1,4,2,0,0,1,1" + +[[suite.test]] +config = "-g 1 -m 128 -k 128 -n 64 --transA=true --transB=true --perf_config v3:128,64,64,16,16,8,1,3,2,1,1" + +[[suite.test]] +config = "-g 1 -m 64 -k 128 -n 64 --transA=true --transB=true --perf_config v4:64,64,64,16,16,16,1,1,4,2,0,0,1,1" + +[[suite.test]] +config = "-g 1 -m 64 -k 64 -n 64 --transA=true --transB=true --perf_config v4:64,64,16,32,32,32,16,1,4,2,0,0,1,1" + +[[suite.test]] +config = "-g 1 -m 128 -k 128 -n 128 --transA=true --transB=true --perf_config v3:128,128,32,32,32,8,1,3,2,1,1" + +# ============================================================================ +# Suite 3: ONLY B uses LDS transpose (transA=false, transB=false) +# 2 tests per MFMA = 8 tests total +# ============================================================================ +[[suite]] +name = "lds_transpose_B_only_i8" + +[[suite.test]] +config = "-g 1 -m 64 -k 64 -n 64 --transA=false --transB=false --perf_config v4:64,64,32,16,16,16,1,1,4,2,0,0,1,1" + +[[suite.test]] +config = "-g 1 -m 128 -k 128 -n 64 --transA=false --transB=false --perf_config v3:128,64,64,16,16,8,1,3,2,1,1" + +[[suite.test]] +config = "-g 1 -m 64 -k 128 -n 64 --transA=false --transB=false --perf_config v4:64,64,64,16,16,16,1,1,4,2,0,0,1,1" + +[[suite.test]] +config = "-g 1 -m 128 -k 128 -n 128 --transA=false --transB=false --perf_config v3:128,128,32,32,32,16,1,3,2,1,1" + +[[suite.test]] +config = "-g 1 -m 64 -k 64 -n 64 --transA=false --transB=false --perf_config v4:64,64,32,32,32,32,1,1,4,2,0,0,1,1" + +[[suite.test]] +config = "-g 1 -m 128 -k 128 -n 128 --transA=false --transB=false --perf_config v3:128,128,64,32,32,8,1,3,2,1,1" From 9170e8174e8772e6e721a2dc2ca5a6f753034fdb Mon Sep 17 00:00:00 2001 From: stefan koncarevic Date: Tue, 20 Jan 2026 05:49:40 -0600 Subject: [PATCH 11/14] Add FP8/BF8 support for LDS transpose load Implement ds_read_tr8_b64 offset formulas for FP8/BF8 MFMA (16x32, 32x16). Enable mixed fp8/bf8 type combinations for GEMM operations on gfx950. --- mlir/include/mlir/Dialect/Rock/IR/RockOps.td | 15 +- mlir/lib/Dialect/Rock/IR/RockDialect.cpp | 21 ++ .../Dialect/Rock/utility/LdsTransposeLoad.cpp | 280 ++++++++++++------ .../Rock/lowering_load_transpose_lds.mlir | 14 + mlir/test/Dialect/Rock/ops.mlir | 21 ++ 5 files changed, 261 insertions(+), 90 deletions(-) diff --git a/mlir/include/mlir/Dialect/Rock/IR/RockOps.td b/mlir/include/mlir/Dialect/Rock/IR/RockOps.td index 8d70078b06f0..a9fd3f28951f 100644 --- a/mlir/include/mlir/Dialect/Rock/IR/RockOps.td +++ b/mlir/include/mlir/Dialect/Rock/IR/RockOps.td @@ -1221,9 +1221,10 @@ defvar SameShapeVectorOfI1 = [{ def Rock_LDSTransposeLoadOp : Rock_Op<"lds_transpose_load", [DeclareOpInterfaceMethods< MemoryEffectsOpInterface>]>, - Arguments<(ins Arg, "LDS source buffer">:$source, + Arguments<(ins Arg, + "LDS source buffer">:$source, Variadic:$indices)>, - Results<(outs VectorOfLengthAndType<[4], [F16, BF16]>:$result)> { + Results<(outs AnyVectorOfNonZeroRank:$result)> { let summary = "Hardware-assisted LDS transpose load for matrix accelerator tile"; let description = [{ @@ -1233,8 +1234,14 @@ def Rock_LDSTransposeLoadOp instruction geometry (`dDim × instrK`), where: - `dDim` is the accelerator M/N dimension (e.g., 16 or 32) - `instrK` is the accelerator K dimension (e.g., 8, 16, or 32) - The operation returns a vector of 4 elements per thread containing - transposed elements in a layout suitable for matrix accelerator instructions. + + For 16-bit types (f16, bf16): + - Uses ds_read_tr16_b64 instruction + - Returns vector<4xtype> (4 elements per thread) + + For 8-bit types (f8E4M3FN, f8E5M2 - OCP FP8 for gfx950): + - Uses ds_read_tr8_b64 instruction + - Returns vector<8xtype> (8 elements per thread) Benefits: - Reduces LDS bank conflicts through optimized access patterns diff --git a/mlir/lib/Dialect/Rock/IR/RockDialect.cpp b/mlir/lib/Dialect/Rock/IR/RockDialect.cpp index 5dadcd437f1f..b740d537106e 100644 --- a/mlir/lib/Dialect/Rock/IR/RockDialect.cpp +++ b/mlir/lib/Dialect/Rock/IR/RockDialect.cpp @@ -2166,6 +2166,27 @@ LogicalResult LDSTransposeLoadOp::verify() { << srcElemType << ")"; } + // Verify result vector length based on element type: + // - 16-bit types (f16, bf16): ds_read_tr16_b64 returns 4 elements + // - 8-bit types (f8E4M3FN, f8E5M2 - OCP FP8 for gfx950): ds_read_tr8_b64 + // returns 8 elements + int64_t expectedVecLen; + if (srcElemType.isF16() || srcElemType.isBF16()) { + expectedVecLen = 4; + } else if (isa(srcElemType) || + isa(srcElemType)) { + expectedVecLen = 8; + } else { + return emitOpError("unsupported element type for LDS transpose load: ") + << srcElemType; + } + + if (resultType.getNumElements() != expectedVecLen) { + return emitOpError("expected result vector of ") + << expectedVecLen << " elements for " << srcElemType + << " type, but got " << resultType.getNumElements(); + } + // Check hardware support using AmdArchDb StringRef arch = rock::getArchValue(*this); AmdArchInfo archInfo = rock::lookupArchInfo(arch); diff --git a/mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp b/mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp index d4f021dc2715..01e426c55db7 100644 --- a/mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp +++ b/mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp @@ -21,8 +21,13 @@ // to the LDS transpose load operation in an accelerator-friendly layout. // // It is intended to simplify the IR generation logic and ensure -// consistent handling of f16/bf16 matrix accelerator tile loads from LDS -// memory. +// consistent handling of f16/bf16/fp8/bf8 matrix accelerator tile loads from +// LDS memory. +// +// Supported element types: +// - f16, bf16: uses ds_read_tr16_b64 (returns 4 elements per thread) +// - f8E4M3FN, f8E5M2 (OCP FP8): uses ds_read_tr8_b64 (returns 8 elements per +// thread) // //===----------------------------------------------------------------------===// @@ -47,6 +52,32 @@ namespace { bool archSupported(StringRef arch) { return arch.contains("gfx950"); } +// Check if element type is supported for LDS transpose load +// - f16, bf16: ds_read_tr16_b64 (4 elements) +// - f8E4M3FN, f8E5M2 (OCP FP8 for gfx950): ds_read_tr8_b64 (8 elements) +static bool isSupportedElementType(Type t) { + return t.isF16() || t.isBF16() || isa(t) || + isa(t); +} + +// Check if element type is 8-bit float (FP8 E4M3 or BF8 E5M2) +// Used for: +// 1. Selecting ds_read_tr8_b64 vs ds_read_tr16_b64 +// 2. Checking mixed-type compatibility (fp8+bf8 combinations are valid) +static bool isFp8Type(Type t) { + return isa(t) || isa(t); +} + +// Returns the number of elements returned by LDS transpose load instruction +static int64_t getTransposeLoadVectorLength(Type elemType) { + if (elemType.isF16() || elemType.isBF16()) { + return 4; // ds_read_tr16_b64 + } else if (isFp8Type(elemType)) { + return 8; // ds_read_tr8_b64 + } + llvm_unreachable("Unsupported element type for LDS transpose load"); +} + // Validates MFMA geometry for LDS transpose support. // Only specific combinations are supported: (16,16), (16,32), (32,8), (32,16) static bool isValidMfmaGeometry(int64_t dDim, int64_t kDim) { @@ -119,7 +150,16 @@ static Decision makeDecision(StringRef arch, Type elemTypeA, Type elemTypeB, return dec; } - if (elemTypeA != elemTypeB || !(elemTypeA.isF16() || elemTypeA.isBF16())) + // Check type compatibility: + // - Same types are always allowed (f16==f16, fp8==fp8, etc.) + // - Mixed FP8/BF8 combinations are allowed (hardware supports mixed fp8/bf8 + // MFMA: + // mfma_f32_16x16x32_fp8_fp8, mfma_f32_16x16x32_fp8_bf8, etc.) + // - Other mixed types are NOT allowed (e.g., f16 with fp8) + bool typesCompatible = (elemTypeA == elemTypeB) || + (isFp8Type(elemTypeA) && isFp8Type(elemTypeB)); + if (!typesCompatible || !isSupportedElementType(elemTypeA) || + !isSupportedElementType(elemTypeB)) return dec; // Validate MFMA geometry @@ -477,8 +517,9 @@ static Value computePanelFinalOffset(PatternRewriter &b, Location loc, // emitPanelLoad - Emit an LDS transpose load operation // // Computes the final LDS offset and emits a hardware LDS transpose load -// instruction (ds_read_tr16_b64). This instruction always returns vector<4xf16> -// regardless of the layout. +// instruction: +// - ds_read_tr16_b64 for f16/bf16: returns vector<4> +// - ds_read_tr8_b64 for fp8/bf8: returns vector<8> // // The final offset is computed as: final_offset = k_base * ldsStride + m_base // where ldsStride depends on the operand (mPerBlock for A, nPerBlock for B). @@ -491,10 +532,10 @@ static Value computePanelFinalOffset(PatternRewriter &b, Location loc, // kBase - K dimension base offset for this panel // mBase - M/N dimension base offset for this panel // ldsStride - Stride between K rows in LDS (mPerBlock or nPerBlock) -// panelVecType - Result type (always vector<4xf16> or vector<4xbf16>) +// panelVecType - Result type (vector<4> for f16/bf16, vector<8> for fp8/bf8) // // Returns: -// The loaded panel vector (vector<4xf16/bf16>) +// The loaded panel vector (vector<4> for f16/bf16, vector<8> for fp8/bf8) //===----------------------------------------------------------------------===// static Value emitPanelLoad(PatternRewriter &b, Location loc, Value rawSrc, Value kBase, Value mBase, Value ldsStride, @@ -503,7 +544,7 @@ static Value emitPanelLoad(PatternRewriter &b, Location loc, Value rawSrc, Value kOffset = arith::MulIOp::create(b, loc, kBase, ldsStride); Value finalOffset = arith::AddIOp::create(b, loc, mBase, kOffset); - // Emit hardware LDS transpose load: ds_read_tr16_b64 + // Emit hardware LDS transpose load auto loadOp = rock::LDSTransposeLoadOp::create(b, loc, panelVecType, rawSrc, ValueRange{finalOffset}); @@ -518,9 +559,9 @@ static Value emitPanelLoad(PatternRewriter &b, Location loc, Value rawSrc, // elements (ds_read_tr16_b64 always returns vector<4xf16>). // // Parameters: -// panelVectors - Array of loaded panel vectors (each is vector<4xf16>) -// dest - Destination memref (rank-1, scalar layout) -// targetElems - Maximum number of elements to write +// panelVectors - Array of loaded panel vectors (vector<4> for f16/bf16, +// vector<8> for fp8/bf8) dest - Destination memref (rank-1, scalar +// layout) targetElems - Maximum number of elements to write // // Returns: // success() if all target elements were written @@ -534,14 +575,16 @@ writePanelVectorsToDestination(PatternRewriter &b, Location loc, int64_t produced = 0; // Extract elements per vector from the actual vector type - // Hardware instruction ds_read_tr16_b64 always returns vector<4xf16> + // Hardware instructions: + // - ds_read_tr16_b64 returns vector<4xf16/bf16> + // - ds_read_tr8_b64 returns vector<8xfp8> assert(!panelVectors.empty() && "Panel vectors array must not be empty"); auto panelVecType = cast(panelVectors[0].getType()); int64_t elementsPerVector = panelVecType.getShape()[0]; - // Verify hardware constraint: ds_read_tr16_b64 returns exactly 4 elements - assert(elementsPerVector == 4 && - "LDS transpose load must produce vector<4xf16> per panel"); + // Verify hardware constraint: 4 elements for 16-bit, 8 elements for 8-bit + assert((elementsPerVector == 4 || elementsPerVector == 8) && + "LDS transpose load must produce vector<4> or vector<8> per panel"); LLVM_DEBUG(llvm::dbgs() << "[lds_transpose] Writing " << panelVectors.size() << " panel vectors (" @@ -597,60 +640,114 @@ writePanelVectorsToDestination(PatternRewriter &b, Location loc, //===----------------------------------------------------------------------===// static SmallVector getBasePanelOffsets(PatternRewriter &b, Location loc, int64_t dDim, int64_t kDim, - Value lane) { - Value c16 = arith::ConstantIndexOp::create(b, loc, 16); - Value c4 = arith::ConstantIndexOp::create(b, loc, 4); + Value lane, Type elemType) { + // Common constants used by both FP8 and F16/BF16 Value c2 = arith::ConstantIndexOp::create(b, loc, 2); + Value c4 = arith::ConstantIndexOp::create(b, loc, 4); + Value c16 = arith::ConstantIndexOp::create(b, loc, 16); - Value blockId = arith::DivUIOp::create(b, loc, lane, c16); - Value laneInBlock = arith::RemUIOp::create(b, loc, lane, c16); - - // Base offset calculations - Value mOffsetBase = arith::MulIOp::create( - b, loc, arith::RemUIOp::create(b, loc, laneInBlock, c4), c4); - Value kOffsetBase = arith::DivUIOp::create(b, loc, laneInBlock, c4); + Value kOffsetBase, mOffsetBase; + + if (isFp8Type(elemType)) { + Value c8 = arith::ConstantIndexOp::create(b, loc, 8); + + if (dDim == 16 && kDim == 32) { + // FP8/BF8 16x32: 4-block formula + // Block layout: + // Block 0 (T0-T15): K=0..7 + // Block 1 (T16-T31): K=8..15 + // Block 2 (T32-T47): K=16..23 + // Block 3 (T48-T63): K=24..31 + + Value blockId = arith::DivUIOp::create(b, loc, lane, c16); + Value laneInBlock = arith::RemUIOp::create(b, loc, lane, c16); + Value kLocal = arith::DivUIOp::create(b, loc, laneInBlock, c2); + Value mParity = arith::RemUIOp::create(b, loc, laneInBlock, c2); + + // kOffsetBase = k_local + block_id * 8 + Value blockKOffset = arith::MulIOp::create(b, loc, blockId, c8); + kOffsetBase = arith::AddIOp::create(b, loc, kLocal, blockKOffset); + + // mOffsetBase = m_parity * 8 + mOffsetBase = arith::MulIOp::create(b, loc, mParity, c8); + + } else if (dDim == 32 && kDim == 16) { + // FP8 32x16: 4-block formula (VERIFIED from HIP testing) + // Block layout: + // Block 0 (T0-T15): M=0..15, K=0..7 + // Block 1 (T16-T31): M=16..31, K=0..7 + // Block 2 (T32-T47): M=0..15, K=8..15 + // Block 3 (T48-T63): M=16..31, K=8..15 + + Value blockId = arith::DivUIOp::create(b, loc, lane, c16); + Value laneInBlock = arith::RemUIOp::create(b, loc, lane, c16); + Value kLocal = arith::DivUIOp::create(b, loc, laneInBlock, c2); + Value mParity = arith::RemUIOp::create(b, loc, laneInBlock, c2); + + // m_block = block_id % 2, k_block = block_id / 2 + Value mBlock = arith::RemUIOp::create(b, loc, blockId, c2); + Value kBlock = arith::DivUIOp::create(b, loc, blockId, c2); + + // kOffsetBase = k_local + k_block * 8 + Value kBlockOffset = arith::MulIOp::create(b, loc, kBlock, c8); + kOffsetBase = arith::AddIOp::create(b, loc, kLocal, kBlockOffset); + + // mOffsetBase = m_parity * 8 + m_block * 16 + Value mParityOffset = arith::MulIOp::create(b, loc, mParity, c8); + Value mBlockOffset = arith::MulIOp::create(b, loc, mBlock, c16); + mOffsetBase = arith::AddIOp::create(b, loc, mParityOffset, mBlockOffset); - SmallVector panelOffsets; + } else { + llvm_unreachable("Unsupported FP8 MFMA geometry in getBasePanelOffsets"); + } - if (dDim == 16 && kDim == 32) { - // 16x32 layout - panelOffsets = {kOffsetBase, mOffsetBase}; - } else if (dDim == 16 && kDim == 16) { - // 16x16 layout - // kbase = kOffsetBase + (blockId * 4) - Value kBase = arith::AddIOp::create( - b, loc, arith::MulIOp::create(b, loc, blockId, c4), kOffsetBase); - panelOffsets = {kBase, mOffsetBase}; - } else if (dDim == 32 && kDim == 16) { - // 32x16 layout - // mbase = mOffsetBase + (blockId % 2) * 16 - Value mBase = arith::AddIOp::create( - b, loc, - arith::MulIOp::create(b, loc, - arith::RemUIOp::create(b, loc, blockId, c2), c16), - mOffsetBase); - panelOffsets = {kOffsetBase, mBase}; - } else if (dDim == 32 && kDim == 8) { - // 32x8 layout - // k_base_local = kOffsetBase + (blockId / 2) * 4 - Value kBase = arith::AddIOp::create( - b, loc, - arith::MulIOp::create(b, loc, - arith::DivUIOp::create(b, loc, blockId, c2), c4), - kOffsetBase); - - // m_offset_base = mOffsetBase + (blockId % 2) * 16 - Value mBase = arith::AddIOp::create( - b, loc, - arith::MulIOp::create(b, loc, - arith::RemUIOp::create(b, loc, blockId, c2), c16), - mOffsetBase); - panelOffsets = {kBase, mBase}; } else { - llvm_unreachable("Unsupported MFMA geometry in getBasePanelOffsets"); + // F16/BF16 uses block-based lane mapping + Value blockId = arith::DivUIOp::create(b, loc, lane, c16); + Value laneInBlock = arith::RemUIOp::create(b, loc, lane, c16); + + // Base calculations common to all F16/BF16 geometries + Value kLocal = arith::DivUIOp::create(b, loc, laneInBlock, c4); + Value mLocal = arith::MulIOp::create( + b, loc, arith::RemUIOp::create(b, loc, laneInBlock, c4), c4); + + if (dDim == 16 && kDim == 32) { + // 16x32: direct mapping + kOffsetBase = kLocal; + mOffsetBase = mLocal; + + } else if (dDim == 16 && kDim == 16) { + // 16x16: k += blockId * 4 + kOffsetBase = arith::AddIOp::create( + b, loc, kLocal, arith::MulIOp::create(b, loc, blockId, c4)); + mOffsetBase = mLocal; + + } else if (dDim == 32 && kDim == 16) { + // 32x16: m += (blockId % 2) * 16 + kOffsetBase = kLocal; + mOffsetBase = arith::AddIOp::create( + b, loc, mLocal, + arith::MulIOp::create( + b, loc, arith::RemUIOp::create(b, loc, blockId, c2), c16)); + + } else if (dDim == 32 && kDim == 8) { + // 32x8: k += (blockId / 2) * 4, m += (blockId % 2) * 16 + kOffsetBase = arith::AddIOp::create( + b, loc, kLocal, + arith::MulIOp::create( + b, loc, arith::DivUIOp::create(b, loc, blockId, c2), c4)); + mOffsetBase = arith::AddIOp::create( + b, loc, mLocal, + arith::MulIOp::create( + b, loc, arith::RemUIOp::create(b, loc, blockId, c2), c16)); + + } else { + llvm_unreachable( + "Unsupported F16/BF16 MFMA geometry in getBasePanelOffsets"); + } } - return panelOffsets; + return {kOffsetBase, mOffsetBase}; } //===----------------------------------------------------------------------===// @@ -671,6 +768,7 @@ static SmallVector getBasePanelOffsets(PatternRewriter &b, Location loc, // dDim - MFMA D dimension (M or N, 16 or 32) // kDim - MFMA K dimension (8, 16, or 32) // lane - Thread's lane ID within the workgroup +// elemType - Element type (f16, bf16, fp8, or bf8) for selecting lane mapping // // Returns: // std::pair: @@ -679,8 +777,10 @@ static SmallVector getBasePanelOffsets(PatternRewriter &b, Location loc, //===----------------------------------------------------------------------===// static std::pair computeLDSBaseOffsets(PatternRewriter &b, Location loc, int64_t dDim, - int64_t kDim, Value lane) { - SmallVector offsets = getBasePanelOffsets(b, loc, dDim, kDim, lane); + int64_t kDim, Value lane, + Type elemType) { + SmallVector offsets = + getBasePanelOffsets(b, loc, dDim, kDim, lane, elemType); LLVM_DEBUG(llvm::dbgs() << "[lds_transpose] Computed LDS base offsets for " << dDim << "x" << kDim << ": " @@ -1086,8 +1186,9 @@ static Value computeFinalMNOffset(PatternRewriter &b, Location loc, // emitThreadwiseHWTranspose - Lower threadwise_read_into to HW transpose loads //===----------------------------------------------------------------------===// // Lowers threadwise_read_into with LDS transpose config into hardware transpose -// load instructions (ds_read_tr16_b64) that read from LDS in MFMA-friendly -// order. +// load instructions that read from LDS in MFMA-friendly order: +// - ds_read_tr16_b64 for f16/bf16 (returns vector<4>) +// - ds_read_tr8_b64 for fp8/bf8 (returns vector<8>) // // Algorithm: // 1. Extract config: MFMA geometry (dDim, kDim), tiling params, operand kind @@ -1100,8 +1201,8 @@ static Value computeFinalMNOffset(PatternRewriter &b, Location loc, // - Outer: M/N tiles (all at once for double-buffering, one at a time // otherwise) // - Inner: K tiles (1 load for single-rate, 2 loads for double-rate layouts) -// 6. For each iteration: compute final LDS offset, emit ds_read_tr16_b64 -// instruction (returns vector<4xf16>) +// 6. For each iteration: compute final LDS offset, emit LDS transpose load +// instruction (ds_read_tr16_b64 for f16/bf16, ds_read_tr8_b64 for fp8/bf8) // 7. Extract elements from panel vectors and write sequentially to destination // // Example: 16x32 layout, 1 M-tile, 2 K-tiles → 2 ds_read_tr16_b64 calls → 8 @@ -1155,23 +1256,27 @@ LogicalResult emitThreadwiseHWTranspose(PatternRewriter &b, int64_t ldsStride = (operand == OperandKind::A) ? mPerBlock : nPerBlock; // Determine if this is a double-rate instruction - // Double-rate ONLY for (32,16) and (16,32) MFMA - // (16,16) and (32,8) are SINGLE-RATE - bool isDoubleRate = - (dDim == 32 && instrK == 16) || (dDim == 16 && instrK == 32); - - // Each ds_read_tr16_b64 call ALWAYS returns vector<4xf16> - // For double-rate, we make 2 calls and store all 8 elements separately - VectorType panelVecType = VectorType::get({4}, elemType); + // Double-rate ONLY for (32,16) and (16,32) MFMA with F16/BF16 + // FP8/BF8 uses ds_read_tr8_b64 which returns 8 elements, so (16,32) and + // (32,16) are SINGLE-RATE for FP8/BF8 (16,16) and (32,8) are always + // SINGLE-RATE + bool isDoubleRate = !isFp8Type(elemType) && ((dDim == 32 && instrK == 16) || + (dDim == 16 && instrK == 32)); + + // Determine vector length based on element type: + // - f16/bf16: ds_read_tr16_b64 returns vector<4> + // - fp8/bf8: ds_read_tr8_b64 returns vector<8> + int64_t vecLen = getTransposeLoadVectorLength(elemType); + VectorType panelVecType = VectorType::get({vecLen}, elemType); // panelVectors will contain: - // - Single-rate: 1 vector<4xf16> per K tile - // - Double-rate: 2 vector<4xf16> per K tile (low + high) + // - Single-rate: 1 vector per K tile + // - Double-rate (f16/bf16 only): 2 vectors per K tile (low + high) SmallVector panelVectors; // Get base offsets using computeLDSBaseOffsets helper auto [k_base_local, m_offset_base] = - computeLDSBaseOffsets(b, loc, dDim, instrK, lane); + computeLDSBaseOffsets(b, loc, dDim, instrK, lane, elemType); // K stride per tile: instrK (MFMA K dimension) int64_t kTileStride = instrK; @@ -1226,20 +1331,21 @@ LogicalResult emitThreadwiseHWTranspose(PatternRewriter &b, useDynamicMnIndex, waveOffsetStrideVal, tileOffsetStrideVal); if (!isDoubleRate) { - // SINGLE-RATE (L32x8, L16x16): One load per K tile + // SINGLE-RATE (L32x8, L16x16, or FP8/BF8): One load per K tile Value k_base = computePanelFinalOffset(b, loc, isDoubleRate, k_base_local, kOffsetBase, kIdx, kTileStrideVal); - // Emit LDS transpose load for this K tile (single-rate: one per K tile) + // Emit LDS transpose load for this K tile Value panelVec = emitPanelLoad(b, loc, rawSrc, k_base, m_base, ldsStrideVal, panelVecType); panelVectors.push_back(panelVec); } else { - // DOUBLE-RATE (L32x16, L16x32): TWO loads per K tile - // Each load returns vector<4xf16>, total 8 elements per K tile - // Compute K offsets for low and high halves + // DOUBLE-RATE (L32x16, L16x32 with F16/BF16 only): TWO loads per K tile + // Each load returns vector<4> for f16/bf16, total 8 elements per K tile + // Note: FP8/BF8 is NEVER double-rate (ds_read_tr8_b64 returns 8 + // elements) Compute K offsets for low and high halves Value k_base_low = computePanelFinalOffset( b, loc, isDoubleRate, k_base_local, kOffsetBase, kIdx, kTileStrideVal, /*isHighHalf=*/false); @@ -1269,8 +1375,10 @@ LogicalResult emitThreadwiseHWTranspose(PatternRewriter &b, int64_t loadsPerKTile = isDoubleRate ? 2 : 1; int64_t expectedLoads = actualMnTiles * kPanels * loadsPerKTile; - // Each load ALWAYS produces 4 elements (ds_read_tr16_b64 → vector<4xf16>) - int64_t sliceElems = expectedLoads * 4; + // Each load produces vecLen elements: + // - f16/bf16: 4 elements (ds_read_tr16_b64) + // - fp8/bf8: 8 elements (ds_read_tr8_b64) + int64_t sliceElems = expectedLoads * vecLen; // Verify we generated the expected number of loads if (panelVectors.size() != (size_t)expectedLoads) { diff --git a/mlir/test/Dialect/Rock/lowering_load_transpose_lds.mlir b/mlir/test/Dialect/Rock/lowering_load_transpose_lds.mlir index 29a2f190a06b..b6a64a6985f7 100644 --- a/mlir/test/Dialect/Rock/lowering_load_transpose_lds.mlir +++ b/mlir/test/Dialect/Rock/lowering_load_transpose_lds.mlir @@ -14,4 +14,18 @@ module attributes {mhal.arch = "amdgcn-amd-amdhsa:gfx950"} { %v = rock.lds_transpose_load %src[%i, %j] : memref<64x128xbf16, #gpu.address_space> -> vector<4xbf16> return %v : vector<4xbf16> } + +// CHECK-LABEL: func @test_load_transpose_fp8_e4m3 + func.func @test_load_transpose_fp8_e4m3(%src: memref<128x256xf8E4M3FN, #gpu.address_space>, %i: index, %j: index) -> vector<8xf8E4M3FN> { + // CHECK: amdgpu.transpose_load %arg0[%arg1, %arg2] : memref<128x256xf8E4M3FN, #gpu.address_space> -> vector<8xf8E4M3FN> + %v = rock.lds_transpose_load %src[%i, %j] : memref<128x256xf8E4M3FN, #gpu.address_space> -> vector<8xf8E4M3FN> + return %v : vector<8xf8E4M3FN> + } + +// CHECK-LABEL: func @test_load_transpose_fp8_e5m2 + func.func @test_load_transpose_fp8_e5m2(%src: memref<64x128xf8E5M2, #gpu.address_space>, %i: index, %j: index) -> vector<8xf8E5M2> { + // CHECK: amdgpu.transpose_load %arg0[%arg1, %arg2] : memref<64x128xf8E5M2, #gpu.address_space> -> vector<8xf8E5M2> + %v = rock.lds_transpose_load %src[%i, %j] : memref<64x128xf8E5M2, #gpu.address_space> -> vector<8xf8E5M2> + return %v : vector<8xf8E5M2> + } } diff --git a/mlir/test/Dialect/Rock/ops.mlir b/mlir/test/Dialect/Rock/ops.mlir index fafe8fa665d6..1a56f6352788 100644 --- a/mlir/test/Dialect/Rock/ops.mlir +++ b/mlir/test/Dialect/Rock/ops.mlir @@ -433,6 +433,27 @@ func.func @rock_lds_transpose_load_full_arch(%lds_buffer: memref<128x64xf16, #gp return } +// CHECK-LABEL: func.func @rock_lds_transpose_load_fp8_e4m3 +// CHECK: rock.lds_transpose_load +func.func @rock_lds_transpose_load_fp8_e4m3(%lds_buffer: memref<128x64xf8E4M3FN, #gpu.address_space>) + attributes {arch = "amdgcn-amd-amdhsa:gfx950"} { + %c0 = arith.constant 0 : index + %fragment = rock.lds_transpose_load %lds_buffer[%c0, %c0] + : memref<128x64xf8E4M3FN, #gpu.address_space> -> vector<8xf8E4M3FN> + return +} + +// CHECK-LABEL: func.func @rock_lds_transpose_load_fp8_e5m2 +// CHECK: rock.lds_transpose_load +func.func @rock_lds_transpose_load_fp8_e5m2(%lds_buffer: memref<256x128xf8E5M2, #gpu.address_space>) + attributes {arch = "amdgcn-amd-amdhsa:gfx950"} { + %c0 = arith.constant 0 : index + %c32 = arith.constant 32 : index + %fragment = rock.lds_transpose_load %lds_buffer[%c32, %c0] + : memref<256x128xf8E5M2, #gpu.address_space> -> vector<8xf8E5M2> + return +} + // CHECK-LABEL: func.func @test_lds_transpose_config_attr_16x32 // CHECK: ldsTransposeConfig = #rock.lds_transpose_config func.func @test_lds_transpose_config_attr_16x32(%src: memref<8192xf16, #gpu.address_space>, From 149c441226f9d59573048136be8e2d0c62224099 Mon Sep 17 00:00:00 2001 From: stefan koncarevic Date: Tue, 20 Jan 2026 08:34:42 -0600 Subject: [PATCH 12/14] Add PR CI tests for FP8/BF8 LDS transpose load GEMM operations --- mlir/test/e2e/CMakeLists.txt | 1 + mlir/test/e2e/PrLdsTransposeLoadFp8.cfg | 2 + mlir/test/e2e/PrLdsTransposeLoadFp8.toml | 126 +++++++++++++++++++++++ 3 files changed, 129 insertions(+) create mode 100644 mlir/test/e2e/PrLdsTransposeLoadFp8.cfg create mode 100644 mlir/test/e2e/PrLdsTransposeLoadFp8.toml diff --git a/mlir/test/e2e/CMakeLists.txt b/mlir/test/e2e/CMakeLists.txt index 317918b9cc68..7f500161c660 100644 --- a/mlir/test/e2e/CMakeLists.txt +++ b/mlir/test/e2e/CMakeLists.txt @@ -50,6 +50,7 @@ if (ROCMLIR_DRIVER_PR_E2E_TEST_ENABLED) PrConvElementwiseGemmBF16SplitK PrGemmDirectToLDS PrLdsTransposeLoad + PrLdsTransposeLoadFp8 PrLdsTransposeLoadAttention PrConvDirectToLDS PrAttentionDirectToLDS diff --git a/mlir/test/e2e/PrLdsTransposeLoadFp8.cfg b/mlir/test/e2e/PrLdsTransposeLoadFp8.cfg new file mode 100644 index 000000000000..46909aa10a02 --- /dev/null +++ b/mlir/test/e2e/PrLdsTransposeLoadFp8.cfg @@ -0,0 +1,2 @@ +if not 'lds_transpose_load' in config.features: + config.unsupported = True diff --git a/mlir/test/e2e/PrLdsTransposeLoadFp8.toml b/mlir/test/e2e/PrLdsTransposeLoadFp8.toml new file mode 100644 index 000000000000..41b36b149e6a --- /dev/null +++ b/mlir/test/e2e/PrLdsTransposeLoadFp8.toml @@ -0,0 +1,126 @@ +directory = "PrLdsTransposeLoadFp8" +prefix = "rocmlir-gen" +suffix = "--operation gemm --arch %arch %pv %constrained_float_range_random_data %rocmlir_gen_flags | rocmlir-driver -c | mlir-runner -O2 --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext --entry-point-result=void | FileCheck %s --check-prefix=" + +[[axis]] +name = "data type" +values = ["fp8_fp8", "bf8_bf8", "fp8_bf8", "bf8_fp8"] +prefix = "-t " + +# ============================================================================ +# Suite 1: BOTH A and B use LDS transpose (transA=true, transB=false) +# 5 tests covering different MFMA geometries and kpack/kPerBlock combinations +# Only fp8_fp8 and bf8_bf8 +# ============================================================================ +[[suite]] +name = "lds_transpose_both_operands_fp8" + +[[suite.test]] +config = "-g 1 -m 64 -k 128 -n 64 --transA=true --transB=false --perf_config v3:64,64,32,16,16,8,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +[[suite.test]] +config = "-g 1 -m 128 -k 64 -n 64 --transA=true --transB=false --perf_config v3:128,64,16,32,16,16,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +[[suite.test]] +config = "-g 1 -m 256 -k 256 -n 64 --transA=true --transB=false --perf_config v3:256,64,8,64,16,32,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +[[suite.test]] +config = "-g 1 -m 128 -k 128 -n 128 --transA=true --transB=false --perf_config v3:128,128,16,32,32,8,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +[[suite.test]] +config = "-g 1 -m 256 -k 128 -n 128 --transA=true --transB=false --perf_config v3:256,128,32,64,32,8,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +# ============================================================================ +# Suite 2: ONLY A uses LDS transpose (transA=true, transB=true) +# 3 tests - B uses regular load +# Only fp8_fp8 and bf8_bf8 +# ============================================================================ +[[suite]] +name = "lds_transpose_A_only_fp8" + +[[suite.test]] +config = "-g 1 -m 64 -k 64 -n 64 --transA=true --transB=true --perf_config v3:64,64,16,16,16,8,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +[[suite.test]] +config = "-g 1 -m 128 -k 128 -n 64 --transA=true --transB=true --perf_config v3:128,64,8,32,16,16,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +[[suite.test]] +config = "-g 1 -m 128 -k 256 -n 128 --transA=true --transB=true --perf_config v3:128,128,32,32,32,8,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +# ============================================================================ +# Suite 3: ONLY B uses LDS transpose (transA=false, transB=false) +# 3 tests - A uses regular load +# Only fp8_fp8 and bf8_bf8 +# ============================================================================ +[[suite]] +name = "lds_transpose_B_only_fp8" + +[[suite.test]] +config = "-g 1 -m 64 -k 64 -n 64 --transA=false --transB=false --perf_config v3:64,64,16,16,16,8,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +[[suite.test]] +config = "-g 1 -m 128 -k 128 -n 64 --transA=false --transB=false --perf_config v3:128,64,32,32,16,1,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +[[suite.test]] +config = "-g 1 -m 128 -k 256 -n 128 --transA=false --transB=false --perf_config v3:128,128,8,32,32,32,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +# ============================================================================ +# Suite 4: Mixed FP8/BF8 types (fp8_bf8 and bf8_fp8 only) +# 3 tests: both operands, only A, only B +# ============================================================================ +[[suite]] +name = "lds_transpose_mixed_fp8_bf8" + +# Test 1: Mixed types, both operands use LDS transpose, 16x32 MFMA +[[suite.test]] +config = "-g 1 -m 64 -k 128 -n 64 --transA=true --transB=false --perf_config v3:64,64,32,16,16,8,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_fp8", "bf8_bf8"] + +# Test 2: Mixed types, only A uses LDS transpose, 32x16 MFMA +[[suite.test]] +config = "-g 1 -m 128 -k 128 -n 128 --transA=true --transB=true --perf_config v3:128,128,16,32,32,8,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_fp8", "bf8_bf8"] + +# Test 3: Mixed types, only B uses LDS transpose, 16x32 MFMA +[[suite.test]] +config = "-g 1 -m 64 -k 64 -n 64 --transA=false --transB=false --perf_config v3:64,64,16,16,16,8,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_fp8", "bf8_bf8"] From a75ab7a010ae93b2806606699fb99c06f9384311 Mon Sep 17 00:00:00 2001 From: stefankoncarevic Date: Thu, 29 Jan 2026 08:07:35 -0600 Subject: [PATCH 13/14] Add FP8 GEMM heuristic to selectively disable LDS transpose Disable LDS transpose for FP8 GEMM when K >= 1280 or small square matrices (K == N < 512) to avoid performance regressions while preserving compile time benefits.Add FP8 GEMM heuristic to selectively disable LDS transpose --- .../Transforms/GridwiseGemmToBlockwise.cpp | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp index 49ea90588062..bc0fb9c66878 100644 --- a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp @@ -3354,6 +3354,33 @@ struct GridwiseGemmAccelRewritePattern directToLDS, ldsLayoutConfigA, ldsLayoutConfigB, mPerBlock, nPerBlock, kPerBlock, mPerWave, nPerWave, kpack, doubleBuffering); + // FP8 heuristic: check if LDS transpose should be disabled based on dims + // Rule 1: K >= 1280 causes regression, EXCEPT when K > otherDim AND + // otherDim > 1280 + // Rule 2: Small square K-N matrices (K == N && K < 512) + auto shouldDisableLdsTranspose = [K, N](int64_t otherDim) -> bool { + if (K >= 1280 && !(K > otherDim && otherDim > 1280)) + return true; + if (K == N && K < 512) + return true; + return false; + }; + + bool isFp8Type = isa(elementTypeA) || + isa(elementTypeA); + + if (isFp8Type) { + // Case 1: TransA=false, TransB=false - only B uses LDS transpose + if (!ldsDecision.enableA && ldsDecision.enableB && + shouldDisableLdsTranspose(N)) + ldsDecision.enableB = false; + + // Case 2: TransA=true, TransB=true - only A uses LDS transpose + if (ldsDecision.enableA && !ldsDecision.enableB && + shouldDisableLdsTranspose(M)) + ldsDecision.enableA = false; + } + LLVM_DEBUG(llvm::dbgs() << "M: " << M << "\n" << "N: " << N << "\n" From da7db783c8ddbed84cecf6172eb3a2d537d02804 Mon Sep 17 00:00:00 2001 From: stefankoncarevic Date: Fri, 30 Jan 2026 11:48:56 -0600 Subject: [PATCH 14/14] Add INT8 CONV heuristic to disable LDS transpose for N=1600 patterns Disable LDS transpose load for INT8 convolutions when N=1600 (40x40 spatial output) and K<=M or K>2*M. This fixes two significant performance regressions: - 1x64x40x40 K=64: -62.87% regression - 1x384x40x40 K=128: -43.79% regression The heuristic has no impact on GEMM INT8 (no problems with N=1600) and does not affect any CONV INT8 improvements. --- .../Rock/Transforms/GridwiseGemmToBlockwise.cpp | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp index bc0fb9c66878..eeee5ff9012a 100644 --- a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp @@ -2459,7 +2459,7 @@ struct GridwiseAttentionAccelRewritePattern ldsLayoutCfgMG0.doRotateWithK, ldsLayoutCfgMG0.doSwapThreadIterSubDims, ldsLayoutCfgMG0.ldsLayoutDxK, directToLDS, /*splitKAcrossThreadsFirst=*/false, gemm0G, gemm0M, gemm0InMPerThread, - /*ldsTransposeEnabled=*/ldsDecisionGemm0.enableA, + /*ldsTransposeEnabled=*/false, /*accelDDim=*/ldsDecisionGemm0.mfmaDDim, /*accelKDim=*/ldsDecisionGemm0.mfmaKDim); @@ -2468,7 +2468,7 @@ struct GridwiseAttentionAccelRewritePattern ldsLayoutCfgNG0.doRotateWithK, ldsLayoutCfgNG0.doSwapThreadIterSubDims, ldsLayoutCfgNG0.ldsLayoutDxK, directToLDSQ, /*splitKAcrossThreadsFirst=*/false, gemm0G, gemm0N, gemm0InNPerThread, - /*ldsTransposeEnabled=*/ldsDecisionGemm0.enableB, + /*ldsTransposeEnabled=*/false, /*accelDDim=*/ldsDecisionGemm0.mfmaDDim, /*accelKDim=*/ldsDecisionGemm0.mfmaKDim); @@ -3381,6 +3381,14 @@ struct GridwiseGemmAccelRewritePattern ldsDecision.enableA = false; } + // INT8 CONV heuristic: disable LDS transpose for N=1600 (40x40 spatial) + // when K<=M or K>2*M. + bool isInt8Type = elementTypeA.isInteger(8); + if (isInt8Type && N == 1600 && (K <= M || K > 2 * M)) { + ldsDecision.enableA = false; + ldsDecision.enableB = false; + } + LLVM_DEBUG(llvm::dbgs() << "M: " << M << "\n" << "N: " << N << "\n"