Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions mlir/include/mlir/Dialect/Rock/IR/RockOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1221,9 +1221,10 @@ defvar SameShapeVectorOfI1 = [{
def Rock_LDSTransposeLoadOp
: Rock_Op<"lds_transpose_load", [DeclareOpInterfaceMethods<
MemoryEffectsOpInterface>]>,
Arguments<(ins Arg<MemRefOf<[F16, BF16]>, "LDS source buffer">:$source,
Arguments<(ins Arg<MemRefOf<[F16, BF16, F8E4M3FN, F8E5M2]>,
"LDS source buffer">:$source,
Variadic<Index>:$indices)>,
Results<(outs VectorOfLengthAndType<[4], [F16, BF16]>:$result)> {
Results<(outs AnyVectorOfNonZeroRank:$result)> {
let summary =
"Hardware-assisted LDS transpose load for matrix accelerator tile";
let description = [{
Expand All @@ -1233,8 +1234,14 @@ def Rock_LDSTransposeLoadOp
instruction geometry (`dDim × instrK`), where:
- `dDim` is the accelerator M/N dimension (e.g., 16 or 32)
- `instrK` is the accelerator K dimension (e.g., 8, 16, or 32)
The operation returns a vector of 4 elements per thread containing
transposed elements in a layout suitable for matrix accelerator instructions.

For 16-bit types (f16, bf16):
- Uses ds_read_tr16_b64 instruction
- Returns vector<4xtype> (4 elements per thread)

For 8-bit types (f8E4M3FN, f8E5M2 - OCP FP8 for gfx950):
- Uses ds_read_tr8_b64 instruction
- Returns vector<8xtype> (8 elements per thread)

Benefits:
- Reduces LDS bank conflicts through optimized access patterns
Expand Down
21 changes: 21 additions & 0 deletions mlir/lib/Dialect/Rock/IR/RockDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2161,6 +2161,27 @@ LogicalResult LDSTransposeLoadOp::verify() {
<< srcElemType << ")";
}

// Verify result vector length based on element type:
// - 16-bit types (f16, bf16): ds_read_tr16_b64 returns 4 elements
// - 8-bit types (f8E4M3FN, f8E5M2 - OCP FP8 for gfx950): ds_read_tr8_b64
// returns 8 elements
int64_t expectedVecLen;
if (srcElemType.isF16() || srcElemType.isBF16()) {
expectedVecLen = 4;
} else if (isa<Float8E4M3FNType>(srcElemType) ||
isa<Float8E5M2Type>(srcElemType)) {
expectedVecLen = 8;
} else {
return emitOpError("unsupported element type for LDS transpose load: ")
<< srcElemType;
}

if (resultType.getNumElements() != expectedVecLen) {
return emitOpError("expected result vector of ")
<< expectedVecLen << " elements for " << srcElemType
<< " type, but got " << resultType.getNumElements();
}

// Check hardware support using AmdArchDb
StringRef arch = rock::getArchValue(*this);
AmdArchInfo archInfo = rock::lookupArchInfo(arch);
Expand Down
27 changes: 27 additions & 0 deletions mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3377,6 +3377,33 @@ struct GridwiseGemmAccelRewritePattern
directToLDS, ldsLayoutConfigA, ldsLayoutConfigB, mPerBlock,
nPerBlock, kPerBlock, mPerWave, nPerWave, kpack, doubleBuffering);

// FP8 heuristic: check if LDS transpose should be disabled based on dims
// Rule 1: K >= 1280 causes regression, EXCEPT when K > otherDim AND
// otherDim > 1280
// Rule 2: Small square K-N matrices (K == N && K < 512)
auto shouldDisableLdsTranspose = [K, N](int64_t otherDim) -> bool {
if (K >= 1280 && !(K > otherDim && otherDim > 1280))
return true;
if (K == N && K < 512)
return true;
return false;
};

bool isFp8Type = isa<Float8E4M3FNType>(elementTypeA) ||
isa<Float8E5M2Type>(elementTypeA);

if (isFp8Type) {
// Case 1: TransA=false, TransB=false - only B uses LDS transpose
if (!ldsDecision.enableA && ldsDecision.enableB &&
shouldDisableLdsTranspose(N))
ldsDecision.enableB = false;

// Case 2: TransA=true, TransB=true - only A uses LDS transpose
if (ldsDecision.enableA && !ldsDecision.enableB &&
shouldDisableLdsTranspose(M))
ldsDecision.enableA = false;
}

LLVM_DEBUG(llvm::dbgs()
<< "M: " << M << "\n"
<< "N: " << N << "\n"
Expand Down
Loading