From b551602f2002250b7319a696eb7016fcb7305c2c Mon Sep 17 00:00:00 2001 From: stefan koncarevic Date: Tue, 20 Jan 2026 05:49:40 -0600 Subject: [PATCH 1/4] Add FP8/BF8 support for LDS transpose load Implement ds_read_tr8_b64 offset formulas for FP8/BF8 MFMA (16x32, 32x16). Enable mixed fp8/bf8 type combinations for GEMM operations on gfx950. --- mlir/include/mlir/Dialect/Rock/IR/RockOps.td | 15 +- mlir/lib/Dialect/Rock/IR/RockDialect.cpp | 21 ++ .../Dialect/Rock/utility/LdsTransposeLoad.cpp | 280 ++++++++++++------ .../Rock/lowering_load_transpose_lds.mlir | 14 + mlir/test/Dialect/Rock/ops.mlir | 21 ++ 5 files changed, 261 insertions(+), 90 deletions(-) diff --git a/mlir/include/mlir/Dialect/Rock/IR/RockOps.td b/mlir/include/mlir/Dialect/Rock/IR/RockOps.td index 8d70078b06f0..a9fd3f28951f 100644 --- a/mlir/include/mlir/Dialect/Rock/IR/RockOps.td +++ b/mlir/include/mlir/Dialect/Rock/IR/RockOps.td @@ -1221,9 +1221,10 @@ defvar SameShapeVectorOfI1 = [{ def Rock_LDSTransposeLoadOp : Rock_Op<"lds_transpose_load", [DeclareOpInterfaceMethods< MemoryEffectsOpInterface>]>, - Arguments<(ins Arg, "LDS source buffer">:$source, + Arguments<(ins Arg, + "LDS source buffer">:$source, Variadic:$indices)>, - Results<(outs VectorOfLengthAndType<[4], [F16, BF16]>:$result)> { + Results<(outs AnyVectorOfNonZeroRank:$result)> { let summary = "Hardware-assisted LDS transpose load for matrix accelerator tile"; let description = [{ @@ -1233,8 +1234,14 @@ def Rock_LDSTransposeLoadOp instruction geometry (`dDim × instrK`), where: - `dDim` is the accelerator M/N dimension (e.g., 16 or 32) - `instrK` is the accelerator K dimension (e.g., 8, 16, or 32) - The operation returns a vector of 4 elements per thread containing - transposed elements in a layout suitable for matrix accelerator instructions. + + For 16-bit types (f16, bf16): + - Uses ds_read_tr16_b64 instruction + - Returns vector<4xtype> (4 elements per thread) + + For 8-bit types (f8E4M3FN, f8E5M2 - OCP FP8 for gfx950): + - Uses ds_read_tr8_b64 instruction + - Returns vector<8xtype> (8 elements per thread) Benefits: - Reduces LDS bank conflicts through optimized access patterns diff --git a/mlir/lib/Dialect/Rock/IR/RockDialect.cpp b/mlir/lib/Dialect/Rock/IR/RockDialect.cpp index 5b7c3131817e..27cce5cc9681 100644 --- a/mlir/lib/Dialect/Rock/IR/RockDialect.cpp +++ b/mlir/lib/Dialect/Rock/IR/RockDialect.cpp @@ -2161,6 +2161,27 @@ LogicalResult LDSTransposeLoadOp::verify() { << srcElemType << ")"; } + // Verify result vector length based on element type: + // - 16-bit types (f16, bf16): ds_read_tr16_b64 returns 4 elements + // - 8-bit types (f8E4M3FN, f8E5M2 - OCP FP8 for gfx950): ds_read_tr8_b64 + // returns 8 elements + int64_t expectedVecLen; + if (srcElemType.isF16() || srcElemType.isBF16()) { + expectedVecLen = 4; + } else if (isa(srcElemType) || + isa(srcElemType)) { + expectedVecLen = 8; + } else { + return emitOpError("unsupported element type for LDS transpose load: ") + << srcElemType; + } + + if (resultType.getNumElements() != expectedVecLen) { + return emitOpError("expected result vector of ") + << expectedVecLen << " elements for " << srcElemType + << " type, but got " << resultType.getNumElements(); + } + // Check hardware support using AmdArchDb StringRef arch = rock::getArchValue(*this); AmdArchInfo archInfo = rock::lookupArchInfo(arch); diff --git a/mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp b/mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp index f65e1487c847..55f35c8b9084 100644 --- a/mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp +++ b/mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp @@ -21,8 +21,13 @@ // to the LDS transpose load operation in an accelerator-friendly layout. // // It is intended to simplify the IR generation logic and ensure -// consistent handling of f16/bf16 matrix accelerator tile loads from LDS -// memory. +// consistent handling of f16/bf16/fp8/bf8 matrix accelerator tile loads from +// LDS memory. +// +// Supported element types: +// - f16, bf16: uses ds_read_tr16_b64 (returns 4 elements per thread) +// - f8E4M3FN, f8E5M2 (OCP FP8): uses ds_read_tr8_b64 (returns 8 elements per +// thread) // //===----------------------------------------------------------------------===// @@ -47,6 +52,32 @@ namespace { bool archSupported(StringRef arch) { return arch.contains("gfx950"); } +// Check if element type is supported for LDS transpose load +// - f16, bf16: ds_read_tr16_b64 (4 elements) +// - f8E4M3FN, f8E5M2 (OCP FP8 for gfx950): ds_read_tr8_b64 (8 elements) +static bool isSupportedElementType(Type t) { + return t.isF16() || t.isBF16() || isa(t) || + isa(t); +} + +// Check if element type is 8-bit float (FP8 E4M3 or BF8 E5M2) +// Used for: +// 1. Selecting ds_read_tr8_b64 vs ds_read_tr16_b64 +// 2. Checking mixed-type compatibility (fp8+bf8 combinations are valid) +static bool isFp8Type(Type t) { + return isa(t) || isa(t); +} + +// Returns the number of elements returned by LDS transpose load instruction +static int64_t getTransposeLoadVectorLength(Type elemType) { + if (elemType.isF16() || elemType.isBF16()) { + return 4; // ds_read_tr16_b64 + } else if (isFp8Type(elemType)) { + return 8; // ds_read_tr8_b64 + } + llvm_unreachable("Unsupported element type for LDS transpose load"); +} + // Validates MFMA geometry for LDS transpose support. // Only specific combinations are supported: (16,16), (16,32), (32,8), (32,16) static bool isValidMfmaGeometry(int64_t dDim, int64_t kDim) { @@ -119,7 +150,16 @@ static Decision makeDecision(StringRef arch, Type elemTypeA, Type elemTypeB, return dec; } - if (elemTypeA != elemTypeB || !(elemTypeA.isF16() || elemTypeA.isBF16())) + // Check type compatibility: + // - Same types are always allowed (f16==f16, fp8==fp8, etc.) + // - Mixed FP8/BF8 combinations are allowed (hardware supports mixed fp8/bf8 + // MFMA: + // mfma_f32_16x16x32_fp8_fp8, mfma_f32_16x16x32_fp8_bf8, etc.) + // - Other mixed types are NOT allowed (e.g., f16 with fp8) + bool typesCompatible = (elemTypeA == elemTypeB) || + (isFp8Type(elemTypeA) && isFp8Type(elemTypeB)); + if (!typesCompatible || !isSupportedElementType(elemTypeA) || + !isSupportedElementType(elemTypeB)) return dec; // Validate MFMA geometry @@ -489,8 +529,9 @@ static Value computePanelFinalOffset(PatternRewriter &b, Location loc, // emitPanelLoad - Emit an LDS transpose load operation // // Computes the final LDS offset and emits a hardware LDS transpose load -// instruction (ds_read_tr16_b64). This instruction always returns vector<4xf16> -// regardless of the layout. +// instruction: +// - ds_read_tr16_b64 for f16/bf16: returns vector<4> +// - ds_read_tr8_b64 for fp8/bf8: returns vector<8> // // The final offset is computed as: final_offset = k_base * ldsStride + m_base // where ldsStride depends on the operand (mPerBlock for A, nPerBlock for B). @@ -503,10 +544,10 @@ static Value computePanelFinalOffset(PatternRewriter &b, Location loc, // kBase - K dimension base offset for this panel // mBase - M/N dimension base offset for this panel // ldsStride - Stride between K rows in LDS (mPerBlock or nPerBlock) -// panelVecType - Result type (always vector<4xf16> or vector<4xbf16>) +// panelVecType - Result type (vector<4> for f16/bf16, vector<8> for fp8/bf8) // // Returns: -// The loaded panel vector (vector<4xf16/bf16>) +// The loaded panel vector (vector<4> for f16/bf16, vector<8> for fp8/bf8) //===----------------------------------------------------------------------===// static Value emitPanelLoad(PatternRewriter &b, Location loc, Value rawSrc, Value kBase, Value mBase, Value ldsStride, @@ -515,7 +556,7 @@ static Value emitPanelLoad(PatternRewriter &b, Location loc, Value rawSrc, Value kOffset = arith::MulIOp::create(b, loc, kBase, ldsStride); Value finalOffset = arith::AddIOp::create(b, loc, mBase, kOffset); - // Emit hardware LDS transpose load: ds_read_tr16_b64 + // Emit hardware LDS transpose load auto loadOp = rock::LDSTransposeLoadOp::create(b, loc, panelVecType, rawSrc, ValueRange{finalOffset}); @@ -530,9 +571,9 @@ static Value emitPanelLoad(PatternRewriter &b, Location loc, Value rawSrc, // elements (ds_read_tr16_b64 always returns vector<4xf16>). // // Parameters: -// panelVectors - Array of loaded panel vectors (each is vector<4xf16>) -// dest - Destination memref (rank-1, scalar layout) -// targetElems - Maximum number of elements to write +// panelVectors - Array of loaded panel vectors (vector<4> for f16/bf16, +// vector<8> for fp8/bf8) dest - Destination memref (rank-1, scalar +// layout) targetElems - Maximum number of elements to write // // Returns: // success() if all target elements were written @@ -546,14 +587,16 @@ writePanelVectorsToDestination(PatternRewriter &b, Location loc, int64_t produced = 0; // Extract elements per vector from the actual vector type - // Hardware instruction ds_read_tr16_b64 always returns vector<4xf16> + // Hardware instructions: + // - ds_read_tr16_b64 returns vector<4xf16/bf16> + // - ds_read_tr8_b64 returns vector<8xfp8> assert(!panelVectors.empty() && "Panel vectors array must not be empty"); auto panelVecType = cast(panelVectors[0].getType()); int64_t elementsPerVector = panelVecType.getShape()[0]; - // Verify hardware constraint: ds_read_tr16_b64 returns exactly 4 elements - assert(elementsPerVector == 4 && - "LDS transpose load must produce vector<4xf16> per panel"); + // Verify hardware constraint: 4 elements for 16-bit, 8 elements for 8-bit + assert((elementsPerVector == 4 || elementsPerVector == 8) && + "LDS transpose load must produce vector<4> or vector<8> per panel"); LLVM_DEBUG(llvm::dbgs() << "[lds_transpose] Writing " << panelVectors.size() << " panel vectors (" @@ -609,60 +652,114 @@ writePanelVectorsToDestination(PatternRewriter &b, Location loc, //===----------------------------------------------------------------------===// static SmallVector getBasePanelOffsets(PatternRewriter &b, Location loc, int64_t dDim, int64_t kDim, - Value lane) { - Value c16 = arith::ConstantIndexOp::create(b, loc, 16); - Value c4 = arith::ConstantIndexOp::create(b, loc, 4); + Value lane, Type elemType) { + // Common constants used by both FP8 and F16/BF16 Value c2 = arith::ConstantIndexOp::create(b, loc, 2); + Value c4 = arith::ConstantIndexOp::create(b, loc, 4); + Value c16 = arith::ConstantIndexOp::create(b, loc, 16); - Value blockId = arith::DivUIOp::create(b, loc, lane, c16); - Value laneInBlock = arith::RemUIOp::create(b, loc, lane, c16); - - // Base offset calculations - Value mOffsetBase = arith::MulIOp::create( - b, loc, arith::RemUIOp::create(b, loc, laneInBlock, c4), c4); - Value kOffsetBase = arith::DivUIOp::create(b, loc, laneInBlock, c4); + Value kOffsetBase, mOffsetBase; + + if (isFp8Type(elemType)) { + Value c8 = arith::ConstantIndexOp::create(b, loc, 8); + + if (dDim == 16 && kDim == 32) { + // FP8/BF8 16x32: 4-block formula + // Block layout: + // Block 0 (T0-T15): K=0..7 + // Block 1 (T16-T31): K=8..15 + // Block 2 (T32-T47): K=16..23 + // Block 3 (T48-T63): K=24..31 + + Value blockId = arith::DivUIOp::create(b, loc, lane, c16); + Value laneInBlock = arith::RemUIOp::create(b, loc, lane, c16); + Value kLocal = arith::DivUIOp::create(b, loc, laneInBlock, c2); + Value mParity = arith::RemUIOp::create(b, loc, laneInBlock, c2); + + // kOffsetBase = k_local + block_id * 8 + Value blockKOffset = arith::MulIOp::create(b, loc, blockId, c8); + kOffsetBase = arith::AddIOp::create(b, loc, kLocal, blockKOffset); + + // mOffsetBase = m_parity * 8 + mOffsetBase = arith::MulIOp::create(b, loc, mParity, c8); + + } else if (dDim == 32 && kDim == 16) { + // FP8 32x16: 4-block formula (VERIFIED from HIP testing) + // Block layout: + // Block 0 (T0-T15): M=0..15, K=0..7 + // Block 1 (T16-T31): M=16..31, K=0..7 + // Block 2 (T32-T47): M=0..15, K=8..15 + // Block 3 (T48-T63): M=16..31, K=8..15 + + Value blockId = arith::DivUIOp::create(b, loc, lane, c16); + Value laneInBlock = arith::RemUIOp::create(b, loc, lane, c16); + Value kLocal = arith::DivUIOp::create(b, loc, laneInBlock, c2); + Value mParity = arith::RemUIOp::create(b, loc, laneInBlock, c2); + + // m_block = block_id % 2, k_block = block_id / 2 + Value mBlock = arith::RemUIOp::create(b, loc, blockId, c2); + Value kBlock = arith::DivUIOp::create(b, loc, blockId, c2); + + // kOffsetBase = k_local + k_block * 8 + Value kBlockOffset = arith::MulIOp::create(b, loc, kBlock, c8); + kOffsetBase = arith::AddIOp::create(b, loc, kLocal, kBlockOffset); + + // mOffsetBase = m_parity * 8 + m_block * 16 + Value mParityOffset = arith::MulIOp::create(b, loc, mParity, c8); + Value mBlockOffset = arith::MulIOp::create(b, loc, mBlock, c16); + mOffsetBase = arith::AddIOp::create(b, loc, mParityOffset, mBlockOffset); - SmallVector panelOffsets; + } else { + llvm_unreachable("Unsupported FP8 MFMA geometry in getBasePanelOffsets"); + } - if (dDim == 16 && kDim == 32) { - // 16x32 layout - panelOffsets = {kOffsetBase, mOffsetBase}; - } else if (dDim == 16 && kDim == 16) { - // 16x16 layout - // kbase = kOffsetBase + (blockId * 4) - Value kBase = arith::AddIOp::create( - b, loc, arith::MulIOp::create(b, loc, blockId, c4), kOffsetBase); - panelOffsets = {kBase, mOffsetBase}; - } else if (dDim == 32 && kDim == 16) { - // 32x16 layout - // mbase = mOffsetBase + (blockId % 2) * 16 - Value mBase = arith::AddIOp::create( - b, loc, - arith::MulIOp::create(b, loc, - arith::RemUIOp::create(b, loc, blockId, c2), c16), - mOffsetBase); - panelOffsets = {kOffsetBase, mBase}; - } else if (dDim == 32 && kDim == 8) { - // 32x8 layout - // k_base_local = kOffsetBase + (blockId / 2) * 4 - Value kBase = arith::AddIOp::create( - b, loc, - arith::MulIOp::create(b, loc, - arith::DivUIOp::create(b, loc, blockId, c2), c4), - kOffsetBase); - - // m_offset_base = mOffsetBase + (blockId % 2) * 16 - Value mBase = arith::AddIOp::create( - b, loc, - arith::MulIOp::create(b, loc, - arith::RemUIOp::create(b, loc, blockId, c2), c16), - mOffsetBase); - panelOffsets = {kBase, mBase}; } else { - llvm_unreachable("Unsupported MFMA geometry in getBasePanelOffsets"); + // F16/BF16 uses block-based lane mapping + Value blockId = arith::DivUIOp::create(b, loc, lane, c16); + Value laneInBlock = arith::RemUIOp::create(b, loc, lane, c16); + + // Base calculations common to all F16/BF16 geometries + Value kLocal = arith::DivUIOp::create(b, loc, laneInBlock, c4); + Value mLocal = arith::MulIOp::create( + b, loc, arith::RemUIOp::create(b, loc, laneInBlock, c4), c4); + + if (dDim == 16 && kDim == 32) { + // 16x32: direct mapping + kOffsetBase = kLocal; + mOffsetBase = mLocal; + + } else if (dDim == 16 && kDim == 16) { + // 16x16: k += blockId * 4 + kOffsetBase = arith::AddIOp::create( + b, loc, kLocal, arith::MulIOp::create(b, loc, blockId, c4)); + mOffsetBase = mLocal; + + } else if (dDim == 32 && kDim == 16) { + // 32x16: m += (blockId % 2) * 16 + kOffsetBase = kLocal; + mOffsetBase = arith::AddIOp::create( + b, loc, mLocal, + arith::MulIOp::create( + b, loc, arith::RemUIOp::create(b, loc, blockId, c2), c16)); + + } else if (dDim == 32 && kDim == 8) { + // 32x8: k += (blockId / 2) * 4, m += (blockId % 2) * 16 + kOffsetBase = arith::AddIOp::create( + b, loc, kLocal, + arith::MulIOp::create( + b, loc, arith::DivUIOp::create(b, loc, blockId, c2), c4)); + mOffsetBase = arith::AddIOp::create( + b, loc, mLocal, + arith::MulIOp::create( + b, loc, arith::RemUIOp::create(b, loc, blockId, c2), c16)); + + } else { + llvm_unreachable( + "Unsupported F16/BF16 MFMA geometry in getBasePanelOffsets"); + } } - return panelOffsets; + return {kOffsetBase, mOffsetBase}; } //===----------------------------------------------------------------------===// @@ -683,6 +780,7 @@ static SmallVector getBasePanelOffsets(PatternRewriter &b, Location loc, // dDim - MFMA D dimension (M or N, 16 or 32) // kDim - MFMA K dimension (8, 16, or 32) // lane - Thread's lane ID within the workgroup +// elemType - Element type (f16, bf16, fp8, or bf8) for selecting lane mapping // // Returns: // std::pair: @@ -691,8 +789,10 @@ static SmallVector getBasePanelOffsets(PatternRewriter &b, Location loc, //===----------------------------------------------------------------------===// static std::pair computeLDSBaseOffsets(PatternRewriter &b, Location loc, int64_t dDim, - int64_t kDim, Value lane) { - SmallVector offsets = getBasePanelOffsets(b, loc, dDim, kDim, lane); + int64_t kDim, Value lane, + Type elemType) { + SmallVector offsets = + getBasePanelOffsets(b, loc, dDim, kDim, lane, elemType); LLVM_DEBUG(llvm::dbgs() << "[lds_transpose] Computed LDS base offsets for " << dDim << "x" << kDim << ": " @@ -1098,8 +1198,9 @@ static Value computeFinalMNOffset(PatternRewriter &b, Location loc, // emitThreadwiseHWTranspose - Lower threadwise_read_into to HW transpose loads //===----------------------------------------------------------------------===// // Lowers threadwise_read_into with LDS transpose config into hardware transpose -// load instructions (ds_read_tr16_b64) that read from LDS in MFMA-friendly -// order. +// load instructions that read from LDS in MFMA-friendly order: +// - ds_read_tr16_b64 for f16/bf16 (returns vector<4>) +// - ds_read_tr8_b64 for fp8/bf8 (returns vector<8>) // // Algorithm: // 1. Extract config: MFMA geometry (dDim, kDim), tiling params, operand kind @@ -1112,8 +1213,8 @@ static Value computeFinalMNOffset(PatternRewriter &b, Location loc, // - Outer: M/N tiles (all at once for double-buffering, one at a time // otherwise) // - Inner: K tiles (1 load for single-rate, 2 loads for double-rate layouts) -// 6. For each iteration: compute final LDS offset, emit ds_read_tr16_b64 -// instruction (returns vector<4xf16>) +// 6. For each iteration: compute final LDS offset, emit LDS transpose load +// instruction (ds_read_tr16_b64 for f16/bf16, ds_read_tr8_b64 for fp8/bf8) // 7. Extract elements from panel vectors and write sequentially to destination // // Example: 16x32 layout, 1 M-tile, 2 K-tiles → 2 ds_read_tr16_b64 calls → 8 @@ -1167,23 +1268,27 @@ LogicalResult emitThreadwiseHWTranspose(PatternRewriter &b, int64_t ldsStride = (operand == OperandKind::A) ? mPerBlock : nPerBlock; // Determine if this is a double-rate instruction - // Double-rate ONLY for (32,16) and (16,32) MFMA - // (16,16) and (32,8) are SINGLE-RATE - bool isDoubleRate = - (dDim == 32 && instrK == 16) || (dDim == 16 && instrK == 32); - - // Each ds_read_tr16_b64 call ALWAYS returns vector<4xf16> - // For double-rate, we make 2 calls and store all 8 elements separately - VectorType panelVecType = VectorType::get({4}, elemType); + // Double-rate ONLY for (32,16) and (16,32) MFMA with F16/BF16 + // FP8/BF8 uses ds_read_tr8_b64 which returns 8 elements, so (16,32) and + // (32,16) are SINGLE-RATE for FP8/BF8 (16,16) and (32,8) are always + // SINGLE-RATE + bool isDoubleRate = !isFp8Type(elemType) && ((dDim == 32 && instrK == 16) || + (dDim == 16 && instrK == 32)); + + // Determine vector length based on element type: + // - f16/bf16: ds_read_tr16_b64 returns vector<4> + // - fp8/bf8: ds_read_tr8_b64 returns vector<8> + int64_t vecLen = getTransposeLoadVectorLength(elemType); + VectorType panelVecType = VectorType::get({vecLen}, elemType); // panelVectors will contain: - // - Single-rate: 1 vector<4xf16> per K tile - // - Double-rate: 2 vector<4xf16> per K tile (low + high) + // - Single-rate: 1 vector per K tile + // - Double-rate (f16/bf16 only): 2 vectors per K tile (low + high) SmallVector panelVectors; // Get base offsets using computeLDSBaseOffsets helper auto [k_base_local, m_offset_base] = - computeLDSBaseOffsets(b, loc, dDim, instrK, lane); + computeLDSBaseOffsets(b, loc, dDim, instrK, lane, elemType); // K stride per tile: instrK (MFMA K dimension) int64_t kTileStride = instrK; @@ -1238,20 +1343,21 @@ LogicalResult emitThreadwiseHWTranspose(PatternRewriter &b, useDynamicMnIndex, waveOffsetStrideVal, tileOffsetStrideVal); if (!isDoubleRate) { - // SINGLE-RATE (L32x8, L16x16): One load per K tile + // SINGLE-RATE (L32x8, L16x16, or FP8/BF8): One load per K tile Value k_base = computePanelFinalOffset(b, loc, isDoubleRate, k_base_local, kOffsetBase, kIdx, kTileStrideVal); - // Emit LDS transpose load for this K tile (single-rate: one per K tile) + // Emit LDS transpose load for this K tile Value panelVec = emitPanelLoad(b, loc, rawSrc, k_base, m_base, ldsStrideVal, panelVecType); panelVectors.push_back(panelVec); } else { - // DOUBLE-RATE (L32x16, L16x32): TWO loads per K tile - // Each load returns vector<4xf16>, total 8 elements per K tile - // Compute K offsets for low and high halves + // DOUBLE-RATE (L32x16, L16x32 with F16/BF16 only): TWO loads per K tile + // Each load returns vector<4> for f16/bf16, total 8 elements per K tile + // Note: FP8/BF8 is NEVER double-rate (ds_read_tr8_b64 returns 8 + // elements) Compute K offsets for low and high halves Value k_base_low = computePanelFinalOffset( b, loc, isDoubleRate, k_base_local, kOffsetBase, kIdx, kTileStrideVal, /*isHighHalf=*/false); @@ -1281,8 +1387,10 @@ LogicalResult emitThreadwiseHWTranspose(PatternRewriter &b, int64_t loadsPerKTile = isDoubleRate ? 2 : 1; int64_t expectedLoads = actualMnTiles * kPanels * loadsPerKTile; - // Each load ALWAYS produces 4 elements (ds_read_tr16_b64 → vector<4xf16>) - int64_t sliceElems = expectedLoads * 4; + // Each load produces vecLen elements: + // - f16/bf16: 4 elements (ds_read_tr16_b64) + // - fp8/bf8: 8 elements (ds_read_tr8_b64) + int64_t sliceElems = expectedLoads * vecLen; // Verify we generated the expected number of loads if (panelVectors.size() != (size_t)expectedLoads) { diff --git a/mlir/test/Dialect/Rock/lowering_load_transpose_lds.mlir b/mlir/test/Dialect/Rock/lowering_load_transpose_lds.mlir index 29a2f190a06b..b6a64a6985f7 100644 --- a/mlir/test/Dialect/Rock/lowering_load_transpose_lds.mlir +++ b/mlir/test/Dialect/Rock/lowering_load_transpose_lds.mlir @@ -14,4 +14,18 @@ module attributes {mhal.arch = "amdgcn-amd-amdhsa:gfx950"} { %v = rock.lds_transpose_load %src[%i, %j] : memref<64x128xbf16, #gpu.address_space> -> vector<4xbf16> return %v : vector<4xbf16> } + +// CHECK-LABEL: func @test_load_transpose_fp8_e4m3 + func.func @test_load_transpose_fp8_e4m3(%src: memref<128x256xf8E4M3FN, #gpu.address_space>, %i: index, %j: index) -> vector<8xf8E4M3FN> { + // CHECK: amdgpu.transpose_load %arg0[%arg1, %arg2] : memref<128x256xf8E4M3FN, #gpu.address_space> -> vector<8xf8E4M3FN> + %v = rock.lds_transpose_load %src[%i, %j] : memref<128x256xf8E4M3FN, #gpu.address_space> -> vector<8xf8E4M3FN> + return %v : vector<8xf8E4M3FN> + } + +// CHECK-LABEL: func @test_load_transpose_fp8_e5m2 + func.func @test_load_transpose_fp8_e5m2(%src: memref<64x128xf8E5M2, #gpu.address_space>, %i: index, %j: index) -> vector<8xf8E5M2> { + // CHECK: amdgpu.transpose_load %arg0[%arg1, %arg2] : memref<64x128xf8E5M2, #gpu.address_space> -> vector<8xf8E5M2> + %v = rock.lds_transpose_load %src[%i, %j] : memref<64x128xf8E5M2, #gpu.address_space> -> vector<8xf8E5M2> + return %v : vector<8xf8E5M2> + } } diff --git a/mlir/test/Dialect/Rock/ops.mlir b/mlir/test/Dialect/Rock/ops.mlir index fafe8fa665d6..1a56f6352788 100644 --- a/mlir/test/Dialect/Rock/ops.mlir +++ b/mlir/test/Dialect/Rock/ops.mlir @@ -433,6 +433,27 @@ func.func @rock_lds_transpose_load_full_arch(%lds_buffer: memref<128x64xf16, #gp return } +// CHECK-LABEL: func.func @rock_lds_transpose_load_fp8_e4m3 +// CHECK: rock.lds_transpose_load +func.func @rock_lds_transpose_load_fp8_e4m3(%lds_buffer: memref<128x64xf8E4M3FN, #gpu.address_space>) + attributes {arch = "amdgcn-amd-amdhsa:gfx950"} { + %c0 = arith.constant 0 : index + %fragment = rock.lds_transpose_load %lds_buffer[%c0, %c0] + : memref<128x64xf8E4M3FN, #gpu.address_space> -> vector<8xf8E4M3FN> + return +} + +// CHECK-LABEL: func.func @rock_lds_transpose_load_fp8_e5m2 +// CHECK: rock.lds_transpose_load +func.func @rock_lds_transpose_load_fp8_e5m2(%lds_buffer: memref<256x128xf8E5M2, #gpu.address_space>) + attributes {arch = "amdgcn-amd-amdhsa:gfx950"} { + %c0 = arith.constant 0 : index + %c32 = arith.constant 32 : index + %fragment = rock.lds_transpose_load %lds_buffer[%c32, %c0] + : memref<256x128xf8E5M2, #gpu.address_space> -> vector<8xf8E5M2> + return +} + // CHECK-LABEL: func.func @test_lds_transpose_config_attr_16x32 // CHECK: ldsTransposeConfig = #rock.lds_transpose_config func.func @test_lds_transpose_config_attr_16x32(%src: memref<8192xf16, #gpu.address_space>, From 473c20550965a85e89b3db689f77951e8248dbdd Mon Sep 17 00:00:00 2001 From: stefan koncarevic Date: Tue, 20 Jan 2026 08:34:42 -0600 Subject: [PATCH 2/4] Add PR CI tests for FP8/BF8 LDS transpose load GEMM operations --- mlir/test/e2e/CMakeLists.txt | 1 + mlir/test/e2e/PrLdsTransposeLoadFp8.cfg | 2 + mlir/test/e2e/PrLdsTransposeLoadFp8.toml | 126 +++++++++++++++++++++++ 3 files changed, 129 insertions(+) create mode 100644 mlir/test/e2e/PrLdsTransposeLoadFp8.cfg create mode 100644 mlir/test/e2e/PrLdsTransposeLoadFp8.toml diff --git a/mlir/test/e2e/CMakeLists.txt b/mlir/test/e2e/CMakeLists.txt index ca0d7c2473f3..1ffabef3538a 100644 --- a/mlir/test/e2e/CMakeLists.txt +++ b/mlir/test/e2e/CMakeLists.txt @@ -50,6 +50,7 @@ if (ROCMLIR_DRIVER_PR_E2E_TEST_ENABLED) PrConvElementwiseGemmBF16SplitK PrGemmDirectToLDS PrLdsTransposeLoad + PrLdsTransposeLoadFp8 PrLdsTransposeLoadAttention PrConvDirectToLDS PrAttentionDirectToLDS diff --git a/mlir/test/e2e/PrLdsTransposeLoadFp8.cfg b/mlir/test/e2e/PrLdsTransposeLoadFp8.cfg new file mode 100644 index 000000000000..46909aa10a02 --- /dev/null +++ b/mlir/test/e2e/PrLdsTransposeLoadFp8.cfg @@ -0,0 +1,2 @@ +if not 'lds_transpose_load' in config.features: + config.unsupported = True diff --git a/mlir/test/e2e/PrLdsTransposeLoadFp8.toml b/mlir/test/e2e/PrLdsTransposeLoadFp8.toml new file mode 100644 index 000000000000..41b36b149e6a --- /dev/null +++ b/mlir/test/e2e/PrLdsTransposeLoadFp8.toml @@ -0,0 +1,126 @@ +directory = "PrLdsTransposeLoadFp8" +prefix = "rocmlir-gen" +suffix = "--operation gemm --arch %arch %pv %constrained_float_range_random_data %rocmlir_gen_flags | rocmlir-driver -c | mlir-runner -O2 --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext --entry-point-result=void | FileCheck %s --check-prefix=" + +[[axis]] +name = "data type" +values = ["fp8_fp8", "bf8_bf8", "fp8_bf8", "bf8_fp8"] +prefix = "-t " + +# ============================================================================ +# Suite 1: BOTH A and B use LDS transpose (transA=true, transB=false) +# 5 tests covering different MFMA geometries and kpack/kPerBlock combinations +# Only fp8_fp8 and bf8_bf8 +# ============================================================================ +[[suite]] +name = "lds_transpose_both_operands_fp8" + +[[suite.test]] +config = "-g 1 -m 64 -k 128 -n 64 --transA=true --transB=false --perf_config v3:64,64,32,16,16,8,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +[[suite.test]] +config = "-g 1 -m 128 -k 64 -n 64 --transA=true --transB=false --perf_config v3:128,64,16,32,16,16,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +[[suite.test]] +config = "-g 1 -m 256 -k 256 -n 64 --transA=true --transB=false --perf_config v3:256,64,8,64,16,32,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +[[suite.test]] +config = "-g 1 -m 128 -k 128 -n 128 --transA=true --transB=false --perf_config v3:128,128,16,32,32,8,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +[[suite.test]] +config = "-g 1 -m 256 -k 128 -n 128 --transA=true --transB=false --perf_config v3:256,128,32,64,32,8,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +# ============================================================================ +# Suite 2: ONLY A uses LDS transpose (transA=true, transB=true) +# 3 tests - B uses regular load +# Only fp8_fp8 and bf8_bf8 +# ============================================================================ +[[suite]] +name = "lds_transpose_A_only_fp8" + +[[suite.test]] +config = "-g 1 -m 64 -k 64 -n 64 --transA=true --transB=true --perf_config v3:64,64,16,16,16,8,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +[[suite.test]] +config = "-g 1 -m 128 -k 128 -n 64 --transA=true --transB=true --perf_config v3:128,64,8,32,16,16,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +[[suite.test]] +config = "-g 1 -m 128 -k 256 -n 128 --transA=true --transB=true --perf_config v3:128,128,32,32,32,8,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +# ============================================================================ +# Suite 3: ONLY B uses LDS transpose (transA=false, transB=false) +# 3 tests - A uses regular load +# Only fp8_fp8 and bf8_bf8 +# ============================================================================ +[[suite]] +name = "lds_transpose_B_only_fp8" + +[[suite.test]] +config = "-g 1 -m 64 -k 64 -n 64 --transA=false --transB=false --perf_config v3:64,64,16,16,16,8,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +[[suite.test]] +config = "-g 1 -m 128 -k 128 -n 64 --transA=false --transB=false --perf_config v3:128,64,32,32,16,1,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +[[suite.test]] +config = "-g 1 -m 128 -k 256 -n 128 --transA=false --transB=false --perf_config v3:128,128,8,32,32,32,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +# ============================================================================ +# Suite 4: Mixed FP8/BF8 types (fp8_bf8 and bf8_fp8 only) +# 3 tests: both operands, only A, only B +# ============================================================================ +[[suite]] +name = "lds_transpose_mixed_fp8_bf8" + +# Test 1: Mixed types, both operands use LDS transpose, 16x32 MFMA +[[suite.test]] +config = "-g 1 -m 64 -k 128 -n 64 --transA=true --transB=false --perf_config v3:64,64,32,16,16,8,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_fp8", "bf8_bf8"] + +# Test 2: Mixed types, only A uses LDS transpose, 32x16 MFMA +[[suite.test]] +config = "-g 1 -m 128 -k 128 -n 128 --transA=true --transB=true --perf_config v3:128,128,16,32,32,8,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_fp8", "bf8_bf8"] + +# Test 3: Mixed types, only B uses LDS transpose, 16x32 MFMA +[[suite.test]] +config = "-g 1 -m 64 -k 64 -n 64 --transA=false --transB=false --perf_config v3:64,64,16,16,16,8,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_fp8", "bf8_bf8"] From e0bb0cddc207e37f5283b282ff7d6ff275dafb91 Mon Sep 17 00:00:00 2001 From: stefankoncarevic Date: Thu, 29 Jan 2026 08:07:35 -0600 Subject: [PATCH 3/4] Add FP8 GEMM heuristic to selectively disable LDS transpose Disable LDS transpose for FP8 GEMM when K >= 1280 or small square matrices (K == N < 512) to avoid performance regressions while preserving compile time benefits.Add FP8 GEMM heuristic to selectively disable LDS transpose --- .../Transforms/GridwiseGemmToBlockwise.cpp | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp index 201748cc206b..d7901b1135df 100644 --- a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp @@ -3377,6 +3377,33 @@ struct GridwiseGemmAccelRewritePattern directToLDS, ldsLayoutConfigA, ldsLayoutConfigB, mPerBlock, nPerBlock, kPerBlock, mPerWave, nPerWave, kpack, doubleBuffering); + // FP8 heuristic: check if LDS transpose should be disabled based on dims + // Rule 1: K >= 1280 causes regression, EXCEPT when K > otherDim AND + // otherDim > 1280 + // Rule 2: Small square K-N matrices (K == N && K < 512) + auto shouldDisableLdsTranspose = [K, N](int64_t otherDim) -> bool { + if (K >= 1280 && !(K > otherDim && otherDim > 1280)) + return true; + if (K == N && K < 512) + return true; + return false; + }; + + bool isFp8Type = isa(elementTypeA) || + isa(elementTypeA); + + if (isFp8Type) { + // Case 1: TransA=false, TransB=false - only B uses LDS transpose + if (!ldsDecision.enableA && ldsDecision.enableB && + shouldDisableLdsTranspose(N)) + ldsDecision.enableB = false; + + // Case 2: TransA=true, TransB=true - only A uses LDS transpose + if (ldsDecision.enableA && !ldsDecision.enableB && + shouldDisableLdsTranspose(M)) + ldsDecision.enableA = false; + } + LLVM_DEBUG(llvm::dbgs() << "M: " << M << "\n" << "N: " << N << "\n" From 68d1fdccd2d6f0d1f090087829c0db911d2e19e9 Mon Sep 17 00:00:00 2001 From: stefan koncarevic Date: Wed, 25 Feb 2026 07:55:22 -0600 Subject: [PATCH 4/4] Add 32x64 scaled FP8 MFMA support for LDS transpose load - Add (32,64) geometry support in LdsTransposeLoad.cpp: - New getBasePanelOffsets() branch for 32x64 quad-rate formula - k_block = block_id / 2, m_block = block_id % 2 - kOffsetBase = k_local + k_block * 32 - mOffsetBase = m_parity * 8 + m_block * 16 - Update isQuadRate detection to include 32x64 - Add validation for (32,64) in RockDialect.cpp - Extend tuning ranges for scaled FP8 testing: - kPackPerBlock: added 64 - kPack: added 32 (for k_base=32) Co-authored-by: Cursor --- .../mlir/Dialect/Rock/IR/RockAttrDefs.td | 2 +- mlir/include/mlir/Dialect/Rock/IR/RockOps.td | 2 +- .../Dialect/Rock/utility/LdsTransposeLoad.h | 17 ++- mlir/lib/Dialect/Rock/IR/RockDialect.cpp | 11 +- .../Dialect/Rock/Tuning/RockTuningImpl.cpp | 4 +- .../Dialect/Rock/utility/LdsTransposeLoad.cpp | 127 +++++++++++++++--- 6 files changed, 135 insertions(+), 28 deletions(-) diff --git a/mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td b/mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td index 796e6d6dae9c..1f682c78acd5 100644 --- a/mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td +++ b/mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td @@ -578,7 +578,7 @@ def Rock_LDSTransposeConfigAttr : Rock_Attr<"LDSTransposeConfig", []> { and tiling parameters. - DDim: Matrix-multiply accelerator instruction D dimension (M or N, typically 16 or 32) - - KDim: Matrix-multiply accelerator instruction K dimension (typically 8, 16, or 32) + - KDim: Matrix-multiply accelerator instruction K dimension (typically 8, 16, 32, or 128 for scaled FP8) - mPerBlock: M dimension size per block - nPerBlock: N dimension size per block - kPerBlock: K dimension size per block diff --git a/mlir/include/mlir/Dialect/Rock/IR/RockOps.td b/mlir/include/mlir/Dialect/Rock/IR/RockOps.td index a9fd3f28951f..0cbcd3a80049 100644 --- a/mlir/include/mlir/Dialect/Rock/IR/RockOps.td +++ b/mlir/include/mlir/Dialect/Rock/IR/RockOps.td @@ -1233,7 +1233,7 @@ def Rock_LDSTransposeLoadOp The tile dimensions match the selected matrix-multiply accelerator instruction geometry (`dDim × instrK`), where: - `dDim` is the accelerator M/N dimension (e.g., 16 or 32) - - `instrK` is the accelerator K dimension (e.g., 8, 16, or 32) + - `instrK` is the accelerator K dimension (e.g., 8, 16, 32, or 128 for scaled FP8) For 16-bit types (f16, bf16): - Uses ds_read_tr16_b64 instruction diff --git a/mlir/include/mlir/Dialect/Rock/utility/LdsTransposeLoad.h b/mlir/include/mlir/Dialect/Rock/utility/LdsTransposeLoad.h index 07d87d6aca33..c1f9c6cd15bc 100644 --- a/mlir/include/mlir/Dialect/Rock/utility/LdsTransposeLoad.h +++ b/mlir/include/mlir/Dialect/Rock/utility/LdsTransposeLoad.h @@ -21,8 +21,17 @@ // to the LDS transpose load operation in an accelerator-friendly layout. // // It is intended to simplify the IR generation logic and ensure -// consistent handling of f16/bf16 matrix accelerator tile loads from LDS -// memory. +// consistent handling of f16/bf16/fp8/bf8 matrix accelerator tile loads from +// LDS memory. +// +// Supported element types: +// - f16, bf16: uses ds_read_tr16_b64 (returns 4 elements per thread) +// - f8E4M3FN, f8E5M2 (OCP FP8): uses ds_read_tr8_b64 (returns 8 elements) +// +// Supported MFMA geometries: +// - Standard: (16,16), (16,32), (32,8), (32,16) - single-rate or double-rate +// - Scaled FP8: (16,128) - quad-rate (4 ds_read_tr8 calls per K tile) +// - Scaled FP8: (32,64) - quad-rate (4 ds_read_tr8 calls per K tile) // //===----------------------------------------------------------------------===// @@ -43,7 +52,7 @@ enum class OperandKind { A, B }; // Build LDS transpose config attribute from already-computed MFMA params. // Used in BlockwiseLoadTileToThreadwise when decision was made upstream. // Requires mfmaDDim > 0 and mfmaKDim > 0 (asserted). -// Valid combinations: (16,16), (16,32), (32,8), (32,16) +// Valid combinations: (16,16), (16,32), (16,128), (32,8), (32,16), (32,64) LDSTransposeConfigAttr buildTransposeAttrFromParams( PatternRewriter &rewriter, int64_t mfmaDDim, int64_t mfmaKDim, int64_t mPerBlock, int64_t nPerBlock, int64_t kPerBlock, int64_t mPerWave, @@ -60,7 +69,7 @@ struct LDSTransposeDecision { bool enableA{false}; // Enable for operand A bool enableB{false}; // Enable for operand B int64_t mfmaDDim{0}; // MFMA D dimension (M or N, 16 or 32) - int64_t mfmaKDim{0}; // MFMA K dimension (8, 16, or 32) + int64_t mfmaKDim{0}; // MFMA K dimension (8, 16, 32, 64, or 128) }; // Decides whether to enable LDS transpose for operands A and B diff --git a/mlir/lib/Dialect/Rock/IR/RockDialect.cpp b/mlir/lib/Dialect/Rock/IR/RockDialect.cpp index 27cce5cc9681..4a4c5c6c0a14 100644 --- a/mlir/lib/Dialect/Rock/IR/RockDialect.cpp +++ b/mlir/lib/Dialect/Rock/IR/RockDialect.cpp @@ -560,8 +560,13 @@ LogicalResult TransformMapAttr::verify( // Helper function to check valid MFMA geometry for LDS transpose static bool isValidLdsTransposeMfmaGeometry(int64_t dDim, int64_t kDim) { - return (dDim == 16 && (kDim == 16 || kDim == 32)) || - (dDim == 32 && (kDim == 8 || kDim == 16)); + // Supported geometries: + // - (16,16), (16,32): standard FP16/BF16/FP8 MFMA + // - (16,128): scaled FP8 MFMA (mfma_scale_f32_16x16x128_f8f6f4) + // - (32,8), (32,16): standard FP16/BF16/FP8 MFMA + // - (32,64): scaled FP8 MFMA (mfma_scale_f32_32x32x64_f8f6f4) + return (dDim == 16 && (kDim == 16 || kDim == 32 || kDim == 128)) || + (dDim == 32 && (kDim == 8 || kDim == 16 || kDim == 64)); } LogicalResult LDSTransposeConfigAttr::verify( @@ -573,7 +578,7 @@ LogicalResult LDSTransposeConfigAttr::verify( if (!isValidLdsTransposeMfmaGeometry(dDim, kDim)) { return emitError() << "invalid MFMA geometry (" << dDim << "x" << kDim << ") for LDS transpose - valid combinations: " - "(16,16), (16,32), (32,8), (32,16)"; + "(16,16), (16,32), (16,128), (32,8), (32,16), (32,64)"; } // Validate positive dimensions diff --git a/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp b/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp index ca0afea45388..ba3e6d85f282 100644 --- a/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp +++ b/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp @@ -192,9 +192,9 @@ getAccelRangeGemm(RockGemmWrapperInterface gemmOp, TuningParamSetKind kind) { std::vector> validRangeAccelGemmParams8BitReduction = { dPerBlock, // M/block dPerBlock, // N/block - {4, 8, 16, 32}, // K/block + {4, 8, 16, 32, 64}, // K/block (added 64 for scaled FP8 MFMA) {16, 32}, // MnPerXdl - {1, 4, 8, 16}, // kPack + {1, 4, 8, 16, 32}, // kPack (added 32 for scaled FP8 k_base=32) getSchedules(gemmOp, kind), // scheduleVersion {0, 1}}; // forceUnroll diff --git a/mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp b/mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp index 55f35c8b9084..47ee14b42d01 100644 --- a/mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp +++ b/mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp @@ -79,10 +79,14 @@ static int64_t getTransposeLoadVectorLength(Type elemType) { } // Validates MFMA geometry for LDS transpose support. -// Only specific combinations are supported: (16,16), (16,32), (32,8), (32,16) +// Supported combinations: +// - (16,16), (16,32): standard FP16/BF16/FP8 MFMA (single-rate) +// - (16,128): scaled FP8 MFMA (mfma_scale_f32_16x16x128_f8f6f4) (quad-rate) +// - (32,8), (32,16): standard FP16/BF16/FP8 MFMA (single-rate or double-rate) +// - (32,64): scaled FP8 MFMA (mfma_scale_f32_32x32x64_f8f6f4) (quad-rate) static bool isValidMfmaGeometry(int64_t dDim, int64_t kDim) { - return (dDim == 16 && (kDim == 16 || kDim == 32)) || - (dDim == 32 && (kDim == 8 || kDim == 16)); + return (dDim == 16 && (kDim == 16 || kDim == 32 || kDim == 128)) || + (dDim == 32 && (kDim == 8 || kDim == 16 || kDim == 64)); } // Shape of a single MFMA instruction (internal use only). @@ -346,7 +350,7 @@ LDSTransposeConfigAttr buildTransposeAttrFromParams( "MFMA geometry must be set when building transpose attributes"); assert(isValidMfmaGeometry(mfmaDDim, mfmaKDim) && "Invalid MFMA geometry for LDS transpose - valid: (16,16), (16,32), " - "(32,8), (32,16)"); + "(16,128), (32,8), (32,16), (32,64)"); // Create structured attribute with all parameters return LDSTransposeConfigAttr::get(rewriter.getContext(), mfmaDDim, mfmaKDim, @@ -459,23 +463,27 @@ static Value getDoubleRateKOffsetBase(PatternRewriter &b, Location loc, //===----------------------------------------------------------------------===// // computePanelFinalOffset - Compute final K offset for a specific K tile // -// This function centralizes the K offset computation logic for both single-rate -// and double-rate layouts. It handles the tile-based offset calculation and -// optional low/high half splitting for double-rate layouts. +// This function centralizes the K offset computation logic for single-rate, +// double-rate, and quad-rate layouts. It handles the tile-based offset +// calculation and optional low/high half splitting for double-rate layouts, +// as well as read index offset for quad-rate layouts. // // Formula: // Single-rate: k_final = k_base_local + (kTileIdx * kTileStride) // Double-rate: k_final = k_base_local + kOffsetBase + (kTileIdx * // kTileStride) + halfOffset // where halfOffset = 0 for low half, 4 for high half +// Quad-rate: k_final = k_base_local + (kTileIdx * kTileStride) + readIdx*8 +// where readIdx = 0, 1, 2, 3 for the 4 ds_read_tr8 calls per K tile // // Parameters: // isDoubleRate - Whether this is a double-rate layout (L32x16, L16x32) // kBaseLocal - Local K base offset from computeLDSBaseOffsets() // kOffsetBase - Double-rate K offset base (from getDoubleRateKOffsetBase) // kTileIdx - Current K tile index (0, 1, 2, ...) -// kTileStride - K stride per tile (instrK, e.g., 8 or 16) +// kTileStride - K stride per tile (instrK, e.g., 8, 16, 32, or 128) // isHighHalf - For double-rate: true = high half (+4), false = low half +// readIdx - For quad-rate: 0-3 index for consecutive 8-K chunks // // Returns: // Final K offset value to use for emitPanelLoad() @@ -483,8 +491,8 @@ static Value getDoubleRateKOffsetBase(PatternRewriter &b, Location loc, static Value computePanelFinalOffset(PatternRewriter &b, Location loc, bool isDoubleRate, Value kBaseLocal, Value kOffsetBase, int64_t kTileIdx, - Value kTileStride, - bool isHighHalf = false) { + Value kTileStride, bool isHighHalf = false, + int64_t readIdx = 0) { Value kBase = kBaseLocal; if (isDoubleRate) { @@ -510,17 +518,27 @@ static Value computePanelFinalOffset(PatternRewriter &b, Location loc, kBase = arith::AddIOp::create(b, loc, kBaseLocal, k_offset); } else { - // Single-rate: k_base = k_base_local + kTileIdx * kTileStride + // Single-rate or Quad-rate: k_base = k_base_local + kTileIdx * kTileStride if (kTileIdx > 0) { Value kIdxVal = arith::ConstantIndexOp::create(b, loc, kTileIdx); Value kOffsetAdd = arith::MulIOp::create(b, loc, kTileStride, kIdxVal); kBase = arith::AddIOp::create(b, loc, kBase, kOffsetAdd); } + + // Quad-rate: add readIdx * 8 for consecutive 8-K chunks within k_base=32 + // readIdx=0: K+0..7, readIdx=1: K+8..15, readIdx=2: K+16..23, readIdx=3: + // K+24..31 + if (readIdx > 0) { + Value c8 = arith::ConstantIndexOp::create(b, loc, 8); + Value readIdxVal = arith::ConstantIndexOp::create(b, loc, readIdx); + Value readOffset = arith::MulIOp::create(b, loc, readIdxVal, c8); + kBase = arith::AddIOp::create(b, loc, kBase, readOffset); + } } LLVM_DEBUG(llvm::dbgs() << "[lds_transpose] Computed panel K offset for tile " << kTileIdx << (isHighHalf ? " (high)" : " (low)") - << "\n"); + << ", readIdx=" << readIdx << "\n"); return kBase; } @@ -684,7 +702,7 @@ static SmallVector getBasePanelOffsets(PatternRewriter &b, Location loc, mOffsetBase = arith::MulIOp::create(b, loc, mParity, c8); } else if (dDim == 32 && kDim == 16) { - // FP8 32x16: 4-block formula (VERIFIED from HIP testing) + // FP8 32x16: 4-block formula // Block layout: // Block 0 (T0-T15): M=0..15, K=0..7 // Block 1 (T16-T31): M=16..31, K=0..7 @@ -709,6 +727,60 @@ static SmallVector getBasePanelOffsets(PatternRewriter &b, Location loc, Value mBlockOffset = arith::MulIOp::create(b, loc, mBlock, c16); mOffsetBase = arith::AddIOp::create(b, loc, mParityOffset, mBlockOffset); + } else if (dDim == 16 && kDim == 128) { + // FP8 Scaled 16x128: 4-block formula with k_base=32 (QUAD-RATE) + // Each thread provides 32 CONSECUTIVE K elements via 4 ds_read_tr8 calls. + // + // Thread mapping (consecutive K per thread group): + // Block 0 (T0-T15): M=0..15, K=0..31 + // Block 1 (T16-T31): M=0..15, K=32..63 + // Block 2 (T32-T47): M=0..15, K=64..95 + // Block 3 (T48-T63): M=0..15, K=96..127 + + Value c32 = arith::ConstantIndexOp::create(b, loc, 32); + + Value blockId = arith::DivUIOp::create(b, loc, lane, c16); + Value laneInBlock = arith::RemUIOp::create(b, loc, lane, c16); + Value kLocal = arith::DivUIOp::create(b, loc, laneInBlock, c2); + Value mParity = arith::RemUIOp::create(b, loc, laneInBlock, c2); + + // kOffsetBase = k_local + block_id * 32 + Value blockKOffset = arith::MulIOp::create(b, loc, blockId, c32); + kOffsetBase = arith::AddIOp::create(b, loc, kLocal, blockKOffset); + + // mOffsetBase = m_parity * 8 + mOffsetBase = arith::MulIOp::create(b, loc, mParity, c8); + + } else if (dDim == 32 && kDim == 64) { + // FP8 Scaled 32x64: 4-block formula with k_base=32 (QUAD-RATE) + // Each thread provides 32 CONSECUTIVE K elements via 4 ds_read_tr8 calls. + // + // Thread mapping (same m_block/k_block split as 32x16): + // Block 0 (T0-T15): M=0..15, K=0..31 + // Block 1 (T16-T31): M=16..31, K=0..31 + // Block 2 (T32-T47): M=0..15, K=32..63 + // Block 3 (T48-T63): M=16..31, K=32..63 + + Value c32 = arith::ConstantIndexOp::create(b, loc, 32); + + Value blockId = arith::DivUIOp::create(b, loc, lane, c16); + Value laneInBlock = arith::RemUIOp::create(b, loc, lane, c16); + Value kLocal = arith::DivUIOp::create(b, loc, laneInBlock, c2); + Value mParity = arith::RemUIOp::create(b, loc, laneInBlock, c2); + + // m_block = block_id % 2, k_block = block_id / 2 + Value mBlock = arith::RemUIOp::create(b, loc, blockId, c2); + Value kBlock = arith::DivUIOp::create(b, loc, blockId, c2); + + // kOffsetBase = k_local + k_block * 32 + Value kBlockOffset = arith::MulIOp::create(b, loc, kBlock, c32); + kOffsetBase = arith::AddIOp::create(b, loc, kLocal, kBlockOffset); + + // mOffsetBase = m_parity * 8 + m_block * 16 + Value mParityOffset = arith::MulIOp::create(b, loc, mParity, c8); + Value mBlockOffset = arith::MulIOp::create(b, loc, mBlock, c16); + mOffsetBase = arith::AddIOp::create(b, loc, mParityOffset, mBlockOffset); + } else { llvm_unreachable("Unsupported FP8 MFMA geometry in getBasePanelOffsets"); } @@ -1267,13 +1339,16 @@ LogicalResult emitThreadwiseHWTranspose(PatternRewriter &b, // Use mPerBlock as stride for operand A, nPerBlock for operand B int64_t ldsStride = (operand == OperandKind::A) ? mPerBlock : nPerBlock; - // Determine if this is a double-rate instruction + // Determine if this is a double-rate or quad-rate instruction // Double-rate ONLY for (32,16) and (16,32) MFMA with F16/BF16 // FP8/BF8 uses ds_read_tr8_b64 which returns 8 elements, so (16,32) and // (32,16) are SINGLE-RATE for FP8/BF8 (16,16) and (32,8) are always // SINGLE-RATE + // Quad-rate for FP8 scaled MFMA: 16x128 and 32x64 (k_base=32, 4 reads of 8) bool isDoubleRate = !isFp8Type(elemType) && ((dDim == 32 && instrK == 16) || (dDim == 16 && instrK == 32)); + bool isQuadRate = isFp8Type(elemType) && + ((dDim == 16 && instrK == 128) || (dDim == 32 && instrK == 64)); // Determine vector length based on element type: // - f16/bf16: ds_read_tr16_b64 returns vector<4> @@ -1284,6 +1359,7 @@ LogicalResult emitThreadwiseHWTranspose(PatternRewriter &b, // panelVectors will contain: // - Single-rate: 1 vector per K tile // - Double-rate (f16/bf16 only): 2 vectors per K tile (low + high) + // - Quad-rate (FP8 16x128 or 32x64): 4 vectors per K tile (readIdx 0-3) SmallVector panelVectors; // Get base offsets using computeLDSBaseOffsets helper @@ -1342,7 +1418,23 @@ LogicalResult emitThreadwiseHWTranspose(PatternRewriter &b, b, loc, m_offset_base, operand, waveM, waveN, mnTileIndex, mnIdxLocal, useDynamicMnIndex, waveOffsetStrideVal, tileOffsetStrideVal); - if (!isDoubleRate) { + if (isQuadRate) { + // QUAD-RATE (FP8 scaled MFMA 16x128 or 32x64): FOUR loads per K tile + // Each load returns vector<8> for fp8, total 32 elements per K tile + // k_base=32, so 4 reads of 8 elements each give consecutive K + // For 16x128: all blocks in K dimension (block_id * 32) + // For 32x64: m_block/k_block split (k_block = block_id / 2) * 32 + for (int64_t readIdx = 0; readIdx < 4; ++readIdx) { + Value k_base = computePanelFinalOffset( + b, loc, /*isDoubleRate=*/false, k_base_local, kOffsetBase, kIdx, + kTileStrideVal, /*isHighHalf=*/false, /*readIdx=*/readIdx); + + Value panelVec = emitPanelLoad(b, loc, rawSrc, k_base, m_base, + ldsStrideVal, panelVecType); + panelVectors.push_back(panelVec); + } + + } else if (!isDoubleRate) { // SINGLE-RATE (L32x8, L16x16, or FP8/BF8): One load per K tile Value k_base = computePanelFinalOffset(b, loc, isDoubleRate, k_base_local, @@ -1380,11 +1472,12 @@ LogicalResult emitThreadwiseHWTranspose(PatternRewriter &b, // Calculate expected number of loads // - For double buffering: we generate ALL M/N panels → endMnIdx panels × - // kPanels × (1 or 2 for rate) + // kPanels × (1, 2, or 4 for rate) // - Single-rate: 1 load per K tile → actualMnTiles × kPanels loads // - Double-rate: 2 loads per K tile → actualMnTiles × kPanels × 2 loads + // - Quad-rate: 4 loads per K tile → actualMnTiles × kPanels × 4 loads int64_t actualMnTiles = endMnIdx - startMnIdx; - int64_t loadsPerKTile = isDoubleRate ? 2 : 1; + int64_t loadsPerKTile = isQuadRate ? 4 : (isDoubleRate ? 2 : 1); int64_t expectedLoads = actualMnTiles * kPanels * loadsPerKTile; // Each load produces vecLen elements: