Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion mlir/include/mlir/Dialect/Rock/IR/MfmaInsnGroup.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,13 @@ enum class MfmaTypeId : uint32_t {
Fp8Fp8TyId,
Fp8Bf8TyId,
Bf8Fp8TyId,
Bf8Bf8TyId
Bf8Bf8TyId,
// FP8 via scaled MFMA (uses mfma_scale_f32_16x16x128_f8f6f4 with cbsz=0)
// These provide larger K dimension (128 for 16x16, 64 for 32x32)
Fp8Fp8ScaledTyId,
Fp8Bf8ScaledTyId,
Bf8Fp8ScaledTyId,
Bf8Bf8ScaledTyId
};

struct MfmaInsnInfo {
Expand Down Expand Up @@ -155,6 +161,10 @@ class MfmaInsnGroup {
bool isCoherentWithK(int64_t kPack, int64_t kPerBlock,
int64_t scheduleVersion);
SmallString<16> getROCDLIntrinsicName() { return groupAttr.insn; }

// Check if this is FP8 using scaled MFMA (mfma_scale with cbsz=0, blgp=0)
// These instructions have larger K dimension (128 for 16x16, 64 for 32x32)
bool isScaledFp8() const;
};

} // namespace rock
Expand Down
111 changes: 110 additions & 1 deletion mlir/lib/Dialect/Rock/IR/MfmaInsnGroup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,13 @@ static auto getMfmaInsnInfoMap = []() -> const llvm::StringMap<MfmaInsnInfo> & {
{ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName(),
{MfmaTypeId::Bf8Bf8TyId, 16, 32, 1}},

// fp4
// Scaled MFMA instructions (FP4 and scaled FP8 types)
// Note: FP8 scaled types (Fp8Fp8ScaledTyId, Fp8Bf8ScaledTyId, etc.)
// use the same underlying instruction with identical (mfmaDDim, k,
// blocksMfma).
// Since deriveAttr only uses those fields (not MfmaTypeId), we only need
// one entry per instruction. The type differentiation happens elsewhere
// via cbsz/blgp parameters at code generation time.
{ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(),
{MfmaTypeId::Fp4TyId, 16, 128, 1}},
{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(),
Expand Down Expand Up @@ -448,6 +454,25 @@ static auto getMfmaInsnGroupAttrMapGfx950 = []() {
{{MfmaTypeId::Fp4TyId, 32, 32},
{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName()}},

// FP8 via scaled MFMA (cbsz=0, blgp=0 mode)
// 16x16 with K=128, 32x32 with K=64
{{MfmaTypeId::Fp8Fp8ScaledTyId, 16, 16},
{ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName()}},
{{MfmaTypeId::Fp8Fp8ScaledTyId, 32, 32},
{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName()}},
{{MfmaTypeId::Fp8Bf8ScaledTyId, 16, 16},
{ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName()}},
{{MfmaTypeId::Fp8Bf8ScaledTyId, 32, 32},
{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName()}},
{{MfmaTypeId::Bf8Fp8ScaledTyId, 16, 16},
{ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName()}},
{{MfmaTypeId::Bf8Fp8ScaledTyId, 32, 32},
{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName()}},
{{MfmaTypeId::Bf8Bf8ScaledTyId, 16, 16},
{ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName()}},
{{MfmaTypeId::Bf8Bf8ScaledTyId, 32, 32},
{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName()}},

// i8 double rate
{{MfmaTypeId::I8TyId, 32, 32},
{ROCDL::mfma_i32_32x32x32_i8::getOperationName()}},
Expand Down Expand Up @@ -583,6 +608,28 @@ static MfmaTypeId convertTypesToId(Type dataTypeA, Type dataTypeB) {
llvm_unreachable("Unsupported input argument type.");
}

// Convert native FP8 TypeId to scaled FP8 TypeId for gfx950 scaled MFMA
static std::optional<MfmaTypeId> getScaledFp8TypeId(MfmaTypeId nativeTypeId) {
switch (nativeTypeId) {
case MfmaTypeId::Fp8Fp8TyId:
return MfmaTypeId::Fp8Fp8ScaledTyId;
case MfmaTypeId::Fp8Bf8TyId:
return MfmaTypeId::Fp8Bf8ScaledTyId;
case MfmaTypeId::Bf8Fp8TyId:
return MfmaTypeId::Bf8Fp8ScaledTyId;
case MfmaTypeId::Bf8Bf8TyId:
return MfmaTypeId::Bf8Bf8ScaledTyId;
default:
return std::nullopt;
}
}

// Check if this is a native FP8 type (not scaled)
static bool isNativeFp8TypeId(MfmaTypeId typeId) {
return typeId == MfmaTypeId::Fp8Fp8TyId || typeId == MfmaTypeId::Fp8Bf8TyId ||
typeId == MfmaTypeId::Bf8Fp8TyId || typeId == MfmaTypeId::Bf8Bf8TyId;
}

FailureOr<MfmaInsnGroup>
MfmaInsnGroup::select(Type elementTypeA, Type elementTypeB, StringRef arch,
int64_t mnPerXdl, int64_t kPack, int64_t kPackPerBlock,
Expand Down Expand Up @@ -624,6 +671,44 @@ MfmaInsnGroup::select(Type elementTypeA, Type elementTypeB, StringRef arch,
};

auto selectForGfx950 = [&]() {
int64_t kPerBlock = kPack * kPackPerBlock;

// For FP8 types, try scaled MFMA first if kPerBlock is large enough
// Scaled MFMA has K=128 for 16x16 and K=64 for 32x32
if (isNativeFp8TypeId(key.type)) {
int64_t scaledK = (mPerMfmaGroup == 16) ? 128 : 64;
if (kPerBlock >= scaledK) {
LLVM_DEBUG(llvm::dbgs()
<< ">>> Trying scaled FP8: kPerBlock=" << kPerBlock
<< " >= scaledK=" << scaledK << "\n");
auto scaledTypeId = getScaledFp8TypeId(key.type);
if (scaledTypeId) {
MfmaInsnGroupSelectKey scaledKey = {*scaledTypeId, mPerMfmaGroup,
nPerMfmaGroup};
const auto &gfx950Map = getMfmaInsnGroupAttrMapGfx950();
auto it = gfx950Map.find(scaledKey);
if (it != gfx950Map.end()) {
MfmaInsnGroupAttr groupAttr = (*it).second;
auto maybeInsn = MfmaInsn::select(groupAttr.insn);
if (succeeded(maybeInsn)) {
auto scaledResult = MfmaInsnGroup(elementTypeA, elementTypeB,
*maybeInsn, groupAttr);
if (scaledResult.isCoherentWithK(kPack, kPackPerBlock,
scheduleVersion)) {
LLVM_DEBUG(llvm::dbgs() << ">>> SELECTED SCALED FP8 MFMA: K="
<< maybeInsn->getAttr().k << "\n");
result = scaledResult;
return;
}
}
}
}
LLVM_DEBUG(
llvm::dbgs()
<< ">>> Scaled FP8 MFMA not suitable, falling back to native\n");
}
}

// gfx950 has double rate instructions. Select from those first.
selectFrom(getMfmaInsnGroupAttrMapGfx950());
if (succeeded(result)) {
Expand Down Expand Up @@ -714,3 +799,27 @@ bool MfmaInsnGroup::isCoherentWithK(int64_t kpack, int64_t kPerBlock,
int64_t scheduleVersion) {
return insn.isCoherentWithK(kpack, kPerBlock, scheduleVersion);
}

bool MfmaInsnGroup::isScaledFp8() const {
// Check if the instruction is a scaled MFMA
// (rocdl.mfma.scale.f32.*x*x*.f8f6f4)
StringRef insnName = groupAttr.insn;
bool isScaledInsn = insnName.contains("mfma.scale.f32.16x16x128.f8f6f4") ||
insnName.contains("mfma.scale.f32.32x32x64.f8f6f4");
if (!isScaledInsn)
return false;

// Check if the element type is FP8 (not FP4)
// FP8 types: Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, Float8E5M2FNUZ
// FP4 types: Float4E2M1FN
bool isFp8A = isa<Float8E4M3FNType>(elementTypeA) ||
isa<Float8E4M3FNUZType>(elementTypeA) ||
isa<Float8E5M2Type>(elementTypeA) ||
isa<Float8E5M2FNUZType>(elementTypeA);
bool isFp8B = isa<Float8E4M3FNType>(elementTypeB) ||
isa<Float8E4M3FNUZType>(elementTypeB) ||
isa<Float8E5M2Type>(elementTypeB) ||
isa<Float8E5M2FNUZType>(elementTypeB);

return isFp8A && isFp8B;
}
26 changes: 26 additions & 0 deletions mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,22 @@ void MfmaEmitter::emitThreadwiseLoop(OpBuilder &b, Location loc, Value argA,
VectorType vectorType = mfmaGroup.getRetType();
auto outputOffset = llvm::to_vector(regCOffset);
bool isScaled = scaleA && scaleB;
bool isScaledFp8 = mfmaGroup.isScaledFp8();

// For scaled FP8 MFMA without explicit scale buffers, create neutral scales.
// In cbsz=0, blgp=0 mode, a scale exponent value of 0 means no scaling
// because 2^0 = 1. For Float8E8M0FNU (an exponent-only format), the call
// getFloatAttr(scaleType, 0.0) is used to produce the encoding with
// exponent = 0 (all-zero bit pattern), which corresponds to a scale of 1.
Value neutralScaleA, neutralScaleB;
if (isScaledFp8 && !isScaled) {
Type scaleType = b.getType<Float8E8M0FNUType>();
auto neutralScaleAttr = b.getFloatAttr(scaleType, 0.0);
neutralScaleA =
arith::ConstantOp::create(b, loc, scaleType, neutralScaleAttr);
neutralScaleB =
arith::ConstantOp::create(b, loc, scaleType, neutralScaleAttr);
}

for (int64_t i = 0; i < nResultVectors; ++i) {
Value offset = b.createOrFold<arith::ConstantIndexOp>(loc, i);
Expand All @@ -203,11 +219,21 @@ void MfmaEmitter::emitThreadwiseLoop(OpBuilder &b, Location loc, Value argA,

Value vectorD;
if (isScaled) {
// Explicit scale buffers provided (FP4 or scaled FP8 with explicit scales)
auto mfma = amdgpu::ScaledMFMAOp::create(
b, loc, vectorType, mfmaDDim, mfmaDDim, mfmaAttr.k, argA, argB,
vectorC, scaleA, scaleB, /*scalesIdxA=*/0, /*scalesIdxB=*/0);
vectorD = mfma.getDestD();
} else if (isScaledFp8) {
// Scaled FP8 MFMA (K=128 for 16x16, K=64 for 32x32) without explicit scales
// Use neutral scale values (0) which means 2^0 = 1 (no scaling)
auto mfma = amdgpu::ScaledMFMAOp::create(
b, loc, vectorType, mfmaDDim, mfmaDDim, mfmaAttr.k, argA, argB,
vectorC, neutralScaleA, neutralScaleB,
/*scalesIdxA=*/0, /*scalesIdxB=*/0);
vectorD = mfma.getDestD();
} else {
// Regular MFMA
auto mfma = amdgpu::MFMAOp::create(
b, loc, vectorType, mfmaDDim, mfmaDDim, mfmaAttr.k,
mfmaAttr.blocksMfma, argA, argB, vectorC, /*cbsz=*/imms[i].cbsz,
Expand Down
Loading