diff --git a/mlir/include/mlir/Dialect/Rock/IR/RockOps.td b/mlir/include/mlir/Dialect/Rock/IR/RockOps.td index 8d70078b06f0..988f980eb3d0 100644 --- a/mlir/include/mlir/Dialect/Rock/IR/RockOps.td +++ b/mlir/include/mlir/Dialect/Rock/IR/RockOps.td @@ -1221,9 +1221,10 @@ defvar SameShapeVectorOfI1 = [{ def Rock_LDSTransposeLoadOp : Rock_Op<"lds_transpose_load", [DeclareOpInterfaceMethods< MemoryEffectsOpInterface>]>, - Arguments<(ins Arg, "LDS source buffer">:$source, + Arguments<(ins Arg, + "LDS source buffer">:$source, Variadic:$indices)>, - Results<(outs VectorOfLengthAndType<[4], [F16, BF16]>:$result)> { + Results<(outs AnyVectorOfNonZeroRank:$result)> { let summary = "Hardware-assisted LDS transpose load for matrix accelerator tile"; let description = [{ @@ -1232,9 +1233,15 @@ def Rock_LDSTransposeLoadOp The tile dimensions match the selected matrix-multiply accelerator instruction geometry (`dDim × instrK`), where: - `dDim` is the accelerator M/N dimension (e.g., 16 or 32) - - `instrK` is the accelerator K dimension (e.g., 8, 16, or 32) - The operation returns a vector of 4 elements per thread containing - transposed elements in a layout suitable for matrix accelerator instructions. + - `instrK` is the accelerator K dimension (e.g., 8, 16, 32, or 64) + + For 16-bit types (f16, bf16): + - Uses ds_read_tr16_b64 instruction + - Returns vector<4xtype> (4 elements per thread) + + For 8-bit types (f8E4M3FN, f8E5M2 - OCP FP8, i8 - INT8 for gfx950): + - Uses ds_read_tr8_b64 instruction + - Returns vector<8xtype> (8 elements per thread) Benefits: - Reduces LDS bank conflicts through optimized access patterns diff --git a/mlir/lib/Dialect/Rock/IR/RockDialect.cpp b/mlir/lib/Dialect/Rock/IR/RockDialect.cpp index 5b7c3131817e..bfa4a2d312a6 100644 --- a/mlir/lib/Dialect/Rock/IR/RockDialect.cpp +++ b/mlir/lib/Dialect/Rock/IR/RockDialect.cpp @@ -559,9 +559,13 @@ LogicalResult TransformMapAttr::verify( } // Helper function to check valid MFMA geometry for LDS transpose +// Supported geometries: +// - F16/BF16: (16,16), (16,32), (32,8), (32,16) +// - FP8/BF8: (16,32), (32,16) +// - INT8: (16,32), (16,64), (32,16), (32,32) static bool isValidLdsTransposeMfmaGeometry(int64_t dDim, int64_t kDim) { - return (dDim == 16 && (kDim == 16 || kDim == 32)) || - (dDim == 32 && (kDim == 8 || kDim == 16)); + return (dDim == 16 && (kDim == 16 || kDim == 32 || kDim == 64)) || + (dDim == 32 && (kDim == 8 || kDim == 16 || kDim == 32)); } LogicalResult LDSTransposeConfigAttr::verify( @@ -573,7 +577,7 @@ LogicalResult LDSTransposeConfigAttr::verify( if (!isValidLdsTransposeMfmaGeometry(dDim, kDim)) { return emitError() << "invalid MFMA geometry (" << dDim << "x" << kDim << ") for LDS transpose - valid combinations: " - "(16,16), (16,32), (32,8), (32,16)"; + "(16,16), (16,32), (16,64), (32,8), (32,16), (32,32)"; } // Validate positive dimensions @@ -2161,6 +2165,29 @@ LogicalResult LDSTransposeLoadOp::verify() { << srcElemType << ")"; } + // Verify result vector length based on element type: + // - 16-bit types (f16, bf16): ds_read_tr16_b64 returns 4 elements + // - 8-bit types (f8E4M3FN, f8E5M2 - OCP FP8 for gfx950): ds_read_tr8_b64 + // returns 8 elements + int64_t expectedVecLen; + if (srcElemType.isF16() || srcElemType.isBF16()) { + expectedVecLen = 4; + } else if (isa(srcElemType) || + isa(srcElemType)) { + expectedVecLen = 8; + } else if (srcElemType.isInteger(8)) { + expectedVecLen = 8; + } else { + return emitOpError("unsupported element type for LDS transpose load: ") + << srcElemType; + } + + if (resultType.getNumElements() != expectedVecLen) { + return emitOpError("expected result vector of ") + << expectedVecLen << " elements for " << srcElemType + << " type, but got " << resultType.getNumElements(); + } + // Check hardware support using AmdArchDb StringRef arch = rock::getArchValue(*this); AmdArchInfo archInfo = rock::lookupArchInfo(arch); diff --git a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp index 8c91c4083bb2..4aede40e591a 100644 --- a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp @@ -2489,7 +2489,7 @@ struct GridwiseAttentionAccelRewritePattern ldsLayoutCfgMG0.doRotateWithK, ldsLayoutCfgMG0.doSwapThreadIterSubDims, ldsLayoutCfgMG0.ldsLayoutDxK, directToLDS, /*splitKAcrossThreadsFirst=*/false, gemm0G, gemm0M, gemm0InMPerThread, - /*ldsTransposeEnabled=*/ldsDecisionGemm0.enableA, + /*ldsTransposeEnabled=*/false, /*accelDDim=*/ldsDecisionGemm0.mfmaDDim, /*accelKDim=*/ldsDecisionGemm0.mfmaKDim); @@ -2498,7 +2498,7 @@ struct GridwiseAttentionAccelRewritePattern ldsLayoutCfgNG0.doRotateWithK, ldsLayoutCfgNG0.doSwapThreadIterSubDims, ldsLayoutCfgNG0.ldsLayoutDxK, directToLDSQ, /*splitKAcrossThreadsFirst=*/false, gemm0G, gemm0N, gemm0InNPerThread, - /*ldsTransposeEnabled=*/ldsDecisionGemm0.enableB, + /*ldsTransposeEnabled=*/false, /*accelDDim=*/ldsDecisionGemm0.mfmaDDim, /*accelKDim=*/ldsDecisionGemm0.mfmaKDim); @@ -3384,6 +3384,41 @@ struct GridwiseGemmAccelRewritePattern directToLDS, ldsLayoutConfigA, ldsLayoutConfigB, mPerBlock, nPerBlock, kPerBlock, mPerWave, nPerWave, kpack, doubleBuffering); + // FP8 heuristic: check if LDS transpose should be disabled based on dims + // Rule 1: K >= 1280 causes regression, EXCEPT when K > otherDim AND + // otherDim > 1280 + // Rule 2: Small square K-N matrices (K == N && K < 512) + auto shouldDisableLdsTranspose = [K, N](int64_t otherDim) -> bool { + if (K >= 1280 && !(K > otherDim && otherDim > 1280)) + return true; + if (K == N && K < 512) + return true; + return false; + }; + + bool isFp8Type = isa(elementTypeA) || + isa(elementTypeA); + + if (isFp8Type) { + // Case 1: TransA=false, TransB=false - only B uses LDS transpose + if (!ldsDecision.enableA && ldsDecision.enableB && + shouldDisableLdsTranspose(N)) + ldsDecision.enableB = false; + + // Case 2: TransA=true, TransB=true - only A uses LDS transpose + if (ldsDecision.enableA && !ldsDecision.enableB && + shouldDisableLdsTranspose(M)) + ldsDecision.enableA = false; + } + + // INT8 CONV heuristic: disable LDS transpose for N=1600 (40x40 spatial) + // when K<=M or K>2*M. + bool isInt8Type = elementTypeA.isInteger(8); + if (isInt8Type && N == 1600 && (K <= M || K > 2 * M)) { + ldsDecision.enableA = false; + ldsDecision.enableB = false; + } + LLVM_DEBUG(llvm::dbgs() << "M: " << M << "\n" << "N: " << N << "\n" diff --git a/mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp b/mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp index 711d4a7b4edf..6ab4fce4425b 100644 --- a/mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp +++ b/mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp @@ -594,6 +594,64 @@ Value MfmaEmitter::wrapLDSBufferForLoad( toLDSRowCol.unmerge("k", 1, {"k_iter", "k_mfma", "blk_k", "k_base"}, {kIter, numMfmaPerKVec, numBlksInK, kBase}); + } else if (useLdsTransposeLoad && kBase == 16) { + // This handles INT8 32x32 (instrK=32, kBase=16) and INT8 16x64 + // (instrK=64, kBase=16) when kpack=1 + int64_t instrK = mfmaAttr.k; + int64_t numBlksInK = instrK / kBase; + int64_t numBlksInD = (waveSize / inputSpanLen) / numBlksInK; + + // Split blk_id into blk_d (for D dimension) and blk_k (for K dimension) + TopDownTMBuilder splitBlkId = + TopDownTMBuilder::below(splitWaveId, splitWaveIdAttr); + splitBlkId.passThrough({"wave_m", "wave_n"}, {0, 1}, + {"wave_m", "wave_n"}); + splitBlkId.merge({"blk_d", "blk_k"}, {2, 3}, "blk_id", + {numBlksInD, numBlksInK}); + splitBlkId.passThrough({"blk_td", "d_iter", "k_iter", "k_vec"}, + {4, 5, 6, 7}, + {"blk_td", "d_iter", "k_iter", "k_vec"}); + TransformMapAttr splitBlkIdAttr = splitBlkId.get(); + transformAttrs.push_back(splitBlkIdAttr); + + // For kVec < kBase, we split k_iter into outer_k (MFMA iterations) + // and inner_k (within kBase for each blk_k) + // Total K per thread = kIter * kVec + // Each thread covers K range: blk_k * kBase to blk_k * kBase + kBase - 1 + // Number of MFMA iterations = (kIter * kVec) / kBase (iterations per + // blk_k) K iterations within each kBase = kBase / kVec Constraint: + // numMfmaIters * kIterPerBlk = kIter + int64_t kIterPerBlk = kBase / kVec; // iterations to cover kBase elements + int64_t numMfmaIters = kIter / kIterPerBlk; // number of outer iterations + + TopDownTMBuilder splitKIter = + TopDownTMBuilder::below(splitBlkId, splitBlkIdAttr); + splitKIter.passThrough({"wave_m", "wave_n"}, {0, 1}, + {"wave_m", "wave_n"}); + splitKIter.passThrough({"blk_d", "blk_k", "blk_td", "d_iter"}, + {2, 3, 4, 5}, + {"blk_d", "blk_k", "blk_td", "d_iter"}); + splitKIter.merge({"outer_k", "inner_k"}, {6, 7}, "k_iter", + {numMfmaIters, kIterPerBlk}); + splitKIter.passThrough({"k_vec"}, {8}, {"k_vec"}); + TransformMapAttr splitKIterAttr = splitKIter.get(); + transformAttrs.push_back(splitKIterAttr); + + toLDSRowCol = TopDownTMBuilder::below(splitKIter, splitKIterAttr); + + // d = d_iter * dWaves * numBlksInD * inputSpanLen + wave_d * numBlksInD * + // inputSpanLen + blk_d * inputSpanLen + blk_td + toLDSRowCol.unmerge("d", 0, {"d_iter", thisWaveDim, "blk_d", "blk_td"}, + {dRepeats, dWaves, numBlksInD, inputSpanLen}); + + // k = outer_k * instrK + blk_k * kBase + inner_k * kVec + k_vec + // This ensures K access pattern matches transpose load: + // - outer_k selects which MFMA iteration + // - blk_k selects which K block (0..kBase-1 or kBase..instrK-1) + // - inner_k * kVec + k_vec gives position within the K block + toLDSRowCol.unmerge("k", 1, {"outer_k", "blk_k", "inner_k", "k_vec"}, + {numMfmaIters, numBlksInK, kIterPerBlk, kVec}); + } else { // Standard formula for regular load scenarios toLDSRowCol = TopDownTMBuilder::below(splitWaveId, splitWaveIdAttr); diff --git a/mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp b/mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp index d4f021dc2715..18297c496dd9 100644 --- a/mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp +++ b/mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp @@ -21,8 +21,14 @@ // to the LDS transpose load operation in an accelerator-friendly layout. // // It is intended to simplify the IR generation logic and ensure -// consistent handling of f16/bf16 matrix accelerator tile loads from LDS -// memory. +// consistent handling of f16/bf16/fp8/bf8/i8 matrix accelerator tile loads from +// LDS memory. +// +// Supported element types: +// - f16, bf16: uses ds_read_tr16_b64 (returns 4 elements per thread) +// - f8E4M3FN, f8E5M2 (OCP FP8): uses ds_read_tr8_b64 (returns 8 elements per +// thread) +// - i8 (INT8): uses ds_read_tr8_b64 (returns 8 elements per thread) // //===----------------------------------------------------------------------===// @@ -47,11 +53,51 @@ namespace { bool archSupported(StringRef arch) { return arch.contains("gfx950"); } +// Check if element type is supported for LDS transpose load +// - f16, bf16: ds_read_tr16_b64 (4 elements) +// - f8E4M3FN, f8E5M2 (OCP FP8 for gfx950): ds_read_tr8_b64 (8 elements) +// - i8 (INT8 for gfx950): ds_read_tr8_b64 (8 elements) +static bool isSupportedElementType(Type t) { + return t.isF16() || t.isBF16() || isa(t) || + isa(t) || t.isInteger(8); +} + +// Check if element type is 8-bit float (FP8 E4M3 or BF8 E5M2) +// Used for: +// 1. Selecting ds_read_tr8_b64 vs ds_read_tr16_b64 +// 2. Checking mixed-type compatibility (fp8+bf8 combinations are valid) +static bool isFp8Type(Type t) { + return isa(t) || isa(t); +} + +// Check if element type is INT8 (i8) +// INT8 uses ds_read_tr8_b64 instruction like FP8 +static bool isInt8Type(Type t) { return t.isInteger(8); } + +// Check if element type uses 8-bit transpose load (FP8, BF8, or INT8) +// All these types use ds_read_tr8_b64 which returns 8 elements per thread +static bool uses8BitTransposeLoad(Type t) { + return isFp8Type(t) || isInt8Type(t); +} + +// Returns the number of elements returned by LDS transpose load instruction +static int64_t getTransposeLoadVectorLength(Type elemType) { + if (elemType.isF16() || elemType.isBF16()) { + return 4; // ds_read_tr16_b64 + } else if (uses8BitTransposeLoad(elemType)) { + return 8; // ds_read_tr8_b64 (FP8, BF8, INT8) + } + llvm_unreachable("Unsupported element type for LDS transpose load"); +} + // Validates MFMA geometry for LDS transpose support. -// Only specific combinations are supported: (16,16), (16,32), (32,8), (32,16) +// Supported combinations: +// - F16/BF16: (16,16), (16,32), (32,8), (32,16) +// - FP8/BF8: (16,32), (32,16) +// - INT8: (16,32), (16,64), (32,16), (32,32) static bool isValidMfmaGeometry(int64_t dDim, int64_t kDim) { - return (dDim == 16 && (kDim == 16 || kDim == 32)) || - (dDim == 32 && (kDim == 8 || kDim == 16)); + return (dDim == 16 && (kDim == 16 || kDim == 32 || kDim == 64)) || + (dDim == 32 && (kDim == 8 || kDim == 16 || kDim == 32)); } // Shape of a single MFMA instruction (internal use only). @@ -119,7 +165,16 @@ static Decision makeDecision(StringRef arch, Type elemTypeA, Type elemTypeB, return dec; } - if (elemTypeA != elemTypeB || !(elemTypeA.isF16() || elemTypeA.isBF16())) + // Check type compatibility: + // - Same types are always allowed (f16==f16, fp8==fp8, etc.) + // - Mixed FP8/BF8 combinations are allowed (hardware supports mixed fp8/bf8 + // MFMA: + // mfma_f32_16x16x32_fp8_fp8, mfma_f32_16x16x32_fp8_bf8, etc.) + // - Other mixed types are NOT allowed (e.g., f16 with fp8) + bool typesCompatible = (elemTypeA == elemTypeB) || + (isFp8Type(elemTypeA) && isFp8Type(elemTypeB)); + if (!typesCompatible || !isSupportedElementType(elemTypeA) || + !isSupportedElementType(elemTypeB)) return dec; // Validate MFMA geometry @@ -294,7 +349,7 @@ LDSTransposeConfigAttr buildTransposeAttrFromParams( "MFMA geometry must be set when building transpose attributes"); assert(isValidMfmaGeometry(mfmaDDim, mfmaKDim) && "Invalid MFMA geometry for LDS transpose - valid: (16,16), (16,32), " - "(32,8), (32,16)"); + "(16,64), (32,8), (32,16), (32,32)"); // Create structured attribute with all parameters return LDSTransposeConfigAttr::get(rewriter.getContext(), mfmaDDim, mfmaKDim, @@ -391,14 +446,29 @@ static Value getDoubleRateKOffsetBase(PatternRewriter &b, Location loc, Value kOffsetBase; if (dDim == 32 && kDim == 16) { - // 32x16 layout + // 32x16 layout (F16/BF16) kOffsetBase = arith::MulIOp::create( b, loc, arith::DivUIOp::create(b, loc, blockId, c2), c8); } else if (dDim == 16 && kDim == 32) { - // 16x32 layout + // 16x32 layout (F16/BF16) kOffsetBase = arith::MulIOp::create(b, loc, blockId, c8); + } else if (dDim == 16 && kDim == 64) { + // 16x64 layout (INT8): k_offset_base = block_id * 16 + // Each block covers 16 K values, with first/second call covering + // K=0..7/8..15 within each block's 16-K range + Value c16val = arith::ConstantIndexOp::create(b, loc, 16); + kOffsetBase = arith::MulIOp::create(b, loc, blockId, c16val); + } else if (dDim == 32 && kDim == 32) { + // 32x32 layout (INT8): k_offset_base = k_block * 16 + // where k_block = block_id / 2 + // Block 0,1 (k_block=0) cover K=0..15, Block 2,3 (k_block=1) cover K=16..31 + // First/second call covers K+0..7 / K+8..15 within each range + Value c16val = arith::ConstantIndexOp::create(b, loc, 16); + Value kBlock = arith::DivUIOp::create(b, loc, blockId, c2); + kOffsetBase = arith::MulIOp::create(b, loc, kBlock, c16val); } else { - llvm_unreachable("Invalid double-rate geometry - must be 32x16 or 16x32"); + llvm_unreachable( + "Invalid double-rate geometry - must be 32x16, 16x32, 16x64, or 32x32"); } return kOffsetBase; @@ -415,15 +485,17 @@ static Value getDoubleRateKOffsetBase(PatternRewriter &b, Location loc, // Single-rate: k_final = k_base_local + (kTileIdx * kTileStride) // Double-rate: k_final = k_base_local + kOffsetBase + (kTileIdx * // kTileStride) + halfOffset -// where halfOffset = 0 for low half, 4 for high half +// where halfOffset = 0 for low half, 4 for F16/BF16 high half, +// 8 for INT8 high half // // Parameters: -// isDoubleRate - Whether this is a double-rate layout (L32x16, L16x32) -// kBaseLocal - Local K base offset from computeLDSBaseOffsets() +// isDoubleRate - Whether this is a double-rate layout (L32x16, L16x32, +// L16x64) kBaseLocal - Local K base offset from computeLDSBaseOffsets() // kOffsetBase - Double-rate K offset base (from getDoubleRateKOffsetBase) // kTileIdx - Current K tile index (0, 1, 2, ...) -// kTileStride - K stride per tile (instrK, e.g., 8 or 16) -// isHighHalf - For double-rate: true = high half (+4), false = low half +// kTileStride - K stride per tile (instrK, e.g., 8, 16, or 64) +// isHighHalf - For double-rate: true = high half, false = low half +// halfOffset - Offset to add for high half (4 for F16/BF16, 8 for INT8) // // Returns: // Final K offset value to use for emitPanelLoad() @@ -431,13 +503,13 @@ static Value getDoubleRateKOffsetBase(PatternRewriter &b, Location loc, static Value computePanelFinalOffset(PatternRewriter &b, Location loc, bool isDoubleRate, Value kBaseLocal, Value kOffsetBase, int64_t kTileIdx, - Value kTileStride, - bool isHighHalf = false) { + Value kTileStride, bool isHighHalf = false, + int64_t halfOffset = 4) { Value kBase = kBaseLocal; if (isDoubleRate) { - // Double-rate: k_offset = kOffsetBase + kTileIdx * kTileStride [+ 4 for - // high] + // Double-rate: k_offset = kOffsetBase + kTileIdx * kTileStride [+ + // halfOffset] Value kTileOffset; if (kTileIdx > 0) { Value kIdxVal = arith::ConstantIndexOp::create(b, loc, kTileIdx); @@ -448,10 +520,10 @@ static Value computePanelFinalOffset(PatternRewriter &b, Location loc, Value k_offset = arith::AddIOp::create(b, loc, kOffsetBase, kTileOffset); - // For high half, add 4 + // For high half, add halfOffset (4 for F16/BF16, 8 for INT8) if (isHighHalf) { - Value c4 = arith::ConstantIndexOp::create(b, loc, 4); - k_offset = arith::AddIOp::create(b, loc, k_offset, c4); + Value cHalfOffset = arith::ConstantIndexOp::create(b, loc, halfOffset); + k_offset = arith::AddIOp::create(b, loc, k_offset, cHalfOffset); } // k_base = k_base_local + k_offset @@ -477,8 +549,9 @@ static Value computePanelFinalOffset(PatternRewriter &b, Location loc, // emitPanelLoad - Emit an LDS transpose load operation // // Computes the final LDS offset and emits a hardware LDS transpose load -// instruction (ds_read_tr16_b64). This instruction always returns vector<4xf16> -// regardless of the layout. +// instruction: +// - ds_read_tr16_b64 for f16/bf16: returns vector<4> +// - ds_read_tr8_b64 for fp8/bf8: returns vector<8> // // The final offset is computed as: final_offset = k_base * ldsStride + m_base // where ldsStride depends on the operand (mPerBlock for A, nPerBlock for B). @@ -491,10 +564,10 @@ static Value computePanelFinalOffset(PatternRewriter &b, Location loc, // kBase - K dimension base offset for this panel // mBase - M/N dimension base offset for this panel // ldsStride - Stride between K rows in LDS (mPerBlock or nPerBlock) -// panelVecType - Result type (always vector<4xf16> or vector<4xbf16>) +// panelVecType - Result type (vector<4> for f16/bf16, vector<8> for fp8/bf8) // // Returns: -// The loaded panel vector (vector<4xf16/bf16>) +// The loaded panel vector (vector<4> for f16/bf16, vector<8> for fp8/bf8) //===----------------------------------------------------------------------===// static Value emitPanelLoad(PatternRewriter &b, Location loc, Value rawSrc, Value kBase, Value mBase, Value ldsStride, @@ -503,7 +576,7 @@ static Value emitPanelLoad(PatternRewriter &b, Location loc, Value rawSrc, Value kOffset = arith::MulIOp::create(b, loc, kBase, ldsStride); Value finalOffset = arith::AddIOp::create(b, loc, mBase, kOffset); - // Emit hardware LDS transpose load: ds_read_tr16_b64 + // Emit hardware LDS transpose load auto loadOp = rock::LDSTransposeLoadOp::create(b, loc, panelVecType, rawSrc, ValueRange{finalOffset}); @@ -518,9 +591,9 @@ static Value emitPanelLoad(PatternRewriter &b, Location loc, Value rawSrc, // elements (ds_read_tr16_b64 always returns vector<4xf16>). // // Parameters: -// panelVectors - Array of loaded panel vectors (each is vector<4xf16>) -// dest - Destination memref (rank-1, scalar layout) -// targetElems - Maximum number of elements to write +// panelVectors - Array of loaded panel vectors (vector<4> for f16/bf16, +// vector<8> for fp8/bf8) dest - Destination memref (rank-1, scalar +// layout) targetElems - Maximum number of elements to write // // Returns: // success() if all target elements were written @@ -534,14 +607,16 @@ writePanelVectorsToDestination(PatternRewriter &b, Location loc, int64_t produced = 0; // Extract elements per vector from the actual vector type - // Hardware instruction ds_read_tr16_b64 always returns vector<4xf16> + // Hardware instructions: + // - ds_read_tr16_b64 returns vector<4xf16/bf16> + // - ds_read_tr8_b64 returns vector<8xfp8> assert(!panelVectors.empty() && "Panel vectors array must not be empty"); auto panelVecType = cast(panelVectors[0].getType()); int64_t elementsPerVector = panelVecType.getShape()[0]; - // Verify hardware constraint: ds_read_tr16_b64 returns exactly 4 elements - assert(elementsPerVector == 4 && - "LDS transpose load must produce vector<4xf16> per panel"); + // Verify hardware constraint: 4 elements for 16-bit, 8 elements for 8-bit + assert((elementsPerVector == 4 || elementsPerVector == 8) && + "LDS transpose load must produce vector<4> or vector<8> per panel"); LLVM_DEBUG(llvm::dbgs() << "[lds_transpose] Writing " << panelVectors.size() << " panel vectors (" @@ -597,60 +672,133 @@ writePanelVectorsToDestination(PatternRewriter &b, Location loc, //===----------------------------------------------------------------------===// static SmallVector getBasePanelOffsets(PatternRewriter &b, Location loc, int64_t dDim, int64_t kDim, - Value lane) { - Value c16 = arith::ConstantIndexOp::create(b, loc, 16); - Value c4 = arith::ConstantIndexOp::create(b, loc, 4); + Value lane, Type elemType) { + // Common constants used by both FP8 and F16/BF16 Value c2 = arith::ConstantIndexOp::create(b, loc, 2); + Value c4 = arith::ConstantIndexOp::create(b, loc, 4); + Value c16 = arith::ConstantIndexOp::create(b, loc, 16); + + Value kOffsetBase, mOffsetBase; + + // 8-bit types (FP8, BF8, INT8) use ds_read_tr8_b64 + // They share the same offset formulas for 16x32 and 32x16 geometries + if (uses8BitTransposeLoad(elemType)) { + Value c8 = arith::ConstantIndexOp::create(b, loc, 8); + + Value blockId = arith::DivUIOp::create(b, loc, lane, c16); + Value laneInBlock = arith::RemUIOp::create(b, loc, lane, c16); + Value kLocal = arith::DivUIOp::create(b, loc, laneInBlock, c2); + Value mParity = arith::RemUIOp::create(b, loc, laneInBlock, c2); + + if (dDim == 16 && kDim == 32) { + // 16x32: 4-block formula (FP8/BF8/INT8) + // Block layout: + // Block 0 (T0-T15): K=0..7 + // Block 1 (T16-T31): K=8..15 + // Block 2 (T32-T47): K=16..23 + // Block 3 (T48-T63): K=24..31 + + // kOffsetBase = k_local + block_id * 8 + Value blockKOffset = arith::MulIOp::create(b, loc, blockId, c8); + kOffsetBase = arith::AddIOp::create(b, loc, kLocal, blockKOffset); + + // mOffsetBase = m_parity * 8 + mOffsetBase = arith::MulIOp::create(b, loc, mParity, c8); + + } else if (dDim == 32 && kDim == 16) { + // 32x16: 4-block formula (FP8/BF8/INT8) + // Block layout: + // Block 0 (T0-T15): M=0..15, K=0..7 + // Block 1 (T16-T31): M=16..31, K=0..7 + // Block 2 (T32-T47): M=0..15, K=8..15 + // Block 3 (T48-T63): M=16..31, K=8..15 + + // m_block = block_id % 2, k_block = block_id / 2 + Value mBlock = arith::RemUIOp::create(b, loc, blockId, c2); + Value kBlock = arith::DivUIOp::create(b, loc, blockId, c2); + + // kOffsetBase = k_local + k_block * 8 + Value kBlockOffset = arith::MulIOp::create(b, loc, kBlock, c8); + kOffsetBase = arith::AddIOp::create(b, loc, kLocal, kBlockOffset); + + // mOffsetBase = m_parity * 8 + m_block * 16 + Value mParityOffset = arith::MulIOp::create(b, loc, mParity, c8); + Value mBlockOffset = arith::MulIOp::create(b, loc, mBlock, c16); + mOffsetBase = arith::AddIOp::create(b, loc, mParityOffset, mBlockOffset); + + } else if (isInt8Type(elemType) && dDim == 16 && kDim == 64) { + // INT8 16x64: Double-rate layout (INT8 only) + // k_base_local = k_local (block offset comes from + // getDoubleRateKOffsetBase) m_base = m_parity * 8 + + kOffsetBase = kLocal; + mOffsetBase = arith::MulIOp::create(b, loc, mParity, c8); + + } else if (isInt8Type(elemType) && dDim == 32 && kDim == 32) { + // INT8 32x32: Double-rate layout (INT8 only) + // k_base_local = k_local (k_block * 16 comes from + // getDoubleRateKOffsetBase) m_base = m_parity * 8 + m_block * 16 + + Value mBlock = arith::RemUIOp::create(b, loc, blockId, c2); + + kOffsetBase = kLocal; + + Value mParityOffset = arith::MulIOp::create(b, loc, mParity, c8); + Value mBlockOffset = arith::MulIOp::create(b, loc, mBlock, c16); + mOffsetBase = arith::AddIOp::create(b, loc, mParityOffset, mBlockOffset); + + } else { + llvm_unreachable( + "Unsupported 8-bit type MFMA geometry in getBasePanelOffsets"); + } - Value blockId = arith::DivUIOp::create(b, loc, lane, c16); - Value laneInBlock = arith::RemUIOp::create(b, loc, lane, c16); - - // Base offset calculations - Value mOffsetBase = arith::MulIOp::create( - b, loc, arith::RemUIOp::create(b, loc, laneInBlock, c4), c4); - Value kOffsetBase = arith::DivUIOp::create(b, loc, laneInBlock, c4); - - SmallVector panelOffsets; - - if (dDim == 16 && kDim == 32) { - // 16x32 layout - panelOffsets = {kOffsetBase, mOffsetBase}; - } else if (dDim == 16 && kDim == 16) { - // 16x16 layout - // kbase = kOffsetBase + (blockId * 4) - Value kBase = arith::AddIOp::create( - b, loc, arith::MulIOp::create(b, loc, blockId, c4), kOffsetBase); - panelOffsets = {kBase, mOffsetBase}; - } else if (dDim == 32 && kDim == 16) { - // 32x16 layout - // mbase = mOffsetBase + (blockId % 2) * 16 - Value mBase = arith::AddIOp::create( - b, loc, - arith::MulIOp::create(b, loc, - arith::RemUIOp::create(b, loc, blockId, c2), c16), - mOffsetBase); - panelOffsets = {kOffsetBase, mBase}; - } else if (dDim == 32 && kDim == 8) { - // 32x8 layout - // k_base_local = kOffsetBase + (blockId / 2) * 4 - Value kBase = arith::AddIOp::create( - b, loc, - arith::MulIOp::create(b, loc, - arith::DivUIOp::create(b, loc, blockId, c2), c4), - kOffsetBase); - - // m_offset_base = mOffsetBase + (blockId % 2) * 16 - Value mBase = arith::AddIOp::create( - b, loc, - arith::MulIOp::create(b, loc, - arith::RemUIOp::create(b, loc, blockId, c2), c16), - mOffsetBase); - panelOffsets = {kBase, mBase}; } else { - llvm_unreachable("Unsupported MFMA geometry in getBasePanelOffsets"); + // F16/BF16 uses block-based lane mapping + Value blockId = arith::DivUIOp::create(b, loc, lane, c16); + Value laneInBlock = arith::RemUIOp::create(b, loc, lane, c16); + + // Base calculations common to all F16/BF16 geometries + Value kLocal = arith::DivUIOp::create(b, loc, laneInBlock, c4); + Value mLocal = arith::MulIOp::create( + b, loc, arith::RemUIOp::create(b, loc, laneInBlock, c4), c4); + + if (dDim == 16 && kDim == 32) { + // 16x32: direct mapping + kOffsetBase = kLocal; + mOffsetBase = mLocal; + + } else if (dDim == 16 && kDim == 16) { + // 16x16: k += blockId * 4 + kOffsetBase = arith::AddIOp::create( + b, loc, kLocal, arith::MulIOp::create(b, loc, blockId, c4)); + mOffsetBase = mLocal; + + } else if (dDim == 32 && kDim == 16) { + // 32x16: m += (blockId % 2) * 16 + kOffsetBase = kLocal; + mOffsetBase = arith::AddIOp::create( + b, loc, mLocal, + arith::MulIOp::create( + b, loc, arith::RemUIOp::create(b, loc, blockId, c2), c16)); + + } else if (dDim == 32 && kDim == 8) { + // 32x8: k += (blockId / 2) * 4, m += (blockId % 2) * 16 + kOffsetBase = arith::AddIOp::create( + b, loc, kLocal, + arith::MulIOp::create( + b, loc, arith::DivUIOp::create(b, loc, blockId, c2), c4)); + mOffsetBase = arith::AddIOp::create( + b, loc, mLocal, + arith::MulIOp::create( + b, loc, arith::RemUIOp::create(b, loc, blockId, c2), c16)); + + } else { + llvm_unreachable( + "Unsupported F16/BF16 MFMA geometry in getBasePanelOffsets"); + } } - return panelOffsets; + return {kOffsetBase, mOffsetBase}; } //===----------------------------------------------------------------------===// @@ -671,6 +819,7 @@ static SmallVector getBasePanelOffsets(PatternRewriter &b, Location loc, // dDim - MFMA D dimension (M or N, 16 or 32) // kDim - MFMA K dimension (8, 16, or 32) // lane - Thread's lane ID within the workgroup +// elemType - Element type (f16, bf16, fp8, or bf8) for selecting lane mapping // // Returns: // std::pair: @@ -679,8 +828,10 @@ static SmallVector getBasePanelOffsets(PatternRewriter &b, Location loc, //===----------------------------------------------------------------------===// static std::pair computeLDSBaseOffsets(PatternRewriter &b, Location loc, int64_t dDim, - int64_t kDim, Value lane) { - SmallVector offsets = getBasePanelOffsets(b, loc, dDim, kDim, lane); + int64_t kDim, Value lane, + Type elemType) { + SmallVector offsets = + getBasePanelOffsets(b, loc, dDim, kDim, lane, elemType); LLVM_DEBUG(llvm::dbgs() << "[lds_transpose] Computed LDS base offsets for " << dDim << "x" << kDim << ": " @@ -1086,8 +1237,9 @@ static Value computeFinalMNOffset(PatternRewriter &b, Location loc, // emitThreadwiseHWTranspose - Lower threadwise_read_into to HW transpose loads //===----------------------------------------------------------------------===// // Lowers threadwise_read_into with LDS transpose config into hardware transpose -// load instructions (ds_read_tr16_b64) that read from LDS in MFMA-friendly -// order. +// load instructions that read from LDS in MFMA-friendly order: +// - ds_read_tr16_b64 for f16/bf16 (returns vector<4>) +// - ds_read_tr8_b64 for fp8/bf8 (returns vector<8>) // // Algorithm: // 1. Extract config: MFMA geometry (dDim, kDim), tiling params, operand kind @@ -1100,8 +1252,8 @@ static Value computeFinalMNOffset(PatternRewriter &b, Location loc, // - Outer: M/N tiles (all at once for double-buffering, one at a time // otherwise) // - Inner: K tiles (1 load for single-rate, 2 loads for double-rate layouts) -// 6. For each iteration: compute final LDS offset, emit ds_read_tr16_b64 -// instruction (returns vector<4xf16>) +// 6. For each iteration: compute final LDS offset, emit LDS transpose load +// instruction (ds_read_tr16_b64 for f16/bf16, ds_read_tr8_b64 for fp8/bf8) // 7. Extract elements from panel vectors and write sequentially to destination // // Example: 16x32 layout, 1 M-tile, 2 K-tiles → 2 ds_read_tr16_b64 calls → 8 @@ -1155,23 +1307,35 @@ LogicalResult emitThreadwiseHWTranspose(PatternRewriter &b, int64_t ldsStride = (operand == OperandKind::A) ? mPerBlock : nPerBlock; // Determine if this is a double-rate instruction - // Double-rate ONLY for (32,16) and (16,32) MFMA - // (16,16) and (32,8) are SINGLE-RATE + // Double-rate layouts require TWO ds_read calls per K tile: + // - F16/BF16 (32,16) and (16,32): ds_read_tr16_b64 returns 4 elements, + // need 2 calls for 8 elements total + // - INT8 (16,64) and (32,32): ds_read_tr8_b64 returns 8 elements, need 2 + // calls for 16 i8 values (packed into v4i32 for mfma_i32_16x16x64_i8) + // Single-rate layouts: + // - FP8/BF8/INT8 (16,32), (32,16): ds_read_tr8_b64 returns 8 elements + // directly + // - F16/BF16 (16,16), (32,8): single ds_read_tr16_b64 per K tile bool isDoubleRate = - (dDim == 32 && instrK == 16) || (dDim == 16 && instrK == 32); + (!uses8BitTransposeLoad(elemType) && + ((dDim == 32 && instrK == 16) || (dDim == 16 && instrK == 32))) || + (isInt8Type(elemType) && + ((dDim == 16 && instrK == 64) || (dDim == 32 && instrK == 32))); - // Each ds_read_tr16_b64 call ALWAYS returns vector<4xf16> - // For double-rate, we make 2 calls and store all 8 elements separately - VectorType panelVecType = VectorType::get({4}, elemType); + // Determine vector length based on element type: + // - f16/bf16: ds_read_tr16_b64 returns vector<4> + // - fp8/bf8: ds_read_tr8_b64 returns vector<8> + int64_t vecLen = getTransposeLoadVectorLength(elemType); + VectorType panelVecType = VectorType::get({vecLen}, elemType); // panelVectors will contain: - // - Single-rate: 1 vector<4xf16> per K tile - // - Double-rate: 2 vector<4xf16> per K tile (low + high) + // - Single-rate: 1 vector per K tile + // - Double-rate (f16/bf16 only): 2 vectors per K tile (low + high) SmallVector panelVectors; // Get base offsets using computeLDSBaseOffsets helper auto [k_base_local, m_offset_base] = - computeLDSBaseOffsets(b, loc, dDim, instrK, lane); + computeLDSBaseOffsets(b, loc, dDim, instrK, lane, elemType); // K stride per tile: instrK (MFMA K dimension) int64_t kTileStride = instrK; @@ -1226,26 +1390,33 @@ LogicalResult emitThreadwiseHWTranspose(PatternRewriter &b, useDynamicMnIndex, waveOffsetStrideVal, tileOffsetStrideVal); if (!isDoubleRate) { - // SINGLE-RATE (L32x8, L16x16): One load per K tile + // SINGLE-RATE (L32x8, L16x16, or FP8/BF8): One load per K tile Value k_base = computePanelFinalOffset(b, loc, isDoubleRate, k_base_local, kOffsetBase, kIdx, kTileStrideVal); - // Emit LDS transpose load for this K tile (single-rate: one per K tile) + // Emit LDS transpose load for this K tile Value panelVec = emitPanelLoad(b, loc, rawSrc, k_base, m_base, ldsStrideVal, panelVecType); panelVectors.push_back(panelVec); } else { - // DOUBLE-RATE (L32x16, L16x32): TWO loads per K tile - // Each load returns vector<4xf16>, total 8 elements per K tile - // Compute K offsets for low and high halves + // DOUBLE-RATE: TWO loads per K tile + // - F16/BF16 (L32x16, L16x32): vector<4> per load, halfOffset=4 + // - INT8 (L16x64): vector<8> per load, halfOffset=8 + // FP8/BF8 is NEVER double-rate (single ds_read_tr8_b64 returns 8 elems) + + // Determine halfOffset based on element type: + // - F16/BF16: 4 (each ds_read_tr16_b64 returns 4 elements) + // - INT8: 8 (each ds_read_tr8_b64 returns 8 elements) + int64_t halfOffsetVal = isInt8Type(elemType) ? 8 : 4; + Value k_base_low = computePanelFinalOffset( b, loc, isDoubleRate, k_base_local, kOffsetBase, kIdx, - kTileStrideVal, /*isHighHalf=*/false); + kTileStrideVal, /*isHighHalf=*/false, halfOffsetVal); Value k_base_high = computePanelFinalOffset( b, loc, isDoubleRate, k_base_local, kOffsetBase, kIdx, - kTileStrideVal, /*isHighHalf=*/true); + kTileStrideVal, /*isHighHalf=*/true, halfOffsetVal); // Emit low half load Value panelVecLow = emitPanelLoad(b, loc, rawSrc, k_base_low, m_base, @@ -1269,8 +1440,10 @@ LogicalResult emitThreadwiseHWTranspose(PatternRewriter &b, int64_t loadsPerKTile = isDoubleRate ? 2 : 1; int64_t expectedLoads = actualMnTiles * kPanels * loadsPerKTile; - // Each load ALWAYS produces 4 elements (ds_read_tr16_b64 → vector<4xf16>) - int64_t sliceElems = expectedLoads * 4; + // Each load produces vecLen elements: + // - f16/bf16: 4 elements (ds_read_tr16_b64) + // - fp8/bf8: 8 elements (ds_read_tr8_b64) + int64_t sliceElems = expectedLoads * vecLen; // Verify we generated the expected number of loads if (panelVectors.size() != (size_t)expectedLoads) { diff --git a/mlir/test/Dialect/Rock/lds_transpose_attributes.mlir b/mlir/test/Dialect/Rock/lds_transpose_attributes.mlir index 5dfc22d8ac21..3d1a2d96ced4 100644 --- a/mlir/test/Dialect/Rock/lds_transpose_attributes.mlir +++ b/mlir/test/Dialect/Rock/lds_transpose_attributes.mlir @@ -65,3 +65,71 @@ module attributes {mhal.arch = "amdgcn-amd-amdhsa:gfx950"} { // CHECK-SAME: memref<256x8xf16, #gpu.address_space> -> memref<8xf16, #gpu.address_space> // CHECK: rock.threadwise_read_into {forceUnroll, ldsTransposeConfig = #rock.lds_transpose_config, useIndexDiffs} // CHECK-SAME: memref<256x32xf16, #gpu.address_space> -> memref<32xf16, #gpu.address_space> + +// ----- + +// Test INT8 32x32 MFMA (mfma_i32_32x32x32_i8) with LDS transpose attributes +#params_int8_32x32 = #rock.accel_gemm_params< + kpackPerBlock = 32, mPerBlock = 32, nPerBlock = 32, + kpack = 1, mPerWave = 32, nPerWave = 32, + mnPerXdl = 32, splitKFactor = 1, scheduleVersion = 3, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true> + +module attributes {mhal.arch = "amdgcn-amd-amdhsa:gfx950"} { + // CHECK-LABEL: func.func @test_lds_transpose_int8_32x32 + func.func @test_lds_transpose_int8_32x32( + %arg0: memref<1024xi8>, + %arg1: memref<1024xi8>, + %arg2: memref<1024xi32>) + attributes {block_size = 64 : i32, grid_size = 1 : i32, + enable_splitk_for_tuning, kernel, + num_cu = 256 : i64} { + %a = rock.transform %arg0 by (d1 * 32 + d2)> by [ ["raw"] at [0]>, [] at []>] bounds = [1, 32, 32] -> [1024]> : memref<1024xi8> to memref<1x32x32xi8> + %b = rock.transform %arg1 by (d1 * 32 + d2)> by [ ["raw"] at [0]>, [] at []>] bounds = [1, 32, 32] -> [1024]> : memref<1024xi8> to memref<1x32x32xi8> + %c = rock.transform %arg2 by (d1 * 32 + d2)> by [ ["raw"] at [0]>, [] at []>] bounds = [1, 32, 32] -> [1024]> : memref<1024xi32> to memref<1x32x32xi32> + + rock.gridwise_gemm_accel(%a, %b, %c) + storeMethod(set) + features = mfma|dot|atomic_add|atomic_add_bf16|atomic_add_f16|direct_to_lds_32b|direct_to_lds_128b + {blockSize = 64 : i32, gridSize = 1 : i32, params = #params_int8_32x32} + : memref<1x32x32xi8>, memref<1x32x32xi8>, memref<1x32x32xi32> + return + } +} + +// CHECK: rock.threadwise_read_into {forceUnroll, ldsTransposeConfig = #rock.lds_transpose_config, useIndexDiffs} +// CHECK: rock.threadwise_read_into {forceUnroll, ldsTransposeConfig = #rock.lds_transpose_config, useIndexDiffs} + +// ----- + +// Test INT8 16x64 MFMA (mfma_i32_16x16x64_i8) with LDS transpose attributes +#params_int8_16x64 = #rock.accel_gemm_params< + kpackPerBlock = 64, mPerBlock = 16, nPerBlock = 16, + kpack = 1, mPerWave = 16, nPerWave = 16, + mnPerXdl = 16, splitKFactor = 1, scheduleVersion = 3, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, forceUnroll = true> + +module attributes {mhal.arch = "amdgcn-amd-amdhsa:gfx950"} { + // CHECK-LABEL: func.func @test_lds_transpose_int8_16x64 + func.func @test_lds_transpose_int8_16x64( + %arg0: memref<1024xi8>, + %arg1: memref<1024xi8>, + %arg2: memref<256xi32>) + attributes {block_size = 64 : i32, grid_size = 1 : i32, + enable_splitk_for_tuning, kernel, + num_cu = 256 : i64} { + %a = rock.transform %arg0 by (d1 * 16 + d2)> by [ ["raw"] at [0]>, [] at []>] bounds = [1, 64, 16] -> [1024]> : memref<1024xi8> to memref<1x64x16xi8> + %b = rock.transform %arg1 by (d1 * 16 + d2)> by [ ["raw"] at [0]>, [] at []>] bounds = [1, 64, 16] -> [1024]> : memref<1024xi8> to memref<1x64x16xi8> + %c = rock.transform %arg2 by (d1 * 16 + d2)> by [ ["raw"] at [0]>, [] at []>] bounds = [1, 16, 16] -> [256]> : memref<256xi32> to memref<1x16x16xi32> + + rock.gridwise_gemm_accel(%a, %b, %c) + storeMethod(set) + features = mfma|dot|atomic_add|atomic_add_bf16|atomic_add_f16|direct_to_lds_32b|direct_to_lds_128b + {blockSize = 64 : i32, gridSize = 1 : i32, params = #params_int8_16x64} + : memref<1x64x16xi8>, memref<1x64x16xi8>, memref<1x16x16xi32> + return + } +} + +// CHECK: rock.threadwise_read_into {forceUnroll, ldsTransposeConfig = #rock.lds_transpose_config, useIndexDiffs} +// CHECK: rock.threadwise_read_into {forceUnroll, ldsTransposeConfig = #rock.lds_transpose_config, useIndexDiffs} diff --git a/mlir/test/Dialect/Rock/lds_transpose_error.mlir b/mlir/test/Dialect/Rock/lds_transpose_error.mlir index efc7cfc42fa4..61e0195fdecc 100644 --- a/mlir/test/Dialect/Rock/lds_transpose_error.mlir +++ b/mlir/test/Dialect/Rock/lds_transpose_error.mlir @@ -2,13 +2,13 @@ // Error case: Invalid MFMA geometry (16x8 is not valid) // This tests that LDSTransposeConfigAttr::verify() catches invalid MFMA -// geometry combinations. Valid combinations are: (16,16), (16,32), (32,8), (32,16) +// geometry combinations. Valid combinations are: (16,16), (16,32), (16,64), (32,8), (32,16), (32,32) func.func @threadwise_read_into_invalid_mfma_geometry_16x8( %source: memref<128xf16, #gpu.address_space>, %dest: memref<8xf16, #gpu.address_space>) attributes {arch = "amdgcn-amd-amdhsa:gfx950"} { rock.threadwise_read_into { - // expected-error @+1 {{invalid MFMA geometry (16x8) for LDS transpose - valid combinations: (16,16), (16,32), (32,8), (32,16)}} + // expected-error @+1 {{invalid MFMA geometry (16x8) for LDS transpose - valid combinations: (16,16), (16,32), (16,64), (32,8), (32,16), (32,32)}} ldsTransposeConfig = #rock.lds_transpose_config< dDim = 16, kDim = 8, mPerBlock = 128, nPerBlock = 128, kPerBlock = 32, @@ -21,15 +21,15 @@ func.func @threadwise_read_into_invalid_mfma_geometry_16x8( // ----- -// Error case: Invalid MFMA geometry (32x32 is not valid) -func.func @threadwise_read_into_invalid_mfma_geometry_32x32( +// Error case: Invalid MFMA geometry (64x16 is not valid) +func.func @threadwise_read_into_invalid_mfma_geometry_64x16( %source: memref<128xf16, #gpu.address_space>, %dest: memref<8xf16, #gpu.address_space>) attributes {arch = "amdgcn-amd-amdhsa:gfx950"} { rock.threadwise_read_into { - // expected-error @+1 {{invalid MFMA geometry (32x32) for LDS transpose - valid combinations: (16,16), (16,32), (32,8), (32,16)}} + // expected-error @+1 {{invalid MFMA geometry (64x16) for LDS transpose - valid combinations: (16,16), (16,32), (16,64), (32,8), (32,16), (32,32)}} ldsTransposeConfig = #rock.lds_transpose_config< - dDim = 32, kDim = 32, + dDim = 64, kDim = 16, mPerBlock = 128, nPerBlock = 128, kPerBlock = 32, mPerWave = 64, nPerWave = 64, doubleBuffering = false, isOperandA = true @@ -46,7 +46,7 @@ func.func @threadwise_read_into_invalid_mfma_geometry_8x8( %dest: memref<8xf16, #gpu.address_space>) attributes {arch = "amdgcn-amd-amdhsa:gfx950"} { rock.threadwise_read_into { - // expected-error @+1 {{invalid MFMA geometry (8x8) for LDS transpose - valid combinations: (16,16), (16,32), (32,8), (32,16)}} + // expected-error @+1 {{invalid MFMA geometry (8x8) for LDS transpose - valid combinations: (16,16), (16,32), (16,64), (32,8), (32,16), (32,32)}} ldsTransposeConfig = #rock.lds_transpose_config< dDim = 8, kDim = 8, mPerBlock = 128, nPerBlock = 128, kPerBlock = 32, diff --git a/mlir/test/Dialect/Rock/lowering_load_transpose_lds.mlir b/mlir/test/Dialect/Rock/lowering_load_transpose_lds.mlir index 29a2f190a06b..20521a23a52b 100644 --- a/mlir/test/Dialect/Rock/lowering_load_transpose_lds.mlir +++ b/mlir/test/Dialect/Rock/lowering_load_transpose_lds.mlir @@ -14,4 +14,25 @@ module attributes {mhal.arch = "amdgcn-amd-amdhsa:gfx950"} { %v = rock.lds_transpose_load %src[%i, %j] : memref<64x128xbf16, #gpu.address_space> -> vector<4xbf16> return %v : vector<4xbf16> } + +// CHECK-LABEL: func @test_load_transpose_fp8_e4m3 + func.func @test_load_transpose_fp8_e4m3(%src: memref<128x256xf8E4M3FN, #gpu.address_space>, %i: index, %j: index) -> vector<8xf8E4M3FN> { + // CHECK: amdgpu.transpose_load %arg0[%arg1, %arg2] : memref<128x256xf8E4M3FN, #gpu.address_space> -> vector<8xf8E4M3FN> + %v = rock.lds_transpose_load %src[%i, %j] : memref<128x256xf8E4M3FN, #gpu.address_space> -> vector<8xf8E4M3FN> + return %v : vector<8xf8E4M3FN> + } + +// CHECK-LABEL: func @test_load_transpose_fp8_e5m2 + func.func @test_load_transpose_fp8_e5m2(%src: memref<64x128xf8E5M2, #gpu.address_space>, %i: index, %j: index) -> vector<8xf8E5M2> { + // CHECK: amdgpu.transpose_load %arg0[%arg1, %arg2] : memref<64x128xf8E5M2, #gpu.address_space> -> vector<8xf8E5M2> + %v = rock.lds_transpose_load %src[%i, %j] : memref<64x128xf8E5M2, #gpu.address_space> -> vector<8xf8E5M2> + return %v : vector<8xf8E5M2> + } + +// CHECK-LABEL: func @test_load_transpose_i8 + func.func @test_load_transpose_i8(%src: memref<128x256xi8, #gpu.address_space>, %i: index, %j: index) -> vector<8xi8> { + // CHECK: amdgpu.transpose_load %arg0[%arg1, %arg2] : memref<128x256xi8, #gpu.address_space> -> vector<8xi8> + %v = rock.lds_transpose_load %src[%i, %j] : memref<128x256xi8, #gpu.address_space> -> vector<8xi8> + return %v : vector<8xi8> + } } diff --git a/mlir/test/Dialect/Rock/ops.mlir b/mlir/test/Dialect/Rock/ops.mlir index fafe8fa665d6..1a56f6352788 100644 --- a/mlir/test/Dialect/Rock/ops.mlir +++ b/mlir/test/Dialect/Rock/ops.mlir @@ -433,6 +433,27 @@ func.func @rock_lds_transpose_load_full_arch(%lds_buffer: memref<128x64xf16, #gp return } +// CHECK-LABEL: func.func @rock_lds_transpose_load_fp8_e4m3 +// CHECK: rock.lds_transpose_load +func.func @rock_lds_transpose_load_fp8_e4m3(%lds_buffer: memref<128x64xf8E4M3FN, #gpu.address_space>) + attributes {arch = "amdgcn-amd-amdhsa:gfx950"} { + %c0 = arith.constant 0 : index + %fragment = rock.lds_transpose_load %lds_buffer[%c0, %c0] + : memref<128x64xf8E4M3FN, #gpu.address_space> -> vector<8xf8E4M3FN> + return +} + +// CHECK-LABEL: func.func @rock_lds_transpose_load_fp8_e5m2 +// CHECK: rock.lds_transpose_load +func.func @rock_lds_transpose_load_fp8_e5m2(%lds_buffer: memref<256x128xf8E5M2, #gpu.address_space>) + attributes {arch = "amdgcn-amd-amdhsa:gfx950"} { + %c0 = arith.constant 0 : index + %c32 = arith.constant 32 : index + %fragment = rock.lds_transpose_load %lds_buffer[%c32, %c0] + : memref<256x128xf8E5M2, #gpu.address_space> -> vector<8xf8E5M2> + return +} + // CHECK-LABEL: func.func @test_lds_transpose_config_attr_16x32 // CHECK: ldsTransposeConfig = #rock.lds_transpose_config func.func @test_lds_transpose_config_attr_16x32(%src: memref<8192xf16, #gpu.address_space>, diff --git a/mlir/test/e2e/CMakeLists.txt b/mlir/test/e2e/CMakeLists.txt index 317918b9cc68..59e644b38f03 100644 --- a/mlir/test/e2e/CMakeLists.txt +++ b/mlir/test/e2e/CMakeLists.txt @@ -50,7 +50,10 @@ if (ROCMLIR_DRIVER_PR_E2E_TEST_ENABLED) PrConvElementwiseGemmBF16SplitK PrGemmDirectToLDS PrLdsTransposeLoad + PrLdsTransposeLoadFp8 + PrLdsTransposeLoadI8 PrLdsTransposeLoadAttention + PrLdsTransposeLoadAttentionI8 PrConvDirectToLDS PrAttentionDirectToLDS ) diff --git a/mlir/test/e2e/LdsTransposeLoadAttention.toml b/mlir/test/e2e/LdsTransposeLoadAttention.toml index b2bd4f102575..181eaa5d8d1a 100644 --- a/mlir/test/e2e/LdsTransposeLoadAttention.toml +++ b/mlir/test/e2e/LdsTransposeLoadAttention.toml @@ -1,6 +1,6 @@ directory = "LdsTransposeLoadAttention" prefix = "rocmlir-gen" -suffix = "--operation attention --arch %arch -pv %constrained_float_range_random_data %rocmlir_gen_flags -relDiff_threshold 0.3 -absDiff_threshold 0.3 -RMS_threshold 0.15 | rocmlir-driver -c | mlir-runner -O2 --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext --entry-point-result=void | FileCheck %s --check-prefix=" +suffix = "--operation attention --arch %arch %pv %constrained_float_range_random_data %rocmlir_gen_flags -relDiff_threshold 0.3 -absDiff_threshold 0.3 -RMS_threshold 0.15 | rocmlir-driver -c | mlir-runner -O2 --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext --entry-point-result=void | FileCheck %s --check-prefix=" [[axis]] name = "data type" diff --git a/mlir/test/e2e/PrLdsTransposeLoadAttentionI8.cfg b/mlir/test/e2e/PrLdsTransposeLoadAttentionI8.cfg new file mode 100644 index 000000000000..46909aa10a02 --- /dev/null +++ b/mlir/test/e2e/PrLdsTransposeLoadAttentionI8.cfg @@ -0,0 +1,2 @@ +if not 'lds_transpose_load' in config.features: + config.unsupported = True diff --git a/mlir/test/e2e/PrLdsTransposeLoadAttentionI8.toml b/mlir/test/e2e/PrLdsTransposeLoadAttentionI8.toml new file mode 100644 index 000000000000..5ac7f718778c --- /dev/null +++ b/mlir/test/e2e/PrLdsTransposeLoadAttentionI8.toml @@ -0,0 +1,52 @@ +directory = "PrLdsTransposeLoadAttentionI8" +prefix = "rocmlir-gen" +suffix = "--operation attention --arch %arch -pv %random_data %rocmlir_gen_flags -RMS_threshold 0.01 | rocmlir-driver -c | mlir-runner -O2 --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext --entry-point-result=void | FileCheck %s --check-prefix=" + +[[axis]] +name = "data type" +values = ["i8"] +prefix = "-t " + +# ============================================================================ +# Suite 1: Both K and Q use LDS transpose (transQ=true, transK=false) +# 1 test per MFMA = 4 tests +# ============================================================================ +[[suite]] +name = "lds_transpose_both_i8" + +[[suite.test]] +config = "-seq_len_q 64 -seq_len_k 64 -head_dim_qk 64 -head_dim_v 64 --transQ=true --transK=false -perf_config attn:v3:32,32,32,32,16,16,16,8,1,3,2,0,1" + +[[suite.test]] +config = "-seq_len_q 64 -seq_len_k 64 -head_dim_qk 128 -head_dim_v 64 --transQ=true --transK=false -perf_config attn:v3:32,32,32,64,16,16,16,8,1,3,2,0,1" + +[[suite.test]] +config = "-seq_len_q 64 -seq_len_k 64 -head_dim_qk 64 -head_dim_v 64 --transQ=true --transK=false -perf_config attn:v3:32,32,32,16,32,32,32,8,1,3,2,0,1" + +[[suite.test]] +config = "-seq_len_q 64 -seq_len_k 64 -head_dim_qk 64 -head_dim_v 64 --transQ=true --transK=false -perf_config attn:v3:32,32,32,32,32,32,32,8,1,3,2,0,1" + +# ============================================================================ +# Suite 2: Only K uses LDS transpose (transQ=false, transK=false) +# 2 tests per MFMA = 8 tests +# ============================================================================ +[[suite]] +name = "lds_transpose_k_only_i8" + +[[suite.test]] +config = "-seq_len_q 64 -seq_len_k 64 -head_dim_qk 64 -head_dim_v 64 --transQ=false --transK=false -perf_config attn:v3:32,32,32,32,16,16,16,8,1,3,2,0,1" + +[[suite.test]] +config = "-seq_len_q 128 -seq_len_k 128 -head_dim_qk 64 -head_dim_v 64 --transQ=false --transK=false -perf_config attn:v3:64,64,64,32,16,16,16,16,1,3,2,0,1" + +[[suite.test]] +config = "-seq_len_q 64 -seq_len_k 64 -head_dim_qk 128 -head_dim_v 64 --transQ=false --transK=false -perf_config attn:v3:32,32,32,64,16,16,16,8,1,3,2,0,1" + +[[suite.test]] +config = "-seq_len_q 64 -seq_len_k 64 -head_dim_qk 64 -head_dim_v 64 --transQ=false --transK=false -perf_config attn:v3:32,32,32,16,32,32,32,8,1,3,2,0,1" + +[[suite.test]] +config = "-seq_len_q 128 -seq_len_k 128 -head_dim_qk 64 -head_dim_v 64 --transQ=false --transK=false -perf_config attn:v3:64,64,64,16,32,32,32,16,1,3,2,0,1" + +[[suite.test]] +config = "-seq_len_q 64 -seq_len_k 64 -head_dim_qk 64 -head_dim_v 64 --transQ=false --transK=false -perf_config attn:v3:32,32,32,32,32,32,32,8,1,3,2,0,1" diff --git a/mlir/test/e2e/PrLdsTransposeLoadFp8.cfg b/mlir/test/e2e/PrLdsTransposeLoadFp8.cfg new file mode 100644 index 000000000000..46909aa10a02 --- /dev/null +++ b/mlir/test/e2e/PrLdsTransposeLoadFp8.cfg @@ -0,0 +1,2 @@ +if not 'lds_transpose_load' in config.features: + config.unsupported = True diff --git a/mlir/test/e2e/PrLdsTransposeLoadFp8.toml b/mlir/test/e2e/PrLdsTransposeLoadFp8.toml new file mode 100644 index 000000000000..41b36b149e6a --- /dev/null +++ b/mlir/test/e2e/PrLdsTransposeLoadFp8.toml @@ -0,0 +1,126 @@ +directory = "PrLdsTransposeLoadFp8" +prefix = "rocmlir-gen" +suffix = "--operation gemm --arch %arch %pv %constrained_float_range_random_data %rocmlir_gen_flags | rocmlir-driver -c | mlir-runner -O2 --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext --entry-point-result=void | FileCheck %s --check-prefix=" + +[[axis]] +name = "data type" +values = ["fp8_fp8", "bf8_bf8", "fp8_bf8", "bf8_fp8"] +prefix = "-t " + +# ============================================================================ +# Suite 1: BOTH A and B use LDS transpose (transA=true, transB=false) +# 5 tests covering different MFMA geometries and kpack/kPerBlock combinations +# Only fp8_fp8 and bf8_bf8 +# ============================================================================ +[[suite]] +name = "lds_transpose_both_operands_fp8" + +[[suite.test]] +config = "-g 1 -m 64 -k 128 -n 64 --transA=true --transB=false --perf_config v3:64,64,32,16,16,8,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +[[suite.test]] +config = "-g 1 -m 128 -k 64 -n 64 --transA=true --transB=false --perf_config v3:128,64,16,32,16,16,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +[[suite.test]] +config = "-g 1 -m 256 -k 256 -n 64 --transA=true --transB=false --perf_config v3:256,64,8,64,16,32,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +[[suite.test]] +config = "-g 1 -m 128 -k 128 -n 128 --transA=true --transB=false --perf_config v3:128,128,16,32,32,8,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +[[suite.test]] +config = "-g 1 -m 256 -k 128 -n 128 --transA=true --transB=false --perf_config v3:256,128,32,64,32,8,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +# ============================================================================ +# Suite 2: ONLY A uses LDS transpose (transA=true, transB=true) +# 3 tests - B uses regular load +# Only fp8_fp8 and bf8_bf8 +# ============================================================================ +[[suite]] +name = "lds_transpose_A_only_fp8" + +[[suite.test]] +config = "-g 1 -m 64 -k 64 -n 64 --transA=true --transB=true --perf_config v3:64,64,16,16,16,8,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +[[suite.test]] +config = "-g 1 -m 128 -k 128 -n 64 --transA=true --transB=true --perf_config v3:128,64,8,32,16,16,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +[[suite.test]] +config = "-g 1 -m 128 -k 256 -n 128 --transA=true --transB=true --perf_config v3:128,128,32,32,32,8,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +# ============================================================================ +# Suite 3: ONLY B uses LDS transpose (transA=false, transB=false) +# 3 tests - A uses regular load +# Only fp8_fp8 and bf8_bf8 +# ============================================================================ +[[suite]] +name = "lds_transpose_B_only_fp8" + +[[suite.test]] +config = "-g 1 -m 64 -k 64 -n 64 --transA=false --transB=false --perf_config v3:64,64,16,16,16,8,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +[[suite.test]] +config = "-g 1 -m 128 -k 128 -n 64 --transA=false --transB=false --perf_config v3:128,64,32,32,16,1,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +[[suite.test]] +config = "-g 1 -m 128 -k 256 -n 128 --transA=false --transB=false --perf_config v3:128,128,8,32,32,32,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +# ============================================================================ +# Suite 4: Mixed FP8/BF8 types (fp8_bf8 and bf8_fp8 only) +# 3 tests: both operands, only A, only B +# ============================================================================ +[[suite]] +name = "lds_transpose_mixed_fp8_bf8" + +# Test 1: Mixed types, both operands use LDS transpose, 16x32 MFMA +[[suite.test]] +config = "-g 1 -m 64 -k 128 -n 64 --transA=true --transB=false --perf_config v3:64,64,32,16,16,8,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_fp8", "bf8_bf8"] + +# Test 2: Mixed types, only A uses LDS transpose, 32x16 MFMA +[[suite.test]] +config = "-g 1 -m 128 -k 128 -n 128 --transA=true --transB=true --perf_config v3:128,128,16,32,32,8,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_fp8", "bf8_bf8"] + +# Test 3: Mixed types, only B uses LDS transpose, 16x32 MFMA +[[suite.test]] +config = "-g 1 -m 64 -k 64 -n 64 --transA=false --transB=false --perf_config v3:64,64,16,16,16,8,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_fp8", "bf8_bf8"] diff --git a/mlir/test/e2e/PrLdsTransposeLoadI8.cfg b/mlir/test/e2e/PrLdsTransposeLoadI8.cfg new file mode 100644 index 000000000000..46909aa10a02 --- /dev/null +++ b/mlir/test/e2e/PrLdsTransposeLoadI8.cfg @@ -0,0 +1,2 @@ +if not 'lds_transpose_load' in config.features: + config.unsupported = True diff --git a/mlir/test/e2e/PrLdsTransposeLoadI8.toml b/mlir/test/e2e/PrLdsTransposeLoadI8.toml new file mode 100644 index 000000000000..19bbc999efd2 --- /dev/null +++ b/mlir/test/e2e/PrLdsTransposeLoadI8.toml @@ -0,0 +1,86 @@ +directory = "PrLdsTransposeLoadI8" +prefix = "rocmlir-gen" +suffix = "--operation gemm --arch %arch %random_data %pv %rocmlir_gen_flags | rocmlir-driver -c | mlir-runner -O2 --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext --entry-point-result=void | FileCheck %s --check-prefix=" + +[[axis]] +name = "data type" +values = ["i8"] +prefix = "-t " + +# ============================================================================ +# Suite 1: BOTH A and B use LDS transpose (transA=true, transB=false) +# 2 tests per MFMA = 8 tests total +# ============================================================================ +[[suite]] +name = "lds_transpose_both_operands_i8" + +[[suite.test]] +config = "-g 1 -m 64 -k 128 -n 64 --transA=true --transB=false --perf_config v4:64,64,32,32,16,16,1,1,4,2,0,0,1,1" + +[[suite.test]] +config = "-g 1 -m 128 -k 256 -n 64 --transA=true --transB=false --perf_config v3:128,64,64,16,16,8,1,3,2,1,1" + +[[suite.test]] +config = "-g 1 -m 64 -k 128 -n 64 --transA=true --transB=false --perf_config v4:64,64,64,16,16,16,1,1,4,2,0,0,1,1" + +[[suite.test]] +config = "-g 1 -m 128 -k 256 -n 64 --transA=true --transB=false --perf_config v3:128,32,128,16,16,1,1,3,2,1,1" + +[[suite.test]] +config = "-g 1 -m 64 -k 64 -n 64 --transA=true --transB=false --perf_config v4:32,128,16,32,32,32,8,1,4,2,0,0,1,1" + +[[suite.test]] +config = "-g 1 -m 128 -k 128 -n 128 --transA=true --transB=false --perf_config v3:256,64,32,32,32,8,1,3,2,1,1" + +[[suite.test]] +config = "-g 1 -m 64 -k 128 -n 64 --transA=true --transB=false --perf_config v4:64,64,32,32,32,32,1,1,4,2,0,0,1,1" + +[[suite.test]] +config = "-g 1 -m 128 -k 256 -n 128 --transA=true --transB=false --perf_config v3:128,128,16,32,32,16,1,3,2,1,1" + +# ============================================================================ +# Suite 2: ONLY A uses LDS transpose (transA=true, transB=true) +# 2 tests per MFMA = 8 tests total +# ============================================================================ +[[suite]] +name = "lds_transpose_A_only_i8" + +[[suite.test]] +config = "-g 1 -m 64 -k 64 -n 64 --transA=true --transB=true --perf_config v4:64,64,32,32,16,16,1,1,4,2,0,0,1,1" + +[[suite.test]] +config = "-g 1 -m 128 -k 128 -n 64 --transA=true --transB=true --perf_config v3:128,64,64,16,16,8,1,3,2,1,1" + +[[suite.test]] +config = "-g 1 -m 64 -k 128 -n 64 --transA=true --transB=true --perf_config v4:64,64,64,16,16,16,1,1,4,2,0,0,1,1" + +[[suite.test]] +config = "-g 1 -m 64 -k 64 -n 64 --transA=true --transB=true --perf_config v4:64,64,16,32,32,32,16,1,4,2,0,0,1,1" + +[[suite.test]] +config = "-g 1 -m 128 -k 128 -n 128 --transA=true --transB=true --perf_config v3:128,128,32,32,32,8,1,3,2,1,1" + +# ============================================================================ +# Suite 3: ONLY B uses LDS transpose (transA=false, transB=false) +# 2 tests per MFMA = 8 tests total +# ============================================================================ +[[suite]] +name = "lds_transpose_B_only_i8" + +[[suite.test]] +config = "-g 1 -m 64 -k 64 -n 64 --transA=false --transB=false --perf_config v4:64,64,32,16,16,16,8,1,4,2,0,0,1,1" + +[[suite.test]] +config = "-g 1 -m 128 -k 128 -n 64 --transA=false --transB=false --perf_config v3:128,64,64,16,16,8,1,3,2,1,1" + +[[suite.test]] +config = "-g 1 -m 64 -k 128 -n 64 --transA=false --transB=false --perf_config v4:64,64,64,16,16,16,1,1,4,2,0,0,1,1" + +[[suite.test]] +config = "-g 1 -m 128 -k 128 -n 128 --transA=false --transB=false --perf_config v3:128,128,32,32,32,16,1,3,2,1,1" + +[[suite.test]] +config = "-g 1 -m 64 -k 64 -n 64 --transA=false --transB=false --perf_config v4:64,64,32,32,32,32,1,1,4,2,0,0,1,1" + +[[suite.test]] +config = "-g 1 -m 128 -k 128 -n 128 --transA=false --transB=false --perf_config v3:128,128,64,32,32,8,1,3,2,1,1"