From 0a3d5644209da2f33e03c2c9ee8235de37b07c5e Mon Sep 17 00:00:00 2001 From: stefan koncarevic Date: Fri, 20 Feb 2026 04:25:32 -0600 Subject: [PATCH 1/4] WIP: Scaled FP8 MFMA support --- .../mlir/Dialect/Rock/IR/MfmaInsnGroup.h | 12 +- mlir/lib/Dialect/Rock/IR/MfmaInsnGroup.cpp | 127 +++++++++++++++++- .../lib/Dialect/Rock/utility/AccelEmitter.cpp | 24 ++++ 3 files changed, 161 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir/Dialect/Rock/IR/MfmaInsnGroup.h b/mlir/include/mlir/Dialect/Rock/IR/MfmaInsnGroup.h index b37907d05c8d..5377cb6c2280 100644 --- a/mlir/include/mlir/Dialect/Rock/IR/MfmaInsnGroup.h +++ b/mlir/include/mlir/Dialect/Rock/IR/MfmaInsnGroup.h @@ -31,7 +31,13 @@ enum class MfmaTypeId : uint32_t { Fp8Fp8TyId, Fp8Bf8TyId, Bf8Fp8TyId, - Bf8Bf8TyId + Bf8Bf8TyId, + // FP8 via scaled MFMA (uses mfma_scale_f32_16x16x128_f8f6f4 with cbsz=0) + // These provide larger K dimension (128 for 16x16, 64 for 32x32) + Fp8Fp8ScaledTyId, + Fp8Bf8ScaledTyId, + Bf8Fp8ScaledTyId, + Bf8Bf8ScaledTyId }; struct MfmaInsnInfo { @@ -155,6 +161,10 @@ class MfmaInsnGroup { bool isCoherentWithK(int64_t kPack, int64_t kPerBlock, int64_t scheduleVersion); SmallString<16> getROCDLIntrinsicName() { return groupAttr.insn; } + + // Check if this is FP8 using scaled MFMA (mfma_scale with cbsz=0, blgp=0) + // These instructions have larger K dimension (128 for 16x16, 64 for 32x32) + bool isScaledFp8() const; }; } // namespace rock diff --git a/mlir/lib/Dialect/Rock/IR/MfmaInsnGroup.cpp b/mlir/lib/Dialect/Rock/IR/MfmaInsnGroup.cpp index f02ba3e9012f..aa65d2be6228 100644 --- a/mlir/lib/Dialect/Rock/IR/MfmaInsnGroup.cpp +++ b/mlir/lib/Dialect/Rock/IR/MfmaInsnGroup.cpp @@ -116,7 +116,27 @@ static auto getMfmaInsnInfoMap = []() -> const llvm::StringMap & { {ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), {MfmaTypeId::Fp4TyId, 16, 128, 1}}, {ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(), - {MfmaTypeId::Fp4TyId, 32, 64, 1}}}; + {MfmaTypeId::Fp4TyId, 32, 64, 1}}, + + // fp8 via scaled MFMA (cbsz=0, blgp=0 gives FP8 mode) + // Uses same instructions as FP4 but with different cbsz/blgp + // K dimension is same: 128 for 16x16, 64 for 32x32 + {ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), + {MfmaTypeId::Fp8Fp8ScaledTyId, 16, 128, 1}}, + {ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(), + {MfmaTypeId::Fp8Fp8ScaledTyId, 32, 64, 1}}, + {ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), + {MfmaTypeId::Fp8Bf8ScaledTyId, 16, 128, 1}}, + {ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(), + {MfmaTypeId::Fp8Bf8ScaledTyId, 32, 64, 1}}, + {ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), + {MfmaTypeId::Bf8Fp8ScaledTyId, 16, 128, 1}}, + {ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(), + {MfmaTypeId::Bf8Fp8ScaledTyId, 32, 64, 1}}, + {ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), + {MfmaTypeId::Bf8Bf8ScaledTyId, 16, 128, 1}}, + {ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(), + {MfmaTypeId::Bf8Bf8ScaledTyId, 32, 64, 1}}}; return insnInfo; }; @@ -448,6 +468,25 @@ static auto getMfmaInsnGroupAttrMapGfx950 = []() { {{MfmaTypeId::Fp4TyId, 32, 32}, {ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName()}}, + // FP8 via scaled MFMA (cbsz=0, blgp=0 mode) + // 16x16 with K=128, 32x32 with K=64 + {{MfmaTypeId::Fp8Fp8ScaledTyId, 16, 16}, + {ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName()}}, + {{MfmaTypeId::Fp8Fp8ScaledTyId, 32, 32}, + {ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName()}}, + {{MfmaTypeId::Fp8Bf8ScaledTyId, 16, 16}, + {ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName()}}, + {{MfmaTypeId::Fp8Bf8ScaledTyId, 32, 32}, + {ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName()}}, + {{MfmaTypeId::Bf8Fp8ScaledTyId, 16, 16}, + {ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName()}}, + {{MfmaTypeId::Bf8Fp8ScaledTyId, 32, 32}, + {ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName()}}, + {{MfmaTypeId::Bf8Bf8ScaledTyId, 16, 16}, + {ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName()}}, + {{MfmaTypeId::Bf8Bf8ScaledTyId, 32, 32}, + {ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName()}}, + // i8 double rate {{MfmaTypeId::I8TyId, 32, 32}, {ROCDL::mfma_i32_32x32x32_i8::getOperationName()}}, @@ -583,6 +622,28 @@ static MfmaTypeId convertTypesToId(Type dataTypeA, Type dataTypeB) { llvm_unreachable("Unsupported input argument type."); } +// Convert native FP8 TypeId to scaled FP8 TypeId for gfx950 scaled MFMA +static std::optional getScaledFp8TypeId(MfmaTypeId nativeTypeId) { + switch (nativeTypeId) { + case MfmaTypeId::Fp8Fp8TyId: + return MfmaTypeId::Fp8Fp8ScaledTyId; + case MfmaTypeId::Fp8Bf8TyId: + return MfmaTypeId::Fp8Bf8ScaledTyId; + case MfmaTypeId::Bf8Fp8TyId: + return MfmaTypeId::Bf8Fp8ScaledTyId; + case MfmaTypeId::Bf8Bf8TyId: + return MfmaTypeId::Bf8Bf8ScaledTyId; + default: + return std::nullopt; + } +} + +// Check if this is a native FP8 type (not scaled) +static bool isNativeFp8TypeId(MfmaTypeId typeId) { + return typeId == MfmaTypeId::Fp8Fp8TyId || typeId == MfmaTypeId::Fp8Bf8TyId || + typeId == MfmaTypeId::Bf8Fp8TyId || typeId == MfmaTypeId::Bf8Bf8TyId; +} + FailureOr MfmaInsnGroup::select(Type elementTypeA, Type elementTypeB, StringRef arch, int64_t mnPerXdl, int64_t kPack, int64_t kPackPerBlock, @@ -624,6 +685,43 @@ MfmaInsnGroup::select(Type elementTypeA, Type elementTypeB, StringRef arch, }; auto selectForGfx950 = [&]() { + int64_t kPerBlock = kPack * kPackPerBlock; + + // For FP8 types, try scaled MFMA first if kPerBlock is large enough + // Scaled MFMA has K=128 for 16x16 and K=64 for 32x32 + if (isNativeFp8TypeId(key.type)) { + int64_t scaledK = (mPerMfmaGroup == 16) ? 128 : 64; + if (kPerBlock >= scaledK) { + LLVM_DEBUG(llvm::dbgs() << ">>> Trying scaled FP8: kPerBlock=" + << kPerBlock << " >= scaledK=" << scaledK + << "\n"); + auto scaledTypeId = getScaledFp8TypeId(key.type); + if (scaledTypeId) { + MfmaInsnGroupSelectKey scaledKey = {*scaledTypeId, mPerMfmaGroup, + nPerMfmaGroup}; + const auto &gfx950Map = getMfmaInsnGroupAttrMapGfx950(); + auto it = gfx950Map.find(scaledKey); + if (it != gfx950Map.end()) { + MfmaInsnGroupAttr groupAttr = (*it).second; + auto maybeInsn = MfmaInsn::select(groupAttr.insn); + if (succeeded(maybeInsn)) { + auto scaledResult = + MfmaInsnGroup(elementTypeA, elementTypeB, *maybeInsn, groupAttr); + if (scaledResult.isCoherentWithK(kPack, kPackPerBlock)) { + LLVM_DEBUG(llvm::dbgs() + << ">>> SELECTED SCALED FP8 MFMA: K=" + << maybeInsn->getAttr().k << "\n"); + result = scaledResult; + return; + } + } + } + } + LLVM_DEBUG(llvm::dbgs() + << ">>> Scaled FP8 MFMA not suitable, falling back to native\n"); + } + } + // gfx950 has double rate instructions. Select from those first. selectFrom(getMfmaInsnGroupAttrMapGfx950()); if (succeeded(result)) { @@ -714,3 +812,30 @@ bool MfmaInsnGroup::isCoherentWithK(int64_t kpack, int64_t kPerBlock, int64_t scheduleVersion) { return insn.isCoherentWithK(kpack, kPerBlock, scheduleVersion); } + +bool MfmaInsnGroup::isScaledFp8() const { + // Check if the instruction is a scaled MFMA (rocdl.mfma.scale.f32.*x*x*.f8f6f4) + StringRef insnName = groupAttr.insn; + llvm::errs() << "[isScaledFp8] insnName: " << insnName << "\n"; + bool isScaledInsn = insnName.contains("mfma.scale.f32.16x16x128.f8f6f4") || + insnName.contains("mfma.scale.f32.32x32x64.f8f6f4"); + llvm::errs() << "[isScaledFp8] isScaledInsn: " << isScaledInsn << "\n"; + if (!isScaledInsn) + return false; + + // Check if the element type is FP8 (not FP4) + // FP8 types: Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, Float8E5M2FNUZ + // FP4 types: Float4E2M1FN + bool isFp8A = isa(elementTypeA) || + isa(elementTypeA) || + isa(elementTypeA) || + isa(elementTypeA); + bool isFp8B = isa(elementTypeB) || + isa(elementTypeB) || + isa(elementTypeB) || + isa(elementTypeB); + + llvm::errs() << "[isScaledFp8] isFp8A: " << isFp8A << ", isFp8B: " << isFp8B << "\n"; + llvm::errs() << "[isScaledFp8] returning: " << (isFp8A && isFp8B) << "\n"; + return isFp8A && isFp8B; +} diff --git a/mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp b/mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp index 258b5369d7bc..8f40a418da9a 100644 --- a/mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp +++ b/mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp @@ -191,6 +191,20 @@ void MfmaEmitter::emitThreadwiseLoop(OpBuilder &b, Location loc, Value argA, VectorType vectorType = mfmaGroup.getRetType(); auto outputOffset = llvm::to_vector(regCOffset); bool isScaled = scaleA && scaleB; + bool isScaledFp8 = mfmaGroup.isScaledFp8(); + llvm::errs() << "[emitThreadwiseLoop] isScaled: " << isScaled + << ", isScaledFp8: " << isScaledFp8 << "\n"; + + // For scaled FP8 MFMA without explicit scale buffers, create neutral scales + // cbsz=0, blgp=0 mode: scale value of 0 means no scaling (2^0 = 1) + Value neutralScaleA, neutralScaleB; + if (isScaledFp8 && !isScaled) { + // Scale type for scaled MFMA is f8E8M0FNU + Type scaleType = b.getType(); + auto zeroAttr = b.getFloatAttr(scaleType, 0.0); + neutralScaleA = arith::ConstantOp::create(b, loc, scaleType, zeroAttr); + neutralScaleB = arith::ConstantOp::create(b, loc, scaleType, zeroAttr); + } for (int64_t i = 0; i < nResultVectors; ++i) { Value offset = b.createOrFold(loc, i); @@ -203,11 +217,21 @@ void MfmaEmitter::emitThreadwiseLoop(OpBuilder &b, Location loc, Value argA, Value vectorD; if (isScaled) { + // Explicit scale buffers provided (FP4 or scaled FP8 with explicit scales) auto mfma = amdgpu::ScaledMFMAOp::create( b, loc, vectorType, mfmaDDim, mfmaDDim, mfmaAttr.k, argA, argB, vectorC, scaleA, scaleB, /*scalesIdxA=*/0, /*scalesIdxB=*/0); vectorD = mfma.getDestD(); + } else if (isScaledFp8) { + // Scaled FP8 MFMA (K=128 for 16x16, K=64 for 32x32) without explicit scales + // Use neutral scale values (0) which means 2^0 = 1 (no scaling) + auto mfma = amdgpu::ScaledMFMAOp::create( + b, loc, vectorType, mfmaDDim, mfmaDDim, mfmaAttr.k, argA, argB, + vectorC, neutralScaleA, neutralScaleB, + /*scalesIdxA=*/0, /*scalesIdxB=*/0); + vectorD = mfma.getDestD(); } else { + // Regular MFMA auto mfma = amdgpu::MFMAOp::create( b, loc, vectorType, mfmaDDim, mfmaDDim, mfmaAttr.k, mfmaAttr.blocksMfma, argA, argB, vectorC, /*cbsz=*/imms[i].cbsz, From 6b5b14bd648093fcc703c450d36e1acddf5e9d0d Mon Sep 17 00:00:00 2001 From: stefan koncarevic Date: Fri, 20 Feb 2026 07:45:26 -0600 Subject: [PATCH 2/4] Add comprehensive tests for scaled FP8 MFMA instructions (32x32x64 and 16x16x128) on gfx950 architecture. These tests cover: MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Single buffering (scheduleVersion 1, 3) with kpack=32 and kpack=1 - Double buffering (scheduleVersion 2, 4) with kpack=32 - Double buffering with kpack < k_base (kpack=1, 4, 8) - All FP8 type combinations: FP8×FP8, BF8×BF8, FP8×BF8, BF8×FP8 The tests verify that amdgpu.scaled_mfma operations are correctly generated for OCP FP8 types (f8E4M3FN, f8E5M2) with implicit scale factors. --- mlir/lib/Dialect/Rock/IR/MfmaInsnGroup.cpp | 6 +- .../lib/Dialect/Rock/utility/AccelEmitter.cpp | 2 - .../Dialect/Rock/lowering_xdlops_gemm.mlir | 252 ++++++++++++++++++ 3 files changed, 253 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Dialect/Rock/IR/MfmaInsnGroup.cpp b/mlir/lib/Dialect/Rock/IR/MfmaInsnGroup.cpp index aa65d2be6228..7115e2e770b3 100644 --- a/mlir/lib/Dialect/Rock/IR/MfmaInsnGroup.cpp +++ b/mlir/lib/Dialect/Rock/IR/MfmaInsnGroup.cpp @@ -707,7 +707,7 @@ MfmaInsnGroup::select(Type elementTypeA, Type elementTypeB, StringRef arch, if (succeeded(maybeInsn)) { auto scaledResult = MfmaInsnGroup(elementTypeA, elementTypeB, *maybeInsn, groupAttr); - if (scaledResult.isCoherentWithK(kPack, kPackPerBlock)) { + if (scaledResult.isCoherentWithK(kPack, kPackPerBlock, scheduleVersion)) { LLVM_DEBUG(llvm::dbgs() << ">>> SELECTED SCALED FP8 MFMA: K=" << maybeInsn->getAttr().k << "\n"); @@ -816,10 +816,8 @@ bool MfmaInsnGroup::isCoherentWithK(int64_t kpack, int64_t kPerBlock, bool MfmaInsnGroup::isScaledFp8() const { // Check if the instruction is a scaled MFMA (rocdl.mfma.scale.f32.*x*x*.f8f6f4) StringRef insnName = groupAttr.insn; - llvm::errs() << "[isScaledFp8] insnName: " << insnName << "\n"; bool isScaledInsn = insnName.contains("mfma.scale.f32.16x16x128.f8f6f4") || insnName.contains("mfma.scale.f32.32x32x64.f8f6f4"); - llvm::errs() << "[isScaledFp8] isScaledInsn: " << isScaledInsn << "\n"; if (!isScaledInsn) return false; @@ -835,7 +833,5 @@ bool MfmaInsnGroup::isScaledFp8() const { isa(elementTypeB) || isa(elementTypeB); - llvm::errs() << "[isScaledFp8] isFp8A: " << isFp8A << ", isFp8B: " << isFp8B << "\n"; - llvm::errs() << "[isScaledFp8] returning: " << (isFp8A && isFp8B) << "\n"; return isFp8A && isFp8B; } diff --git a/mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp b/mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp index 8f40a418da9a..b1d93626d917 100644 --- a/mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp +++ b/mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp @@ -192,8 +192,6 @@ void MfmaEmitter::emitThreadwiseLoop(OpBuilder &b, Location loc, Value argA, auto outputOffset = llvm::to_vector(regCOffset); bool isScaled = scaleA && scaleB; bool isScaledFp8 = mfmaGroup.isScaledFp8(); - llvm::errs() << "[emitThreadwiseLoop] isScaled: " << isScaled - << ", isScaledFp8: " << isScaledFp8 << "\n"; // For scaled FP8 MFMA without explicit scale buffers, create neutral scales // cbsz=0, blgp=0 mode: scale value of 0 means no scaling (2^0 = 1) diff --git a/mlir/test/Dialect/Rock/lowering_xdlops_gemm.mlir b/mlir/test/Dialect/Rock/lowering_xdlops_gemm.mlir index cca87c21f4b8..1a3baeb845c0 100644 --- a/mlir/test/Dialect/Rock/lowering_xdlops_gemm.mlir +++ b/mlir/test/Dialect/Rock/lowering_xdlops_gemm.mlir @@ -649,3 +649,255 @@ func.func @accel_gemm_gfx950_f32_16x16x512_fp4_scaled_multi(%matrixA : memref<1x } : memref<1x1xvector<4xf32>, 5> += memref<1x4xvector<32xf4E2M1FN>, 5> scaled by memref<1x4xvector<32xf8E8M0FNU>, 5> * memref<1x4xvector<32xf4E2M1FN>, 5> scaled by memref<1x4xvector<32xf8E8M0FNU>, 5> return } + +func.func @accel_gemm_gfx950_scaled_fp8_bf8_32x32x64_single_v1(%matrixA : memref<1x2xvector<32xf8E4M3FN>, 5>, + %matrixB : memref<1x2xvector<32xf8E5M2>, 5>, + %matrixC : memref<1x1xvector<16xf32>, 5>) { + // CHECK-LABEL: func.func @accel_gemm_gfx950_scaled_fp8_bf8_32x32x64_single_v1 + // CHECK: rock.transforming_for + // CHECK-SAME: bounds [1, 1, 1] + // CHECK: amdgpu.scaled_mfma 32x32x64 + // CHECK-SAME: vector<32xf8E4M3FN>{{.*}}vector<32xf8E5M2>{{.*}}vector<16xf32> + // CHECK-NOT: amdgpu.scaled_mfma + %c0 = arith.constant 0 : index + rock.threadwise_gemm_accel %matrixC += %matrixA * %matrixB at [%c0, %c0, %c0] features = mfma { + arch = "amdgcn-amd-amdhsa:gfx950", + params = #rock.accel_gemm_params< + kpackPerBlock = 2, + mPerBlock = 32, + nPerBlock = 32, + kpack = 32, + mPerWave = 32, + nPerWave = 32, + mnPerXdl = 32, + splitKFactor = 1, + scheduleVersion = 1, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, + forceUnroll = true> + } : memref<1x1xvector<16xf32>, 5> += memref<1x2xvector<32xf8E4M3FN>, 5> * memref<1x2xvector<32xf8E5M2>, 5> + return +} + +func.func @accel_gemm_gfx950_scaled_fp8_fp8_16x16x128_single_v1(%matrixA : memref<1x4xvector<32xf8E4M3FN>, 5>, + %matrixB : memref<1x4xvector<32xf8E4M3FN>, 5>, + %matrixC : memref<1x1xvector<4xf32>, 5>) { + // CHECK-LABEL: func.func @accel_gemm_gfx950_scaled_fp8_fp8_16x16x128_single_v1 + // CHECK: rock.transforming_for + // CHECK-SAME: bounds [1, 1, 1] + // CHECK: amdgpu.scaled_mfma 16x16x128 + // CHECK-SAME: vector<32xf8E4M3FN>{{.*}}vector<32xf8E4M3FN>{{.*}}vector<4xf32> + // CHECK-NOT: amdgpu.scaled_mfma + %c0 = arith.constant 0 : index + rock.threadwise_gemm_accel %matrixC += %matrixA * %matrixB at [%c0, %c0, %c0] features = mfma { + arch = "amdgcn-amd-amdhsa:gfx950", + params = #rock.accel_gemm_params< + kpackPerBlock = 4, + mPerBlock = 16, + nPerBlock = 16, + kpack = 32, + mPerWave = 16, + nPerWave = 16, + mnPerXdl = 16, + splitKFactor = 1, + scheduleVersion = 1, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, + forceUnroll = true> + } : memref<1x1xvector<4xf32>, 5> += memref<1x4xvector<32xf8E4M3FN>, 5> * memref<1x4xvector<32xf8E4M3FN>, 5> + return +} + +func.func @accel_gemm_gfx950_scaled_bf8_bf8_16x16x128_single_v3_kpack1(%matrixA : memref<1x1xvector<32xf8E5M2>, 5>, + %matrixB : memref<1x1xvector<32xf8E5M2>, 5>, + %matrixC : memref<1x1xvector<4xf32>, 5>) { + // CHECK-LABEL: func.func @accel_gemm_gfx950_scaled_bf8_bf8_16x16x128_single_v3_kpack1 + // CHECK: rock.transforming_for + // CHECK-SAME: bounds [1, 1, 1] + // CHECK: amdgpu.scaled_mfma 16x16x128 + // CHECK-SAME: vector<32xf8E5M2>{{.*}}vector<32xf8E5M2>{{.*}}vector<4xf32> + // CHECK-NOT: amdgpu.scaled_mfma + %c0 = arith.constant 0 : index + rock.threadwise_gemm_accel %matrixC += %matrixA * %matrixB at [%c0, %c0, %c0] features = mfma { + arch = "amdgcn-amd-amdhsa:gfx950", + params = #rock.accel_gemm_params< + kpackPerBlock = 128, + mPerBlock = 16, + nPerBlock = 16, + kpack = 1, + mPerWave = 16, + nPerWave = 16, + mnPerXdl = 16, + splitKFactor = 1, + scheduleVersion = 3, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, + forceUnroll = true> + } : memref<1x1xvector<4xf32>, 5> += memref<1x1xvector<32xf8E5M2>, 5> * memref<1x1xvector<32xf8E5M2>, 5> + return +} + +func.func @accel_gemm_gfx950_scaled_bf8_fp8_32x32x64_single_v3_kpack1(%matrixA : memref<1x1xvector<32xf8E5M2>, 5>, + %matrixB : memref<1x1xvector<32xf8E4M3FN>, 5>, + %matrixC : memref<1x1xvector<16xf32>, 5>) { + // CHECK-LABEL: func.func @accel_gemm_gfx950_scaled_bf8_fp8_32x32x64_single_v3_kpack1 + // CHECK: rock.transforming_for + // CHECK-SAME: bounds [1, 1, 1] + // CHECK: amdgpu.scaled_mfma 32x32x64 + // CHECK-SAME: vector<32xf8E5M2>{{.*}}vector<32xf8E4M3FN>{{.*}}vector<16xf32> + // CHECK-NOT: amdgpu.scaled_mfma + %c0 = arith.constant 0 : index + rock.threadwise_gemm_accel %matrixC += %matrixA * %matrixB at [%c0, %c0, %c0] features = mfma { + arch = "amdgcn-amd-amdhsa:gfx950", + params = #rock.accel_gemm_params< + kpackPerBlock = 64, + mPerBlock = 32, + nPerBlock = 32, + kpack = 1, + mPerWave = 32, + nPerWave = 32, + mnPerXdl = 32, + splitKFactor = 1, + scheduleVersion = 3, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, + forceUnroll = true> + } : memref<1x1xvector<16xf32>, 5> += memref<1x1xvector<32xf8E5M2>, 5> * memref<1x1xvector<32xf8E4M3FN>, 5> + return +} + +func.func @accel_gemm_gfx950_scaled_fp8_bf8_32x32x64_double_v2(%matrixA : memref<1x2xvector<32xf8E4M3FN>, 5>, + %matrixB : memref<1x2xvector<32xf8E5M2>, 5>, + %matrixC : memref<1x1xvector<16xf32>, 5>) { + // CHECK-LABEL: func.func @accel_gemm_gfx950_scaled_fp8_bf8_32x32x64_double_v2 + // CHECK: rock.transforming_for + // CHECK-SAME: bounds [1, 1, 1] + // CHECK: amdgpu.scaled_mfma 32x32x64 + // CHECK-SAME: vector<32xf8E4M3FN>{{.*}}vector<32xf8E5M2>{{.*}}vector<16xf32> + // CHECK-NOT: amdgpu.scaled_mfma + %c0 = arith.constant 0 : index + rock.threadwise_gemm_accel %matrixC += %matrixA * %matrixB at [%c0, %c0, %c0] features = mfma { + arch = "amdgcn-amd-amdhsa:gfx950", + params = #rock.accel_gemm_params< + kpackPerBlock = 2, + mPerBlock = 32, + nPerBlock = 32, + kpack = 32, + mPerWave = 32, + nPerWave = 32, + mnPerXdl = 32, + splitKFactor = 1, + scheduleVersion = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, + forceUnroll = true> + } : memref<1x1xvector<16xf32>, 5> += memref<1x2xvector<32xf8E4M3FN>, 5> * memref<1x2xvector<32xf8E5M2>, 5> + return +} + +func.func @accel_gemm_gfx950_scaled_bf8_bf8_16x16x128_double_v4(%matrixA : memref<1x4xvector<32xf8E5M2>, 5>, + %matrixB : memref<1x4xvector<32xf8E5M2>, 5>, + %matrixC : memref<1x1xvector<4xf32>, 5>) { + // CHECK-LABEL: func.func @accel_gemm_gfx950_scaled_bf8_bf8_16x16x128_double_v4 + // CHECK: rock.transforming_for + // CHECK-SAME: bounds [1, 1, 1] + // CHECK: amdgpu.scaled_mfma 16x16x128 + // CHECK-SAME: vector<32xf8E5M2>{{.*}}vector<32xf8E5M2>{{.*}}vector<4xf32> + // CHECK-NOT: amdgpu.scaled_mfma + %c0 = arith.constant 0 : index + rock.threadwise_gemm_accel %matrixC += %matrixA * %matrixB at [%c0, %c0, %c0] features = mfma { + arch = "amdgcn-amd-amdhsa:gfx950", + params = #rock.accel_gemm_params< + kpackPerBlock = 4, + mPerBlock = 16, + nPerBlock = 16, + kpack = 32, + mPerWave = 16, + nPerWave = 16, + mnPerXdl = 16, + splitKFactor = 1, + scheduleVersion = 4, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, + forceUnroll = true> + } : memref<1x1xvector<4xf32>, 5> += memref<1x4xvector<32xf8E5M2>, 5> * memref<1x4xvector<32xf8E5M2>, 5> + return +} + +func.func @accel_gemm_gfx950_scaled_bf8_bf8_32x32x64_double_v2_kpack8(%matrixA : memref<1x2xvector<32xf8E5M2>, 5>, + %matrixB : memref<1x2xvector<32xf8E5M2>, 5>, + %matrixC : memref<1x1xvector<16xf32>, 5>) { + // CHECK-LABEL: func.func @accel_gemm_gfx950_scaled_bf8_bf8_32x32x64_double_v2_kpack8 + // CHECK: rock.transforming_for + // CHECK-SAME: bounds [1, 1, 1] + // CHECK: amdgpu.scaled_mfma 32x32x64 + // CHECK-SAME: vector<32xf8E5M2>{{.*}}vector<32xf8E5M2>{{.*}}vector<16xf32> + // CHECK-NOT: amdgpu.scaled_mfma + %c0 = arith.constant 0 : index + rock.threadwise_gemm_accel %matrixC += %matrixA * %matrixB at [%c0, %c0, %c0] features = mfma { + arch = "amdgcn-amd-amdhsa:gfx950", + params = #rock.accel_gemm_params< + kpackPerBlock = 8, + mPerBlock = 32, + nPerBlock = 32, + kpack = 8, + mPerWave = 32, + nPerWave = 32, + mnPerXdl = 32, + splitKFactor = 1, + scheduleVersion = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, + forceUnroll = true> + } : memref<1x1xvector<16xf32>, 5> += memref<1x2xvector<32xf8E5M2>, 5> * memref<1x2xvector<32xf8E5M2>, 5> + return +} + +func.func @accel_gemm_gfx950_scaled_fp8_fp8_16x16x128_double_v4_kpack4(%matrixA : memref<1x4xvector<32xf8E4M3FN>, 5>, + %matrixB : memref<1x4xvector<32xf8E4M3FN>, 5>, + %matrixC : memref<1x1xvector<4xf32>, 5>) { + // CHECK-LABEL: func.func @accel_gemm_gfx950_scaled_fp8_fp8_16x16x128_double_v4_kpack4 + // CHECK: rock.transforming_for + // CHECK-SAME: bounds [1, 1, 1] + // CHECK: amdgpu.scaled_mfma 16x16x128 + // CHECK-SAME: vector<32xf8E4M3FN>{{.*}}vector<32xf8E4M3FN>{{.*}}vector<4xf32> + // CHECK-NOT: amdgpu.scaled_mfma + %c0 = arith.constant 0 : index + rock.threadwise_gemm_accel %matrixC += %matrixA * %matrixB at [%c0, %c0, %c0] features = mfma { + arch = "amdgcn-amd-amdhsa:gfx950", + params = #rock.accel_gemm_params< + kpackPerBlock = 32, + mPerBlock = 16, + nPerBlock = 16, + kpack = 4, + mPerWave = 16, + nPerWave = 16, + mnPerXdl = 16, + splitKFactor = 1, + scheduleVersion = 4, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, + forceUnroll = true> + } : memref<1x1xvector<4xf32>, 5> += memref<1x4xvector<32xf8E4M3FN>, 5> * memref<1x4xvector<32xf8E4M3FN>, 5> + return +} + +func.func @accel_gemm_gfx950_scaled_bf8_bf8_16x16x128_double_v2_kpack1(%matrixA : memref<1x4xvector<32xf8E5M2>, 5>, + %matrixB : memref<1x4xvector<32xf8E5M2>, 5>, + %matrixC : memref<1x1xvector<4xf32>, 5>) { + // CHECK-LABEL: func.func @accel_gemm_gfx950_scaled_bf8_bf8_16x16x128_double_v2_kpack1 + // CHECK: rock.transforming_for + // CHECK-SAME: bounds [1, 1, 1] + // CHECK: amdgpu.scaled_mfma 16x16x128 + // CHECK-SAME: vector<32xf8E5M2>{{.*}}vector<32xf8E5M2>{{.*}}vector<4xf32> + // CHECK-NOT: amdgpu.scaled_mfma + %c0 = arith.constant 0 : index + rock.threadwise_gemm_accel %matrixC += %matrixA * %matrixB at [%c0, %c0, %c0] features = mfma { + arch = "amdgcn-amd-amdhsa:gfx950", + params = #rock.accel_gemm_params< + kpackPerBlock = 128, + mPerBlock = 16, + nPerBlock = 16, + kpack = 1, + mPerWave = 16, + nPerWave = 16, + mnPerXdl = 16, + splitKFactor = 1, + scheduleVersion = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, + forceUnroll = true> + } : memref<1x1xvector<4xf32>, 5> += memref<1x4xvector<32xf8E5M2>, 5> * memref<1x4xvector<32xf8E5M2>, 5> + return +} From d9b4eac0bbf274f19ee5ef70d5ab38343ba1597c Mon Sep 17 00:00:00 2001 From: stefan koncarevic Date: Mon, 23 Feb 2026 05:58:28 -0600 Subject: [PATCH 3/4] Clean up scaled FP8 MFMA code based on review feedback - Remove duplicate entries in getMfmaInsnInfoMap - Clarify neutral scale creation comment in AccelEmitter.cpp - Rename zeroAttr to neutralScaleAttr for clarity --- mlir/lib/Dialect/Rock/IR/MfmaInsnGroup.cpp | 29 +++++-------------- .../lib/Dialect/Rock/utility/AccelEmitter.cpp | 16 ++++++---- 2 files changed, 17 insertions(+), 28 deletions(-) diff --git a/mlir/lib/Dialect/Rock/IR/MfmaInsnGroup.cpp b/mlir/lib/Dialect/Rock/IR/MfmaInsnGroup.cpp index 7115e2e770b3..1676bcdab80c 100644 --- a/mlir/lib/Dialect/Rock/IR/MfmaInsnGroup.cpp +++ b/mlir/lib/Dialect/Rock/IR/MfmaInsnGroup.cpp @@ -112,31 +112,16 @@ static auto getMfmaInsnInfoMap = []() -> const llvm::StringMap & { {ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName(), {MfmaTypeId::Bf8Bf8TyId, 16, 32, 1}}, - // fp4 + // Scaled MFMA instructions (FP4 and scaled FP8 types) + // Note: FP8 scaled types (Fp8Fp8ScaledTyId, Fp8Bf8ScaledTyId, etc.) + // use the same underlying instruction with identical (mfmaDDim, k, blocksMfma). + // Since deriveAttr only uses those fields (not MfmaTypeId), we only need + // one entry per instruction. The type differentiation happens elsewhere + // via cbsz/blgp parameters at code generation time. {ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), {MfmaTypeId::Fp4TyId, 16, 128, 1}}, {ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(), - {MfmaTypeId::Fp4TyId, 32, 64, 1}}, - - // fp8 via scaled MFMA (cbsz=0, blgp=0 gives FP8 mode) - // Uses same instructions as FP4 but with different cbsz/blgp - // K dimension is same: 128 for 16x16, 64 for 32x32 - {ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), - {MfmaTypeId::Fp8Fp8ScaledTyId, 16, 128, 1}}, - {ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(), - {MfmaTypeId::Fp8Fp8ScaledTyId, 32, 64, 1}}, - {ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), - {MfmaTypeId::Fp8Bf8ScaledTyId, 16, 128, 1}}, - {ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(), - {MfmaTypeId::Fp8Bf8ScaledTyId, 32, 64, 1}}, - {ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), - {MfmaTypeId::Bf8Fp8ScaledTyId, 16, 128, 1}}, - {ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(), - {MfmaTypeId::Bf8Fp8ScaledTyId, 32, 64, 1}}, - {ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), - {MfmaTypeId::Bf8Bf8ScaledTyId, 16, 128, 1}}, - {ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(), - {MfmaTypeId::Bf8Bf8ScaledTyId, 32, 64, 1}}}; + {MfmaTypeId::Fp4TyId, 32, 64, 1}}}; return insnInfo; }; diff --git a/mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp b/mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp index b1d93626d917..2621fe5e048f 100644 --- a/mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp +++ b/mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp @@ -193,15 +193,19 @@ void MfmaEmitter::emitThreadwiseLoop(OpBuilder &b, Location loc, Value argA, bool isScaled = scaleA && scaleB; bool isScaledFp8 = mfmaGroup.isScaledFp8(); - // For scaled FP8 MFMA without explicit scale buffers, create neutral scales - // cbsz=0, blgp=0 mode: scale value of 0 means no scaling (2^0 = 1) + // For scaled FP8 MFMA without explicit scale buffers, create neutral scales. + // In cbsz=0, blgp=0 mode, a scale exponent value of 0 means no scaling + // because 2^0 = 1. For Float8E8M0FNU (an exponent-only format), the call + // getFloatAttr(scaleType, 0.0) is used to produce the encoding with + // exponent = 0 (all-zero bit pattern), which corresponds to a scale of 1. Value neutralScaleA, neutralScaleB; if (isScaledFp8 && !isScaled) { - // Scale type for scaled MFMA is f8E8M0FNU Type scaleType = b.getType(); - auto zeroAttr = b.getFloatAttr(scaleType, 0.0); - neutralScaleA = arith::ConstantOp::create(b, loc, scaleType, zeroAttr); - neutralScaleB = arith::ConstantOp::create(b, loc, scaleType, zeroAttr); + auto neutralScaleAttr = b.getFloatAttr(scaleType, 0.0); + neutralScaleA = + arith::ConstantOp::create(b, loc, scaleType, neutralScaleAttr); + neutralScaleB = + arith::ConstantOp::create(b, loc, scaleType, neutralScaleAttr); } for (int64_t i = 0; i < nResultVectors; ++i) { From ff1d1c93d77f86b1a6b6d90fe579fedea41ba44b Mon Sep 17 00:00:00 2001 From: stefan koncarevic Date: Mon, 23 Feb 2026 06:49:03 -0600 Subject: [PATCH 4/4] Clang format --- mlir/lib/Dialect/Rock/IR/MfmaInsnGroup.cpp | 29 ++++++++++++---------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/mlir/lib/Dialect/Rock/IR/MfmaInsnGroup.cpp b/mlir/lib/Dialect/Rock/IR/MfmaInsnGroup.cpp index 1676bcdab80c..19ebb4296bfe 100644 --- a/mlir/lib/Dialect/Rock/IR/MfmaInsnGroup.cpp +++ b/mlir/lib/Dialect/Rock/IR/MfmaInsnGroup.cpp @@ -114,7 +114,8 @@ static auto getMfmaInsnInfoMap = []() -> const llvm::StringMap & { // Scaled MFMA instructions (FP4 and scaled FP8 types) // Note: FP8 scaled types (Fp8Fp8ScaledTyId, Fp8Bf8ScaledTyId, etc.) - // use the same underlying instruction with identical (mfmaDDim, k, blocksMfma). + // use the same underlying instruction with identical (mfmaDDim, k, + // blocksMfma). // Since deriveAttr only uses those fields (not MfmaTypeId), we only need // one entry per instruction. The type differentiation happens elsewhere // via cbsz/blgp parameters at code generation time. @@ -677,9 +678,9 @@ MfmaInsnGroup::select(Type elementTypeA, Type elementTypeB, StringRef arch, if (isNativeFp8TypeId(key.type)) { int64_t scaledK = (mPerMfmaGroup == 16) ? 128 : 64; if (kPerBlock >= scaledK) { - LLVM_DEBUG(llvm::dbgs() << ">>> Trying scaled FP8: kPerBlock=" - << kPerBlock << " >= scaledK=" << scaledK - << "\n"); + LLVM_DEBUG(llvm::dbgs() + << ">>> Trying scaled FP8: kPerBlock=" << kPerBlock + << " >= scaledK=" << scaledK << "\n"); auto scaledTypeId = getScaledFp8TypeId(key.type); if (scaledTypeId) { MfmaInsnGroupSelectKey scaledKey = {*scaledTypeId, mPerMfmaGroup, @@ -690,20 +691,21 @@ MfmaInsnGroup::select(Type elementTypeA, Type elementTypeB, StringRef arch, MfmaInsnGroupAttr groupAttr = (*it).second; auto maybeInsn = MfmaInsn::select(groupAttr.insn); if (succeeded(maybeInsn)) { - auto scaledResult = - MfmaInsnGroup(elementTypeA, elementTypeB, *maybeInsn, groupAttr); - if (scaledResult.isCoherentWithK(kPack, kPackPerBlock, scheduleVersion)) { - LLVM_DEBUG(llvm::dbgs() - << ">>> SELECTED SCALED FP8 MFMA: K=" - << maybeInsn->getAttr().k << "\n"); + auto scaledResult = MfmaInsnGroup(elementTypeA, elementTypeB, + *maybeInsn, groupAttr); + if (scaledResult.isCoherentWithK(kPack, kPackPerBlock, + scheduleVersion)) { + LLVM_DEBUG(llvm::dbgs() << ">>> SELECTED SCALED FP8 MFMA: K=" + << maybeInsn->getAttr().k << "\n"); result = scaledResult; return; } } } } - LLVM_DEBUG(llvm::dbgs() - << ">>> Scaled FP8 MFMA not suitable, falling back to native\n"); + LLVM_DEBUG( + llvm::dbgs() + << ">>> Scaled FP8 MFMA not suitable, falling back to native\n"); } } @@ -799,7 +801,8 @@ bool MfmaInsnGroup::isCoherentWithK(int64_t kpack, int64_t kPerBlock, } bool MfmaInsnGroup::isScaledFp8() const { - // Check if the instruction is a scaled MFMA (rocdl.mfma.scale.f32.*x*x*.f8f6f4) + // Check if the instruction is a scaled MFMA + // (rocdl.mfma.scale.f32.*x*x*.f8f6f4) StringRef insnName = groupAttr.insn; bool isScaledInsn = insnName.contains("mfma.scale.f32.16x16x128.f8f6f4") || insnName.contains("mfma.scale.f32.32x32x64.f8f6f4");