diff --git a/mlir/include/mlir/Dialect/Rock/IR/RockOps.td b/mlir/include/mlir/Dialect/Rock/IR/RockOps.td index 8d70078b06f0..a9fd3f28951f 100644 --- a/mlir/include/mlir/Dialect/Rock/IR/RockOps.td +++ b/mlir/include/mlir/Dialect/Rock/IR/RockOps.td @@ -1221,9 +1221,10 @@ defvar SameShapeVectorOfI1 = [{ def Rock_LDSTransposeLoadOp : Rock_Op<"lds_transpose_load", [DeclareOpInterfaceMethods< MemoryEffectsOpInterface>]>, - Arguments<(ins Arg, "LDS source buffer">:$source, + Arguments<(ins Arg, + "LDS source buffer">:$source, Variadic:$indices)>, - Results<(outs VectorOfLengthAndType<[4], [F16, BF16]>:$result)> { + Results<(outs AnyVectorOfNonZeroRank:$result)> { let summary = "Hardware-assisted LDS transpose load for matrix accelerator tile"; let description = [{ @@ -1233,8 +1234,14 @@ def Rock_LDSTransposeLoadOp instruction geometry (`dDim × instrK`), where: - `dDim` is the accelerator M/N dimension (e.g., 16 or 32) - `instrK` is the accelerator K dimension (e.g., 8, 16, or 32) - The operation returns a vector of 4 elements per thread containing - transposed elements in a layout suitable for matrix accelerator instructions. + + For 16-bit types (f16, bf16): + - Uses ds_read_tr16_b64 instruction + - Returns vector<4xtype> (4 elements per thread) + + For 8-bit types (f8E4M3FN, f8E5M2 - OCP FP8 for gfx950): + - Uses ds_read_tr8_b64 instruction + - Returns vector<8xtype> (8 elements per thread) Benefits: - Reduces LDS bank conflicts through optimized access patterns diff --git a/mlir/lib/Dialect/Rock/IR/RockDialect.cpp b/mlir/lib/Dialect/Rock/IR/RockDialect.cpp index 5b7c3131817e..27cce5cc9681 100644 --- a/mlir/lib/Dialect/Rock/IR/RockDialect.cpp +++ b/mlir/lib/Dialect/Rock/IR/RockDialect.cpp @@ -2161,6 +2161,27 @@ LogicalResult LDSTransposeLoadOp::verify() { << srcElemType << ")"; } + // Verify result vector length based on element type: + // - 16-bit types (f16, bf16): ds_read_tr16_b64 returns 4 elements + // - 8-bit types (f8E4M3FN, f8E5M2 - OCP FP8 for gfx950): ds_read_tr8_b64 + // returns 8 elements + int64_t expectedVecLen; + if (srcElemType.isF16() || srcElemType.isBF16()) { + expectedVecLen = 4; + } else if (isa(srcElemType) || + isa(srcElemType)) { + expectedVecLen = 8; + } else { + return emitOpError("unsupported element type for LDS transpose load: ") + << srcElemType; + } + + if (resultType.getNumElements() != expectedVecLen) { + return emitOpError("expected result vector of ") + << expectedVecLen << " elements for " << srcElemType + << " type, but got " << resultType.getNumElements(); + } + // Check hardware support using AmdArchDb StringRef arch = rock::getArchValue(*this); AmdArchInfo archInfo = rock::lookupArchInfo(arch); diff --git a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp index 201748cc206b..d7901b1135df 100644 --- a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp @@ -3377,6 +3377,33 @@ struct GridwiseGemmAccelRewritePattern directToLDS, ldsLayoutConfigA, ldsLayoutConfigB, mPerBlock, nPerBlock, kPerBlock, mPerWave, nPerWave, kpack, doubleBuffering); + // FP8 heuristic: check if LDS transpose should be disabled based on dims + // Rule 1: K >= 1280 causes regression, EXCEPT when K > otherDim AND + // otherDim > 1280 + // Rule 2: Small square K-N matrices (K == N && K < 512) + auto shouldDisableLdsTranspose = [K, N](int64_t otherDim) -> bool { + if (K >= 1280 && !(K > otherDim && otherDim > 1280)) + return true; + if (K == N && K < 512) + return true; + return false; + }; + + bool isFp8Type = isa(elementTypeA) || + isa(elementTypeA); + + if (isFp8Type) { + // Case 1: TransA=false, TransB=false - only B uses LDS transpose + if (!ldsDecision.enableA && ldsDecision.enableB && + shouldDisableLdsTranspose(N)) + ldsDecision.enableB = false; + + // Case 2: TransA=true, TransB=true - only A uses LDS transpose + if (ldsDecision.enableA && !ldsDecision.enableB && + shouldDisableLdsTranspose(M)) + ldsDecision.enableA = false; + } + LLVM_DEBUG(llvm::dbgs() << "M: " << M << "\n" << "N: " << N << "\n" diff --git a/mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp b/mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp index f65e1487c847..55f35c8b9084 100644 --- a/mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp +++ b/mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp @@ -21,8 +21,13 @@ // to the LDS transpose load operation in an accelerator-friendly layout. // // It is intended to simplify the IR generation logic and ensure -// consistent handling of f16/bf16 matrix accelerator tile loads from LDS -// memory. +// consistent handling of f16/bf16/fp8/bf8 matrix accelerator tile loads from +// LDS memory. +// +// Supported element types: +// - f16, bf16: uses ds_read_tr16_b64 (returns 4 elements per thread) +// - f8E4M3FN, f8E5M2 (OCP FP8): uses ds_read_tr8_b64 (returns 8 elements per +// thread) // //===----------------------------------------------------------------------===// @@ -47,6 +52,32 @@ namespace { bool archSupported(StringRef arch) { return arch.contains("gfx950"); } +// Check if element type is supported for LDS transpose load +// - f16, bf16: ds_read_tr16_b64 (4 elements) +// - f8E4M3FN, f8E5M2 (OCP FP8 for gfx950): ds_read_tr8_b64 (8 elements) +static bool isSupportedElementType(Type t) { + return t.isF16() || t.isBF16() || isa(t) || + isa(t); +} + +// Check if element type is 8-bit float (FP8 E4M3 or BF8 E5M2) +// Used for: +// 1. Selecting ds_read_tr8_b64 vs ds_read_tr16_b64 +// 2. Checking mixed-type compatibility (fp8+bf8 combinations are valid) +static bool isFp8Type(Type t) { + return isa(t) || isa(t); +} + +// Returns the number of elements returned by LDS transpose load instruction +static int64_t getTransposeLoadVectorLength(Type elemType) { + if (elemType.isF16() || elemType.isBF16()) { + return 4; // ds_read_tr16_b64 + } else if (isFp8Type(elemType)) { + return 8; // ds_read_tr8_b64 + } + llvm_unreachable("Unsupported element type for LDS transpose load"); +} + // Validates MFMA geometry for LDS transpose support. // Only specific combinations are supported: (16,16), (16,32), (32,8), (32,16) static bool isValidMfmaGeometry(int64_t dDim, int64_t kDim) { @@ -119,7 +150,16 @@ static Decision makeDecision(StringRef arch, Type elemTypeA, Type elemTypeB, return dec; } - if (elemTypeA != elemTypeB || !(elemTypeA.isF16() || elemTypeA.isBF16())) + // Check type compatibility: + // - Same types are always allowed (f16==f16, fp8==fp8, etc.) + // - Mixed FP8/BF8 combinations are allowed (hardware supports mixed fp8/bf8 + // MFMA: + // mfma_f32_16x16x32_fp8_fp8, mfma_f32_16x16x32_fp8_bf8, etc.) + // - Other mixed types are NOT allowed (e.g., f16 with fp8) + bool typesCompatible = (elemTypeA == elemTypeB) || + (isFp8Type(elemTypeA) && isFp8Type(elemTypeB)); + if (!typesCompatible || !isSupportedElementType(elemTypeA) || + !isSupportedElementType(elemTypeB)) return dec; // Validate MFMA geometry @@ -489,8 +529,9 @@ static Value computePanelFinalOffset(PatternRewriter &b, Location loc, // emitPanelLoad - Emit an LDS transpose load operation // // Computes the final LDS offset and emits a hardware LDS transpose load -// instruction (ds_read_tr16_b64). This instruction always returns vector<4xf16> -// regardless of the layout. +// instruction: +// - ds_read_tr16_b64 for f16/bf16: returns vector<4> +// - ds_read_tr8_b64 for fp8/bf8: returns vector<8> // // The final offset is computed as: final_offset = k_base * ldsStride + m_base // where ldsStride depends on the operand (mPerBlock for A, nPerBlock for B). @@ -503,10 +544,10 @@ static Value computePanelFinalOffset(PatternRewriter &b, Location loc, // kBase - K dimension base offset for this panel // mBase - M/N dimension base offset for this panel // ldsStride - Stride between K rows in LDS (mPerBlock or nPerBlock) -// panelVecType - Result type (always vector<4xf16> or vector<4xbf16>) +// panelVecType - Result type (vector<4> for f16/bf16, vector<8> for fp8/bf8) // // Returns: -// The loaded panel vector (vector<4xf16/bf16>) +// The loaded panel vector (vector<4> for f16/bf16, vector<8> for fp8/bf8) //===----------------------------------------------------------------------===// static Value emitPanelLoad(PatternRewriter &b, Location loc, Value rawSrc, Value kBase, Value mBase, Value ldsStride, @@ -515,7 +556,7 @@ static Value emitPanelLoad(PatternRewriter &b, Location loc, Value rawSrc, Value kOffset = arith::MulIOp::create(b, loc, kBase, ldsStride); Value finalOffset = arith::AddIOp::create(b, loc, mBase, kOffset); - // Emit hardware LDS transpose load: ds_read_tr16_b64 + // Emit hardware LDS transpose load auto loadOp = rock::LDSTransposeLoadOp::create(b, loc, panelVecType, rawSrc, ValueRange{finalOffset}); @@ -530,9 +571,9 @@ static Value emitPanelLoad(PatternRewriter &b, Location loc, Value rawSrc, // elements (ds_read_tr16_b64 always returns vector<4xf16>). // // Parameters: -// panelVectors - Array of loaded panel vectors (each is vector<4xf16>) -// dest - Destination memref (rank-1, scalar layout) -// targetElems - Maximum number of elements to write +// panelVectors - Array of loaded panel vectors (vector<4> for f16/bf16, +// vector<8> for fp8/bf8) dest - Destination memref (rank-1, scalar +// layout) targetElems - Maximum number of elements to write // // Returns: // success() if all target elements were written @@ -546,14 +587,16 @@ writePanelVectorsToDestination(PatternRewriter &b, Location loc, int64_t produced = 0; // Extract elements per vector from the actual vector type - // Hardware instruction ds_read_tr16_b64 always returns vector<4xf16> + // Hardware instructions: + // - ds_read_tr16_b64 returns vector<4xf16/bf16> + // - ds_read_tr8_b64 returns vector<8xfp8> assert(!panelVectors.empty() && "Panel vectors array must not be empty"); auto panelVecType = cast(panelVectors[0].getType()); int64_t elementsPerVector = panelVecType.getShape()[0]; - // Verify hardware constraint: ds_read_tr16_b64 returns exactly 4 elements - assert(elementsPerVector == 4 && - "LDS transpose load must produce vector<4xf16> per panel"); + // Verify hardware constraint: 4 elements for 16-bit, 8 elements for 8-bit + assert((elementsPerVector == 4 || elementsPerVector == 8) && + "LDS transpose load must produce vector<4> or vector<8> per panel"); LLVM_DEBUG(llvm::dbgs() << "[lds_transpose] Writing " << panelVectors.size() << " panel vectors (" @@ -609,60 +652,114 @@ writePanelVectorsToDestination(PatternRewriter &b, Location loc, //===----------------------------------------------------------------------===// static SmallVector getBasePanelOffsets(PatternRewriter &b, Location loc, int64_t dDim, int64_t kDim, - Value lane) { - Value c16 = arith::ConstantIndexOp::create(b, loc, 16); - Value c4 = arith::ConstantIndexOp::create(b, loc, 4); + Value lane, Type elemType) { + // Common constants used by both FP8 and F16/BF16 Value c2 = arith::ConstantIndexOp::create(b, loc, 2); + Value c4 = arith::ConstantIndexOp::create(b, loc, 4); + Value c16 = arith::ConstantIndexOp::create(b, loc, 16); - Value blockId = arith::DivUIOp::create(b, loc, lane, c16); - Value laneInBlock = arith::RemUIOp::create(b, loc, lane, c16); - - // Base offset calculations - Value mOffsetBase = arith::MulIOp::create( - b, loc, arith::RemUIOp::create(b, loc, laneInBlock, c4), c4); - Value kOffsetBase = arith::DivUIOp::create(b, loc, laneInBlock, c4); + Value kOffsetBase, mOffsetBase; + + if (isFp8Type(elemType)) { + Value c8 = arith::ConstantIndexOp::create(b, loc, 8); + + if (dDim == 16 && kDim == 32) { + // FP8/BF8 16x32: 4-block formula + // Block layout: + // Block 0 (T0-T15): K=0..7 + // Block 1 (T16-T31): K=8..15 + // Block 2 (T32-T47): K=16..23 + // Block 3 (T48-T63): K=24..31 + + Value blockId = arith::DivUIOp::create(b, loc, lane, c16); + Value laneInBlock = arith::RemUIOp::create(b, loc, lane, c16); + Value kLocal = arith::DivUIOp::create(b, loc, laneInBlock, c2); + Value mParity = arith::RemUIOp::create(b, loc, laneInBlock, c2); + + // kOffsetBase = k_local + block_id * 8 + Value blockKOffset = arith::MulIOp::create(b, loc, blockId, c8); + kOffsetBase = arith::AddIOp::create(b, loc, kLocal, blockKOffset); + + // mOffsetBase = m_parity * 8 + mOffsetBase = arith::MulIOp::create(b, loc, mParity, c8); + + } else if (dDim == 32 && kDim == 16) { + // FP8 32x16: 4-block formula (VERIFIED from HIP testing) + // Block layout: + // Block 0 (T0-T15): M=0..15, K=0..7 + // Block 1 (T16-T31): M=16..31, K=0..7 + // Block 2 (T32-T47): M=0..15, K=8..15 + // Block 3 (T48-T63): M=16..31, K=8..15 + + Value blockId = arith::DivUIOp::create(b, loc, lane, c16); + Value laneInBlock = arith::RemUIOp::create(b, loc, lane, c16); + Value kLocal = arith::DivUIOp::create(b, loc, laneInBlock, c2); + Value mParity = arith::RemUIOp::create(b, loc, laneInBlock, c2); + + // m_block = block_id % 2, k_block = block_id / 2 + Value mBlock = arith::RemUIOp::create(b, loc, blockId, c2); + Value kBlock = arith::DivUIOp::create(b, loc, blockId, c2); + + // kOffsetBase = k_local + k_block * 8 + Value kBlockOffset = arith::MulIOp::create(b, loc, kBlock, c8); + kOffsetBase = arith::AddIOp::create(b, loc, kLocal, kBlockOffset); + + // mOffsetBase = m_parity * 8 + m_block * 16 + Value mParityOffset = arith::MulIOp::create(b, loc, mParity, c8); + Value mBlockOffset = arith::MulIOp::create(b, loc, mBlock, c16); + mOffsetBase = arith::AddIOp::create(b, loc, mParityOffset, mBlockOffset); - SmallVector panelOffsets; + } else { + llvm_unreachable("Unsupported FP8 MFMA geometry in getBasePanelOffsets"); + } - if (dDim == 16 && kDim == 32) { - // 16x32 layout - panelOffsets = {kOffsetBase, mOffsetBase}; - } else if (dDim == 16 && kDim == 16) { - // 16x16 layout - // kbase = kOffsetBase + (blockId * 4) - Value kBase = arith::AddIOp::create( - b, loc, arith::MulIOp::create(b, loc, blockId, c4), kOffsetBase); - panelOffsets = {kBase, mOffsetBase}; - } else if (dDim == 32 && kDim == 16) { - // 32x16 layout - // mbase = mOffsetBase + (blockId % 2) * 16 - Value mBase = arith::AddIOp::create( - b, loc, - arith::MulIOp::create(b, loc, - arith::RemUIOp::create(b, loc, blockId, c2), c16), - mOffsetBase); - panelOffsets = {kOffsetBase, mBase}; - } else if (dDim == 32 && kDim == 8) { - // 32x8 layout - // k_base_local = kOffsetBase + (blockId / 2) * 4 - Value kBase = arith::AddIOp::create( - b, loc, - arith::MulIOp::create(b, loc, - arith::DivUIOp::create(b, loc, blockId, c2), c4), - kOffsetBase); - - // m_offset_base = mOffsetBase + (blockId % 2) * 16 - Value mBase = arith::AddIOp::create( - b, loc, - arith::MulIOp::create(b, loc, - arith::RemUIOp::create(b, loc, blockId, c2), c16), - mOffsetBase); - panelOffsets = {kBase, mBase}; } else { - llvm_unreachable("Unsupported MFMA geometry in getBasePanelOffsets"); + // F16/BF16 uses block-based lane mapping + Value blockId = arith::DivUIOp::create(b, loc, lane, c16); + Value laneInBlock = arith::RemUIOp::create(b, loc, lane, c16); + + // Base calculations common to all F16/BF16 geometries + Value kLocal = arith::DivUIOp::create(b, loc, laneInBlock, c4); + Value mLocal = arith::MulIOp::create( + b, loc, arith::RemUIOp::create(b, loc, laneInBlock, c4), c4); + + if (dDim == 16 && kDim == 32) { + // 16x32: direct mapping + kOffsetBase = kLocal; + mOffsetBase = mLocal; + + } else if (dDim == 16 && kDim == 16) { + // 16x16: k += blockId * 4 + kOffsetBase = arith::AddIOp::create( + b, loc, kLocal, arith::MulIOp::create(b, loc, blockId, c4)); + mOffsetBase = mLocal; + + } else if (dDim == 32 && kDim == 16) { + // 32x16: m += (blockId % 2) * 16 + kOffsetBase = kLocal; + mOffsetBase = arith::AddIOp::create( + b, loc, mLocal, + arith::MulIOp::create( + b, loc, arith::RemUIOp::create(b, loc, blockId, c2), c16)); + + } else if (dDim == 32 && kDim == 8) { + // 32x8: k += (blockId / 2) * 4, m += (blockId % 2) * 16 + kOffsetBase = arith::AddIOp::create( + b, loc, kLocal, + arith::MulIOp::create( + b, loc, arith::DivUIOp::create(b, loc, blockId, c2), c4)); + mOffsetBase = arith::AddIOp::create( + b, loc, mLocal, + arith::MulIOp::create( + b, loc, arith::RemUIOp::create(b, loc, blockId, c2), c16)); + + } else { + llvm_unreachable( + "Unsupported F16/BF16 MFMA geometry in getBasePanelOffsets"); + } } - return panelOffsets; + return {kOffsetBase, mOffsetBase}; } //===----------------------------------------------------------------------===// @@ -683,6 +780,7 @@ static SmallVector getBasePanelOffsets(PatternRewriter &b, Location loc, // dDim - MFMA D dimension (M or N, 16 or 32) // kDim - MFMA K dimension (8, 16, or 32) // lane - Thread's lane ID within the workgroup +// elemType - Element type (f16, bf16, fp8, or bf8) for selecting lane mapping // // Returns: // std::pair: @@ -691,8 +789,10 @@ static SmallVector getBasePanelOffsets(PatternRewriter &b, Location loc, //===----------------------------------------------------------------------===// static std::pair computeLDSBaseOffsets(PatternRewriter &b, Location loc, int64_t dDim, - int64_t kDim, Value lane) { - SmallVector offsets = getBasePanelOffsets(b, loc, dDim, kDim, lane); + int64_t kDim, Value lane, + Type elemType) { + SmallVector offsets = + getBasePanelOffsets(b, loc, dDim, kDim, lane, elemType); LLVM_DEBUG(llvm::dbgs() << "[lds_transpose] Computed LDS base offsets for " << dDim << "x" << kDim << ": " @@ -1098,8 +1198,9 @@ static Value computeFinalMNOffset(PatternRewriter &b, Location loc, // emitThreadwiseHWTranspose - Lower threadwise_read_into to HW transpose loads //===----------------------------------------------------------------------===// // Lowers threadwise_read_into with LDS transpose config into hardware transpose -// load instructions (ds_read_tr16_b64) that read from LDS in MFMA-friendly -// order. +// load instructions that read from LDS in MFMA-friendly order: +// - ds_read_tr16_b64 for f16/bf16 (returns vector<4>) +// - ds_read_tr8_b64 for fp8/bf8 (returns vector<8>) // // Algorithm: // 1. Extract config: MFMA geometry (dDim, kDim), tiling params, operand kind @@ -1112,8 +1213,8 @@ static Value computeFinalMNOffset(PatternRewriter &b, Location loc, // - Outer: M/N tiles (all at once for double-buffering, one at a time // otherwise) // - Inner: K tiles (1 load for single-rate, 2 loads for double-rate layouts) -// 6. For each iteration: compute final LDS offset, emit ds_read_tr16_b64 -// instruction (returns vector<4xf16>) +// 6. For each iteration: compute final LDS offset, emit LDS transpose load +// instruction (ds_read_tr16_b64 for f16/bf16, ds_read_tr8_b64 for fp8/bf8) // 7. Extract elements from panel vectors and write sequentially to destination // // Example: 16x32 layout, 1 M-tile, 2 K-tiles → 2 ds_read_tr16_b64 calls → 8 @@ -1167,23 +1268,27 @@ LogicalResult emitThreadwiseHWTranspose(PatternRewriter &b, int64_t ldsStride = (operand == OperandKind::A) ? mPerBlock : nPerBlock; // Determine if this is a double-rate instruction - // Double-rate ONLY for (32,16) and (16,32) MFMA - // (16,16) and (32,8) are SINGLE-RATE - bool isDoubleRate = - (dDim == 32 && instrK == 16) || (dDim == 16 && instrK == 32); - - // Each ds_read_tr16_b64 call ALWAYS returns vector<4xf16> - // For double-rate, we make 2 calls and store all 8 elements separately - VectorType panelVecType = VectorType::get({4}, elemType); + // Double-rate ONLY for (32,16) and (16,32) MFMA with F16/BF16 + // FP8/BF8 uses ds_read_tr8_b64 which returns 8 elements, so (16,32) and + // (32,16) are SINGLE-RATE for FP8/BF8 (16,16) and (32,8) are always + // SINGLE-RATE + bool isDoubleRate = !isFp8Type(elemType) && ((dDim == 32 && instrK == 16) || + (dDim == 16 && instrK == 32)); + + // Determine vector length based on element type: + // - f16/bf16: ds_read_tr16_b64 returns vector<4> + // - fp8/bf8: ds_read_tr8_b64 returns vector<8> + int64_t vecLen = getTransposeLoadVectorLength(elemType); + VectorType panelVecType = VectorType::get({vecLen}, elemType); // panelVectors will contain: - // - Single-rate: 1 vector<4xf16> per K tile - // - Double-rate: 2 vector<4xf16> per K tile (low + high) + // - Single-rate: 1 vector per K tile + // - Double-rate (f16/bf16 only): 2 vectors per K tile (low + high) SmallVector panelVectors; // Get base offsets using computeLDSBaseOffsets helper auto [k_base_local, m_offset_base] = - computeLDSBaseOffsets(b, loc, dDim, instrK, lane); + computeLDSBaseOffsets(b, loc, dDim, instrK, lane, elemType); // K stride per tile: instrK (MFMA K dimension) int64_t kTileStride = instrK; @@ -1238,20 +1343,21 @@ LogicalResult emitThreadwiseHWTranspose(PatternRewriter &b, useDynamicMnIndex, waveOffsetStrideVal, tileOffsetStrideVal); if (!isDoubleRate) { - // SINGLE-RATE (L32x8, L16x16): One load per K tile + // SINGLE-RATE (L32x8, L16x16, or FP8/BF8): One load per K tile Value k_base = computePanelFinalOffset(b, loc, isDoubleRate, k_base_local, kOffsetBase, kIdx, kTileStrideVal); - // Emit LDS transpose load for this K tile (single-rate: one per K tile) + // Emit LDS transpose load for this K tile Value panelVec = emitPanelLoad(b, loc, rawSrc, k_base, m_base, ldsStrideVal, panelVecType); panelVectors.push_back(panelVec); } else { - // DOUBLE-RATE (L32x16, L16x32): TWO loads per K tile - // Each load returns vector<4xf16>, total 8 elements per K tile - // Compute K offsets for low and high halves + // DOUBLE-RATE (L32x16, L16x32 with F16/BF16 only): TWO loads per K tile + // Each load returns vector<4> for f16/bf16, total 8 elements per K tile + // Note: FP8/BF8 is NEVER double-rate (ds_read_tr8_b64 returns 8 + // elements) Compute K offsets for low and high halves Value k_base_low = computePanelFinalOffset( b, loc, isDoubleRate, k_base_local, kOffsetBase, kIdx, kTileStrideVal, /*isHighHalf=*/false); @@ -1281,8 +1387,10 @@ LogicalResult emitThreadwiseHWTranspose(PatternRewriter &b, int64_t loadsPerKTile = isDoubleRate ? 2 : 1; int64_t expectedLoads = actualMnTiles * kPanels * loadsPerKTile; - // Each load ALWAYS produces 4 elements (ds_read_tr16_b64 → vector<4xf16>) - int64_t sliceElems = expectedLoads * 4; + // Each load produces vecLen elements: + // - f16/bf16: 4 elements (ds_read_tr16_b64) + // - fp8/bf8: 8 elements (ds_read_tr8_b64) + int64_t sliceElems = expectedLoads * vecLen; // Verify we generated the expected number of loads if (panelVectors.size() != (size_t)expectedLoads) { diff --git a/mlir/test/Dialect/Rock/lowering_load_transpose_lds.mlir b/mlir/test/Dialect/Rock/lowering_load_transpose_lds.mlir index 29a2f190a06b..b6a64a6985f7 100644 --- a/mlir/test/Dialect/Rock/lowering_load_transpose_lds.mlir +++ b/mlir/test/Dialect/Rock/lowering_load_transpose_lds.mlir @@ -14,4 +14,18 @@ module attributes {mhal.arch = "amdgcn-amd-amdhsa:gfx950"} { %v = rock.lds_transpose_load %src[%i, %j] : memref<64x128xbf16, #gpu.address_space> -> vector<4xbf16> return %v : vector<4xbf16> } + +// CHECK-LABEL: func @test_load_transpose_fp8_e4m3 + func.func @test_load_transpose_fp8_e4m3(%src: memref<128x256xf8E4M3FN, #gpu.address_space>, %i: index, %j: index) -> vector<8xf8E4M3FN> { + // CHECK: amdgpu.transpose_load %arg0[%arg1, %arg2] : memref<128x256xf8E4M3FN, #gpu.address_space> -> vector<8xf8E4M3FN> + %v = rock.lds_transpose_load %src[%i, %j] : memref<128x256xf8E4M3FN, #gpu.address_space> -> vector<8xf8E4M3FN> + return %v : vector<8xf8E4M3FN> + } + +// CHECK-LABEL: func @test_load_transpose_fp8_e5m2 + func.func @test_load_transpose_fp8_e5m2(%src: memref<64x128xf8E5M2, #gpu.address_space>, %i: index, %j: index) -> vector<8xf8E5M2> { + // CHECK: amdgpu.transpose_load %arg0[%arg1, %arg2] : memref<64x128xf8E5M2, #gpu.address_space> -> vector<8xf8E5M2> + %v = rock.lds_transpose_load %src[%i, %j] : memref<64x128xf8E5M2, #gpu.address_space> -> vector<8xf8E5M2> + return %v : vector<8xf8E5M2> + } } diff --git a/mlir/test/Dialect/Rock/ops.mlir b/mlir/test/Dialect/Rock/ops.mlir index fafe8fa665d6..1a56f6352788 100644 --- a/mlir/test/Dialect/Rock/ops.mlir +++ b/mlir/test/Dialect/Rock/ops.mlir @@ -433,6 +433,27 @@ func.func @rock_lds_transpose_load_full_arch(%lds_buffer: memref<128x64xf16, #gp return } +// CHECK-LABEL: func.func @rock_lds_transpose_load_fp8_e4m3 +// CHECK: rock.lds_transpose_load +func.func @rock_lds_transpose_load_fp8_e4m3(%lds_buffer: memref<128x64xf8E4M3FN, #gpu.address_space>) + attributes {arch = "amdgcn-amd-amdhsa:gfx950"} { + %c0 = arith.constant 0 : index + %fragment = rock.lds_transpose_load %lds_buffer[%c0, %c0] + : memref<128x64xf8E4M3FN, #gpu.address_space> -> vector<8xf8E4M3FN> + return +} + +// CHECK-LABEL: func.func @rock_lds_transpose_load_fp8_e5m2 +// CHECK: rock.lds_transpose_load +func.func @rock_lds_transpose_load_fp8_e5m2(%lds_buffer: memref<256x128xf8E5M2, #gpu.address_space>) + attributes {arch = "amdgcn-amd-amdhsa:gfx950"} { + %c0 = arith.constant 0 : index + %c32 = arith.constant 32 : index + %fragment = rock.lds_transpose_load %lds_buffer[%c32, %c0] + : memref<256x128xf8E5M2, #gpu.address_space> -> vector<8xf8E5M2> + return +} + // CHECK-LABEL: func.func @test_lds_transpose_config_attr_16x32 // CHECK: ldsTransposeConfig = #rock.lds_transpose_config func.func @test_lds_transpose_config_attr_16x32(%src: memref<8192xf16, #gpu.address_space>, diff --git a/mlir/test/e2e/CMakeLists.txt b/mlir/test/e2e/CMakeLists.txt index ca0d7c2473f3..1ffabef3538a 100644 --- a/mlir/test/e2e/CMakeLists.txt +++ b/mlir/test/e2e/CMakeLists.txt @@ -50,6 +50,7 @@ if (ROCMLIR_DRIVER_PR_E2E_TEST_ENABLED) PrConvElementwiseGemmBF16SplitK PrGemmDirectToLDS PrLdsTransposeLoad + PrLdsTransposeLoadFp8 PrLdsTransposeLoadAttention PrConvDirectToLDS PrAttentionDirectToLDS diff --git a/mlir/test/e2e/PrLdsTransposeLoadFp8.cfg b/mlir/test/e2e/PrLdsTransposeLoadFp8.cfg new file mode 100644 index 000000000000..46909aa10a02 --- /dev/null +++ b/mlir/test/e2e/PrLdsTransposeLoadFp8.cfg @@ -0,0 +1,2 @@ +if not 'lds_transpose_load' in config.features: + config.unsupported = True diff --git a/mlir/test/e2e/PrLdsTransposeLoadFp8.toml b/mlir/test/e2e/PrLdsTransposeLoadFp8.toml new file mode 100644 index 000000000000..41b36b149e6a --- /dev/null +++ b/mlir/test/e2e/PrLdsTransposeLoadFp8.toml @@ -0,0 +1,126 @@ +directory = "PrLdsTransposeLoadFp8" +prefix = "rocmlir-gen" +suffix = "--operation gemm --arch %arch %pv %constrained_float_range_random_data %rocmlir_gen_flags | rocmlir-driver -c | mlir-runner -O2 --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext --entry-point-result=void | FileCheck %s --check-prefix=" + +[[axis]] +name = "data type" +values = ["fp8_fp8", "bf8_bf8", "fp8_bf8", "bf8_fp8"] +prefix = "-t " + +# ============================================================================ +# Suite 1: BOTH A and B use LDS transpose (transA=true, transB=false) +# 5 tests covering different MFMA geometries and kpack/kPerBlock combinations +# Only fp8_fp8 and bf8_bf8 +# ============================================================================ +[[suite]] +name = "lds_transpose_both_operands_fp8" + +[[suite.test]] +config = "-g 1 -m 64 -k 128 -n 64 --transA=true --transB=false --perf_config v3:64,64,32,16,16,8,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +[[suite.test]] +config = "-g 1 -m 128 -k 64 -n 64 --transA=true --transB=false --perf_config v3:128,64,16,32,16,16,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +[[suite.test]] +config = "-g 1 -m 256 -k 256 -n 64 --transA=true --transB=false --perf_config v3:256,64,8,64,16,32,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +[[suite.test]] +config = "-g 1 -m 128 -k 128 -n 128 --transA=true --transB=false --perf_config v3:128,128,16,32,32,8,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +[[suite.test]] +config = "-g 1 -m 256 -k 128 -n 128 --transA=true --transB=false --perf_config v3:256,128,32,64,32,8,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +# ============================================================================ +# Suite 2: ONLY A uses LDS transpose (transA=true, transB=true) +# 3 tests - B uses regular load +# Only fp8_fp8 and bf8_bf8 +# ============================================================================ +[[suite]] +name = "lds_transpose_A_only_fp8" + +[[suite.test]] +config = "-g 1 -m 64 -k 64 -n 64 --transA=true --transB=true --perf_config v3:64,64,16,16,16,8,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +[[suite.test]] +config = "-g 1 -m 128 -k 128 -n 64 --transA=true --transB=true --perf_config v3:128,64,8,32,16,16,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +[[suite.test]] +config = "-g 1 -m 128 -k 256 -n 128 --transA=true --transB=true --perf_config v3:128,128,32,32,32,8,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +# ============================================================================ +# Suite 3: ONLY B uses LDS transpose (transA=false, transB=false) +# 3 tests - A uses regular load +# Only fp8_fp8 and bf8_bf8 +# ============================================================================ +[[suite]] +name = "lds_transpose_B_only_fp8" + +[[suite.test]] +config = "-g 1 -m 64 -k 64 -n 64 --transA=false --transB=false --perf_config v3:64,64,16,16,16,8,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +[[suite.test]] +config = "-g 1 -m 128 -k 128 -n 64 --transA=false --transB=false --perf_config v3:128,64,32,32,16,1,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +[[suite.test]] +config = "-g 1 -m 128 -k 256 -n 128 --transA=false --transB=false --perf_config v3:128,128,8,32,32,32,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +# ============================================================================ +# Suite 4: Mixed FP8/BF8 types (fp8_bf8 and bf8_fp8 only) +# 3 tests: both operands, only A, only B +# ============================================================================ +[[suite]] +name = "lds_transpose_mixed_fp8_bf8" + +# Test 1: Mixed types, both operands use LDS transpose, 16x32 MFMA +[[suite.test]] +config = "-g 1 -m 64 -k 128 -n 64 --transA=true --transB=false --perf_config v3:64,64,32,16,16,8,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_fp8", "bf8_bf8"] + +# Test 2: Mixed types, only A uses LDS transpose, 32x16 MFMA +[[suite.test]] +config = "-g 1 -m 128 -k 128 -n 128 --transA=true --transB=true --perf_config v3:128,128,16,32,32,8,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_fp8", "bf8_bf8"] + +# Test 3: Mixed types, only B uses LDS transpose, 16x32 MFMA +[[suite.test]] +config = "-g 1 -m 64 -k 64 -n 64 --transA=false --transB=false --perf_config v3:64,64,16,16,16,8,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_fp8", "bf8_bf8"]