diff --git a/mlir/include/mlir/Dialect/Rock/IR/MfmaInsnGroup.h b/mlir/include/mlir/Dialect/Rock/IR/MfmaInsnGroup.h index b37907d05c8d..5377cb6c2280 100644 --- a/mlir/include/mlir/Dialect/Rock/IR/MfmaInsnGroup.h +++ b/mlir/include/mlir/Dialect/Rock/IR/MfmaInsnGroup.h @@ -31,7 +31,13 @@ enum class MfmaTypeId : uint32_t { Fp8Fp8TyId, Fp8Bf8TyId, Bf8Fp8TyId, - Bf8Bf8TyId + Bf8Bf8TyId, + // FP8 via scaled MFMA (uses mfma_scale_f32_16x16x128_f8f6f4 with cbsz=0) + // These provide larger K dimension (128 for 16x16, 64 for 32x32) + Fp8Fp8ScaledTyId, + Fp8Bf8ScaledTyId, + Bf8Fp8ScaledTyId, + Bf8Bf8ScaledTyId }; struct MfmaInsnInfo { @@ -155,6 +161,10 @@ class MfmaInsnGroup { bool isCoherentWithK(int64_t kPack, int64_t kPerBlock, int64_t scheduleVersion); SmallString<16> getROCDLIntrinsicName() { return groupAttr.insn; } + + // Check if this is FP8 using scaled MFMA (mfma_scale with cbsz=0, blgp=0) + // These instructions have larger K dimension (128 for 16x16, 64 for 32x32) + bool isScaledFp8() const; }; } // namespace rock diff --git a/mlir/lib/Dialect/Rock/IR/MfmaInsnGroup.cpp b/mlir/lib/Dialect/Rock/IR/MfmaInsnGroup.cpp index f02ba3e9012f..19ebb4296bfe 100644 --- a/mlir/lib/Dialect/Rock/IR/MfmaInsnGroup.cpp +++ b/mlir/lib/Dialect/Rock/IR/MfmaInsnGroup.cpp @@ -112,7 +112,13 @@ static auto getMfmaInsnInfoMap = []() -> const llvm::StringMap & { {ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName(), {MfmaTypeId::Bf8Bf8TyId, 16, 32, 1}}, - // fp4 + // Scaled MFMA instructions (FP4 and scaled FP8 types) + // Note: FP8 scaled types (Fp8Fp8ScaledTyId, Fp8Bf8ScaledTyId, etc.) + // use the same underlying instruction with identical (mfmaDDim, k, + // blocksMfma). + // Since deriveAttr only uses those fields (not MfmaTypeId), we only need + // one entry per instruction. The type differentiation happens elsewhere + // via cbsz/blgp parameters at code generation time. {ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), {MfmaTypeId::Fp4TyId, 16, 128, 1}}, {ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(), @@ -448,6 +454,25 @@ static auto getMfmaInsnGroupAttrMapGfx950 = []() { {{MfmaTypeId::Fp4TyId, 32, 32}, {ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName()}}, + // FP8 via scaled MFMA (cbsz=0, blgp=0 mode) + // 16x16 with K=128, 32x32 with K=64 + {{MfmaTypeId::Fp8Fp8ScaledTyId, 16, 16}, + {ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName()}}, + {{MfmaTypeId::Fp8Fp8ScaledTyId, 32, 32}, + {ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName()}}, + {{MfmaTypeId::Fp8Bf8ScaledTyId, 16, 16}, + {ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName()}}, + {{MfmaTypeId::Fp8Bf8ScaledTyId, 32, 32}, + {ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName()}}, + {{MfmaTypeId::Bf8Fp8ScaledTyId, 16, 16}, + {ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName()}}, + {{MfmaTypeId::Bf8Fp8ScaledTyId, 32, 32}, + {ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName()}}, + {{MfmaTypeId::Bf8Bf8ScaledTyId, 16, 16}, + {ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName()}}, + {{MfmaTypeId::Bf8Bf8ScaledTyId, 32, 32}, + {ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName()}}, + // i8 double rate {{MfmaTypeId::I8TyId, 32, 32}, {ROCDL::mfma_i32_32x32x32_i8::getOperationName()}}, @@ -583,6 +608,28 @@ static MfmaTypeId convertTypesToId(Type dataTypeA, Type dataTypeB) { llvm_unreachable("Unsupported input argument type."); } +// Convert native FP8 TypeId to scaled FP8 TypeId for gfx950 scaled MFMA +static std::optional getScaledFp8TypeId(MfmaTypeId nativeTypeId) { + switch (nativeTypeId) { + case MfmaTypeId::Fp8Fp8TyId: + return MfmaTypeId::Fp8Fp8ScaledTyId; + case MfmaTypeId::Fp8Bf8TyId: + return MfmaTypeId::Fp8Bf8ScaledTyId; + case MfmaTypeId::Bf8Fp8TyId: + return MfmaTypeId::Bf8Fp8ScaledTyId; + case MfmaTypeId::Bf8Bf8TyId: + return MfmaTypeId::Bf8Bf8ScaledTyId; + default: + return std::nullopt; + } +} + +// Check if this is a native FP8 type (not scaled) +static bool isNativeFp8TypeId(MfmaTypeId typeId) { + return typeId == MfmaTypeId::Fp8Fp8TyId || typeId == MfmaTypeId::Fp8Bf8TyId || + typeId == MfmaTypeId::Bf8Fp8TyId || typeId == MfmaTypeId::Bf8Bf8TyId; +} + FailureOr MfmaInsnGroup::select(Type elementTypeA, Type elementTypeB, StringRef arch, int64_t mnPerXdl, int64_t kPack, int64_t kPackPerBlock, @@ -624,6 +671,44 @@ MfmaInsnGroup::select(Type elementTypeA, Type elementTypeB, StringRef arch, }; auto selectForGfx950 = [&]() { + int64_t kPerBlock = kPack * kPackPerBlock; + + // For FP8 types, try scaled MFMA first if kPerBlock is large enough + // Scaled MFMA has K=128 for 16x16 and K=64 for 32x32 + if (isNativeFp8TypeId(key.type)) { + int64_t scaledK = (mPerMfmaGroup == 16) ? 128 : 64; + if (kPerBlock >= scaledK) { + LLVM_DEBUG(llvm::dbgs() + << ">>> Trying scaled FP8: kPerBlock=" << kPerBlock + << " >= scaledK=" << scaledK << "\n"); + auto scaledTypeId = getScaledFp8TypeId(key.type); + if (scaledTypeId) { + MfmaInsnGroupSelectKey scaledKey = {*scaledTypeId, mPerMfmaGroup, + nPerMfmaGroup}; + const auto &gfx950Map = getMfmaInsnGroupAttrMapGfx950(); + auto it = gfx950Map.find(scaledKey); + if (it != gfx950Map.end()) { + MfmaInsnGroupAttr groupAttr = (*it).second; + auto maybeInsn = MfmaInsn::select(groupAttr.insn); + if (succeeded(maybeInsn)) { + auto scaledResult = MfmaInsnGroup(elementTypeA, elementTypeB, + *maybeInsn, groupAttr); + if (scaledResult.isCoherentWithK(kPack, kPackPerBlock, + scheduleVersion)) { + LLVM_DEBUG(llvm::dbgs() << ">>> SELECTED SCALED FP8 MFMA: K=" + << maybeInsn->getAttr().k << "\n"); + result = scaledResult; + return; + } + } + } + } + LLVM_DEBUG( + llvm::dbgs() + << ">>> Scaled FP8 MFMA not suitable, falling back to native\n"); + } + } + // gfx950 has double rate instructions. Select from those first. selectFrom(getMfmaInsnGroupAttrMapGfx950()); if (succeeded(result)) { @@ -714,3 +799,27 @@ bool MfmaInsnGroup::isCoherentWithK(int64_t kpack, int64_t kPerBlock, int64_t scheduleVersion) { return insn.isCoherentWithK(kpack, kPerBlock, scheduleVersion); } + +bool MfmaInsnGroup::isScaledFp8() const { + // Check if the instruction is a scaled MFMA + // (rocdl.mfma.scale.f32.*x*x*.f8f6f4) + StringRef insnName = groupAttr.insn; + bool isScaledInsn = insnName.contains("mfma.scale.f32.16x16x128.f8f6f4") || + insnName.contains("mfma.scale.f32.32x32x64.f8f6f4"); + if (!isScaledInsn) + return false; + + // Check if the element type is FP8 (not FP4) + // FP8 types: Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, Float8E5M2FNUZ + // FP4 types: Float4E2M1FN + bool isFp8A = isa(elementTypeA) || + isa(elementTypeA) || + isa(elementTypeA) || + isa(elementTypeA); + bool isFp8B = isa(elementTypeB) || + isa(elementTypeB) || + isa(elementTypeB) || + isa(elementTypeB); + + return isFp8A && isFp8B; +} diff --git a/mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp b/mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp index 258b5369d7bc..2621fe5e048f 100644 --- a/mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp +++ b/mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp @@ -191,6 +191,22 @@ void MfmaEmitter::emitThreadwiseLoop(OpBuilder &b, Location loc, Value argA, VectorType vectorType = mfmaGroup.getRetType(); auto outputOffset = llvm::to_vector(regCOffset); bool isScaled = scaleA && scaleB; + bool isScaledFp8 = mfmaGroup.isScaledFp8(); + + // For scaled FP8 MFMA without explicit scale buffers, create neutral scales. + // In cbsz=0, blgp=0 mode, a scale exponent value of 0 means no scaling + // because 2^0 = 1. For Float8E8M0FNU (an exponent-only format), the call + // getFloatAttr(scaleType, 0.0) is used to produce the encoding with + // exponent = 0 (all-zero bit pattern), which corresponds to a scale of 1. + Value neutralScaleA, neutralScaleB; + if (isScaledFp8 && !isScaled) { + Type scaleType = b.getType(); + auto neutralScaleAttr = b.getFloatAttr(scaleType, 0.0); + neutralScaleA = + arith::ConstantOp::create(b, loc, scaleType, neutralScaleAttr); + neutralScaleB = + arith::ConstantOp::create(b, loc, scaleType, neutralScaleAttr); + } for (int64_t i = 0; i < nResultVectors; ++i) { Value offset = b.createOrFold(loc, i); @@ -203,11 +219,21 @@ void MfmaEmitter::emitThreadwiseLoop(OpBuilder &b, Location loc, Value argA, Value vectorD; if (isScaled) { + // Explicit scale buffers provided (FP4 or scaled FP8 with explicit scales) auto mfma = amdgpu::ScaledMFMAOp::create( b, loc, vectorType, mfmaDDim, mfmaDDim, mfmaAttr.k, argA, argB, vectorC, scaleA, scaleB, /*scalesIdxA=*/0, /*scalesIdxB=*/0); vectorD = mfma.getDestD(); + } else if (isScaledFp8) { + // Scaled FP8 MFMA (K=128 for 16x16, K=64 for 32x32) without explicit scales + // Use neutral scale values (0) which means 2^0 = 1 (no scaling) + auto mfma = amdgpu::ScaledMFMAOp::create( + b, loc, vectorType, mfmaDDim, mfmaDDim, mfmaAttr.k, argA, argB, + vectorC, neutralScaleA, neutralScaleB, + /*scalesIdxA=*/0, /*scalesIdxB=*/0); + vectorD = mfma.getDestD(); } else { + // Regular MFMA auto mfma = amdgpu::MFMAOp::create( b, loc, vectorType, mfmaDDim, mfmaDDim, mfmaAttr.k, mfmaAttr.blocksMfma, argA, argB, vectorC, /*cbsz=*/imms[i].cbsz, diff --git a/mlir/test/Dialect/Rock/lowering_xdlops_gemm.mlir b/mlir/test/Dialect/Rock/lowering_xdlops_gemm.mlir index cca87c21f4b8..1a3baeb845c0 100644 --- a/mlir/test/Dialect/Rock/lowering_xdlops_gemm.mlir +++ b/mlir/test/Dialect/Rock/lowering_xdlops_gemm.mlir @@ -649,3 +649,255 @@ func.func @accel_gemm_gfx950_f32_16x16x512_fp4_scaled_multi(%matrixA : memref<1x } : memref<1x1xvector<4xf32>, 5> += memref<1x4xvector<32xf4E2M1FN>, 5> scaled by memref<1x4xvector<32xf8E8M0FNU>, 5> * memref<1x4xvector<32xf4E2M1FN>, 5> scaled by memref<1x4xvector<32xf8E8M0FNU>, 5> return } + +func.func @accel_gemm_gfx950_scaled_fp8_bf8_32x32x64_single_v1(%matrixA : memref<1x2xvector<32xf8E4M3FN>, 5>, + %matrixB : memref<1x2xvector<32xf8E5M2>, 5>, + %matrixC : memref<1x1xvector<16xf32>, 5>) { + // CHECK-LABEL: func.func @accel_gemm_gfx950_scaled_fp8_bf8_32x32x64_single_v1 + // CHECK: rock.transforming_for + // CHECK-SAME: bounds [1, 1, 1] + // CHECK: amdgpu.scaled_mfma 32x32x64 + // CHECK-SAME: vector<32xf8E4M3FN>{{.*}}vector<32xf8E5M2>{{.*}}vector<16xf32> + // CHECK-NOT: amdgpu.scaled_mfma + %c0 = arith.constant 0 : index + rock.threadwise_gemm_accel %matrixC += %matrixA * %matrixB at [%c0, %c0, %c0] features = mfma { + arch = "amdgcn-amd-amdhsa:gfx950", + params = #rock.accel_gemm_params< + kpackPerBlock = 2, + mPerBlock = 32, + nPerBlock = 32, + kpack = 32, + mPerWave = 32, + nPerWave = 32, + mnPerXdl = 32, + splitKFactor = 1, + scheduleVersion = 1, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, + forceUnroll = true> + } : memref<1x1xvector<16xf32>, 5> += memref<1x2xvector<32xf8E4M3FN>, 5> * memref<1x2xvector<32xf8E5M2>, 5> + return +} + +func.func @accel_gemm_gfx950_scaled_fp8_fp8_16x16x128_single_v1(%matrixA : memref<1x4xvector<32xf8E4M3FN>, 5>, + %matrixB : memref<1x4xvector<32xf8E4M3FN>, 5>, + %matrixC : memref<1x1xvector<4xf32>, 5>) { + // CHECK-LABEL: func.func @accel_gemm_gfx950_scaled_fp8_fp8_16x16x128_single_v1 + // CHECK: rock.transforming_for + // CHECK-SAME: bounds [1, 1, 1] + // CHECK: amdgpu.scaled_mfma 16x16x128 + // CHECK-SAME: vector<32xf8E4M3FN>{{.*}}vector<32xf8E4M3FN>{{.*}}vector<4xf32> + // CHECK-NOT: amdgpu.scaled_mfma + %c0 = arith.constant 0 : index + rock.threadwise_gemm_accel %matrixC += %matrixA * %matrixB at [%c0, %c0, %c0] features = mfma { + arch = "amdgcn-amd-amdhsa:gfx950", + params = #rock.accel_gemm_params< + kpackPerBlock = 4, + mPerBlock = 16, + nPerBlock = 16, + kpack = 32, + mPerWave = 16, + nPerWave = 16, + mnPerXdl = 16, + splitKFactor = 1, + scheduleVersion = 1, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, + forceUnroll = true> + } : memref<1x1xvector<4xf32>, 5> += memref<1x4xvector<32xf8E4M3FN>, 5> * memref<1x4xvector<32xf8E4M3FN>, 5> + return +} + +func.func @accel_gemm_gfx950_scaled_bf8_bf8_16x16x128_single_v3_kpack1(%matrixA : memref<1x1xvector<32xf8E5M2>, 5>, + %matrixB : memref<1x1xvector<32xf8E5M2>, 5>, + %matrixC : memref<1x1xvector<4xf32>, 5>) { + // CHECK-LABEL: func.func @accel_gemm_gfx950_scaled_bf8_bf8_16x16x128_single_v3_kpack1 + // CHECK: rock.transforming_for + // CHECK-SAME: bounds [1, 1, 1] + // CHECK: amdgpu.scaled_mfma 16x16x128 + // CHECK-SAME: vector<32xf8E5M2>{{.*}}vector<32xf8E5M2>{{.*}}vector<4xf32> + // CHECK-NOT: amdgpu.scaled_mfma + %c0 = arith.constant 0 : index + rock.threadwise_gemm_accel %matrixC += %matrixA * %matrixB at [%c0, %c0, %c0] features = mfma { + arch = "amdgcn-amd-amdhsa:gfx950", + params = #rock.accel_gemm_params< + kpackPerBlock = 128, + mPerBlock = 16, + nPerBlock = 16, + kpack = 1, + mPerWave = 16, + nPerWave = 16, + mnPerXdl = 16, + splitKFactor = 1, + scheduleVersion = 3, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, + forceUnroll = true> + } : memref<1x1xvector<4xf32>, 5> += memref<1x1xvector<32xf8E5M2>, 5> * memref<1x1xvector<32xf8E5M2>, 5> + return +} + +func.func @accel_gemm_gfx950_scaled_bf8_fp8_32x32x64_single_v3_kpack1(%matrixA : memref<1x1xvector<32xf8E5M2>, 5>, + %matrixB : memref<1x1xvector<32xf8E4M3FN>, 5>, + %matrixC : memref<1x1xvector<16xf32>, 5>) { + // CHECK-LABEL: func.func @accel_gemm_gfx950_scaled_bf8_fp8_32x32x64_single_v3_kpack1 + // CHECK: rock.transforming_for + // CHECK-SAME: bounds [1, 1, 1] + // CHECK: amdgpu.scaled_mfma 32x32x64 + // CHECK-SAME: vector<32xf8E5M2>{{.*}}vector<32xf8E4M3FN>{{.*}}vector<16xf32> + // CHECK-NOT: amdgpu.scaled_mfma + %c0 = arith.constant 0 : index + rock.threadwise_gemm_accel %matrixC += %matrixA * %matrixB at [%c0, %c0, %c0] features = mfma { + arch = "amdgcn-amd-amdhsa:gfx950", + params = #rock.accel_gemm_params< + kpackPerBlock = 64, + mPerBlock = 32, + nPerBlock = 32, + kpack = 1, + mPerWave = 32, + nPerWave = 32, + mnPerXdl = 32, + splitKFactor = 1, + scheduleVersion = 3, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, + forceUnroll = true> + } : memref<1x1xvector<16xf32>, 5> += memref<1x1xvector<32xf8E5M2>, 5> * memref<1x1xvector<32xf8E4M3FN>, 5> + return +} + +func.func @accel_gemm_gfx950_scaled_fp8_bf8_32x32x64_double_v2(%matrixA : memref<1x2xvector<32xf8E4M3FN>, 5>, + %matrixB : memref<1x2xvector<32xf8E5M2>, 5>, + %matrixC : memref<1x1xvector<16xf32>, 5>) { + // CHECK-LABEL: func.func @accel_gemm_gfx950_scaled_fp8_bf8_32x32x64_double_v2 + // CHECK: rock.transforming_for + // CHECK-SAME: bounds [1, 1, 1] + // CHECK: amdgpu.scaled_mfma 32x32x64 + // CHECK-SAME: vector<32xf8E4M3FN>{{.*}}vector<32xf8E5M2>{{.*}}vector<16xf32> + // CHECK-NOT: amdgpu.scaled_mfma + %c0 = arith.constant 0 : index + rock.threadwise_gemm_accel %matrixC += %matrixA * %matrixB at [%c0, %c0, %c0] features = mfma { + arch = "amdgcn-amd-amdhsa:gfx950", + params = #rock.accel_gemm_params< + kpackPerBlock = 2, + mPerBlock = 32, + nPerBlock = 32, + kpack = 32, + mPerWave = 32, + nPerWave = 32, + mnPerXdl = 32, + splitKFactor = 1, + scheduleVersion = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, + forceUnroll = true> + } : memref<1x1xvector<16xf32>, 5> += memref<1x2xvector<32xf8E4M3FN>, 5> * memref<1x2xvector<32xf8E5M2>, 5> + return +} + +func.func @accel_gemm_gfx950_scaled_bf8_bf8_16x16x128_double_v4(%matrixA : memref<1x4xvector<32xf8E5M2>, 5>, + %matrixB : memref<1x4xvector<32xf8E5M2>, 5>, + %matrixC : memref<1x1xvector<4xf32>, 5>) { + // CHECK-LABEL: func.func @accel_gemm_gfx950_scaled_bf8_bf8_16x16x128_double_v4 + // CHECK: rock.transforming_for + // CHECK-SAME: bounds [1, 1, 1] + // CHECK: amdgpu.scaled_mfma 16x16x128 + // CHECK-SAME: vector<32xf8E5M2>{{.*}}vector<32xf8E5M2>{{.*}}vector<4xf32> + // CHECK-NOT: amdgpu.scaled_mfma + %c0 = arith.constant 0 : index + rock.threadwise_gemm_accel %matrixC += %matrixA * %matrixB at [%c0, %c0, %c0] features = mfma { + arch = "amdgcn-amd-amdhsa:gfx950", + params = #rock.accel_gemm_params< + kpackPerBlock = 4, + mPerBlock = 16, + nPerBlock = 16, + kpack = 32, + mPerWave = 16, + nPerWave = 16, + mnPerXdl = 16, + splitKFactor = 1, + scheduleVersion = 4, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, + forceUnroll = true> + } : memref<1x1xvector<4xf32>, 5> += memref<1x4xvector<32xf8E5M2>, 5> * memref<1x4xvector<32xf8E5M2>, 5> + return +} + +func.func @accel_gemm_gfx950_scaled_bf8_bf8_32x32x64_double_v2_kpack8(%matrixA : memref<1x2xvector<32xf8E5M2>, 5>, + %matrixB : memref<1x2xvector<32xf8E5M2>, 5>, + %matrixC : memref<1x1xvector<16xf32>, 5>) { + // CHECK-LABEL: func.func @accel_gemm_gfx950_scaled_bf8_bf8_32x32x64_double_v2_kpack8 + // CHECK: rock.transforming_for + // CHECK-SAME: bounds [1, 1, 1] + // CHECK: amdgpu.scaled_mfma 32x32x64 + // CHECK-SAME: vector<32xf8E5M2>{{.*}}vector<32xf8E5M2>{{.*}}vector<16xf32> + // CHECK-NOT: amdgpu.scaled_mfma + %c0 = arith.constant 0 : index + rock.threadwise_gemm_accel %matrixC += %matrixA * %matrixB at [%c0, %c0, %c0] features = mfma { + arch = "amdgcn-amd-amdhsa:gfx950", + params = #rock.accel_gemm_params< + kpackPerBlock = 8, + mPerBlock = 32, + nPerBlock = 32, + kpack = 8, + mPerWave = 32, + nPerWave = 32, + mnPerXdl = 32, + splitKFactor = 1, + scheduleVersion = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, + forceUnroll = true> + } : memref<1x1xvector<16xf32>, 5> += memref<1x2xvector<32xf8E5M2>, 5> * memref<1x2xvector<32xf8E5M2>, 5> + return +} + +func.func @accel_gemm_gfx950_scaled_fp8_fp8_16x16x128_double_v4_kpack4(%matrixA : memref<1x4xvector<32xf8E4M3FN>, 5>, + %matrixB : memref<1x4xvector<32xf8E4M3FN>, 5>, + %matrixC : memref<1x1xvector<4xf32>, 5>) { + // CHECK-LABEL: func.func @accel_gemm_gfx950_scaled_fp8_fp8_16x16x128_double_v4_kpack4 + // CHECK: rock.transforming_for + // CHECK-SAME: bounds [1, 1, 1] + // CHECK: amdgpu.scaled_mfma 16x16x128 + // CHECK-SAME: vector<32xf8E4M3FN>{{.*}}vector<32xf8E4M3FN>{{.*}}vector<4xf32> + // CHECK-NOT: amdgpu.scaled_mfma + %c0 = arith.constant 0 : index + rock.threadwise_gemm_accel %matrixC += %matrixA * %matrixB at [%c0, %c0, %c0] features = mfma { + arch = "amdgcn-amd-amdhsa:gfx950", + params = #rock.accel_gemm_params< + kpackPerBlock = 32, + mPerBlock = 16, + nPerBlock = 16, + kpack = 4, + mPerWave = 16, + nPerWave = 16, + mnPerXdl = 16, + splitKFactor = 1, + scheduleVersion = 4, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, + forceUnroll = true> + } : memref<1x1xvector<4xf32>, 5> += memref<1x4xvector<32xf8E4M3FN>, 5> * memref<1x4xvector<32xf8E4M3FN>, 5> + return +} + +func.func @accel_gemm_gfx950_scaled_bf8_bf8_16x16x128_double_v2_kpack1(%matrixA : memref<1x4xvector<32xf8E5M2>, 5>, + %matrixB : memref<1x4xvector<32xf8E5M2>, 5>, + %matrixC : memref<1x1xvector<4xf32>, 5>) { + // CHECK-LABEL: func.func @accel_gemm_gfx950_scaled_bf8_bf8_16x16x128_double_v2_kpack1 + // CHECK: rock.transforming_for + // CHECK-SAME: bounds [1, 1, 1] + // CHECK: amdgpu.scaled_mfma 16x16x128 + // CHECK-SAME: vector<32xf8E5M2>{{.*}}vector<32xf8E5M2>{{.*}}vector<4xf32> + // CHECK-NOT: amdgpu.scaled_mfma + %c0 = arith.constant 0 : index + rock.threadwise_gemm_accel %matrixC += %matrixA * %matrixB at [%c0, %c0, %c0] features = mfma { + arch = "amdgcn-amd-amdhsa:gfx950", + params = #rock.accel_gemm_params< + kpackPerBlock = 128, + mPerBlock = 16, + nPerBlock = 16, + kpack = 1, + mPerWave = 16, + nPerWave = 16, + mnPerXdl = 16, + splitKFactor = 1, + scheduleVersion = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, + forceUnroll = true> + } : memref<1x1xvector<4xf32>, 5> += memref<1x4xvector<32xf8E5M2>, 5> * memref<1x4xvector<32xf8E5M2>, 5> + return +}