Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
035e880
Add LDS transpose load support for attention kernel
stefankoncarevic Dec 23, 2025
ce7f0c9
Extend LDS transpose load support for 8 and 16 wave configurations
stefankoncarevic Dec 23, 2025
5a8c960
Merge branch 'develop' into lds-transpose-load-attention
stefankoncarevic Dec 23, 2025
d28193c
Fix K-access formula for hybrid LDS transpose load scenario
stefankoncarevic Dec 26, 2025
295334b
Merge branch 'develop' into lds-transpose-load-attention
stefankoncarevic Dec 26, 2025
2c4f149
Fix clang format
stefankoncarevic Dec 26, 2025
e0b7855
Add nightly e2e tests for LDS transpose load (GEMM)
stefankoncarevic Dec 26, 2025
2819548
Address review comments for LDS transpose load
stefankoncarevic Dec 29, 2025
f042d38
Add LDS transpose E2E tests and addressed review comment
stefankoncarevic Jan 2, 2026
65cc88b
Merge branch 'develop' into lds-transpose-load-attention
stefankoncarevic Jan 2, 2026
36facb7
Merge branch 'develop' into lds-transpose-load-attention
stefankoncarevic Jan 12, 2026
9027654
Add FP8/BF8 support for LDS transpose load
stefankoncarevic Jan 20, 2026
84c9425
Add PR CI tests for FP8/BF8 LDS transpose load GEMM operations
stefankoncarevic Jan 20, 2026
e3ecb26
Add INT8 LDS transpose load support for GEMM and Attention
stefankoncarevic Jan 23, 2026
9170e81
Add FP8/BF8 support for LDS transpose load
stefankoncarevic Jan 20, 2026
149c441
Add PR CI tests for FP8/BF8 LDS transpose load GEMM operations
stefankoncarevic Jan 20, 2026
a75ab7a
Add FP8 GEMM heuristic to selectively disable LDS transpose
stefankoncarevic Jan 29, 2026
dfd0b86
Merge lds-transpose-load-fp8 into lds-transpose-load-int8
stefankoncarevic Jan 29, 2026
da7db78
Add INT8 CONV heuristic to disable LDS transpose for N=1600 patterns
stefankoncarevic Jan 30, 2026
f8b7266
Merge branch 'develop' into lds-transpose-load-fp8
stefankoncarevic Feb 3, 2026
b07acac
Merge branch 'develop' into lds-transpose-load-fp8
stefankoncarevic Feb 4, 2026
4daa86e
Merge branch 'develop' into lds-transpose-load-fp8
stefankoncarevic Feb 5, 2026
43a6db4
Merge branch 'develop' into lds-transpose-load-fp8
stefankoncarevic Feb 5, 2026
3ccbf35
Merge branch 'lds-transpose-load-fp8' into lds-transpose-load-int8
stefankoncarevic Feb 5, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions mlir/include/mlir/Dialect/Rock/IR/RockOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1221,9 +1221,10 @@ defvar SameShapeVectorOfI1 = [{
def Rock_LDSTransposeLoadOp
: Rock_Op<"lds_transpose_load", [DeclareOpInterfaceMethods<
MemoryEffectsOpInterface>]>,
Arguments<(ins Arg<MemRefOf<[F16, BF16]>, "LDS source buffer">:$source,
Arguments<(ins Arg<MemRefOf<[F16, BF16, F8E4M3FN, F8E5M2, I8]>,
"LDS source buffer">:$source,
Variadic<Index>:$indices)>,
Results<(outs VectorOfLengthAndType<[4], [F16, BF16]>:$result)> {
Results<(outs AnyVectorOfNonZeroRank:$result)> {
let summary =
"Hardware-assisted LDS transpose load for matrix accelerator tile";
let description = [{
Expand All @@ -1232,9 +1233,15 @@ def Rock_LDSTransposeLoadOp
The tile dimensions match the selected matrix-multiply accelerator
instruction geometry (`dDim × instrK`), where:
- `dDim` is the accelerator M/N dimension (e.g., 16 or 32)
- `instrK` is the accelerator K dimension (e.g., 8, 16, or 32)
The operation returns a vector of 4 elements per thread containing
transposed elements in a layout suitable for matrix accelerator instructions.
- `instrK` is the accelerator K dimension (e.g., 8, 16, 32, or 64)

For 16-bit types (f16, bf16):
- Uses ds_read_tr16_b64 instruction
- Returns vector<4xtype> (4 elements per thread)

For 8-bit types (f8E4M3FN, f8E5M2 - OCP FP8, i8 - INT8 for gfx950):
- Uses ds_read_tr8_b64 instruction
- Returns vector<8xtype> (8 elements per thread)

Benefits:
- Reduces LDS bank conflicts through optimized access patterns
Expand Down
33 changes: 30 additions & 3 deletions mlir/lib/Dialect/Rock/IR/RockDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -559,9 +559,13 @@ LogicalResult TransformMapAttr::verify(
}

// Helper function to check valid MFMA geometry for LDS transpose
// Supported geometries:
// - F16/BF16: (16,16), (16,32), (32,8), (32,16)
// - FP8/BF8: (16,32), (32,16)
// - INT8: (16,32), (16,64), (32,16), (32,32)
static bool isValidLdsTransposeMfmaGeometry(int64_t dDim, int64_t kDim) {
return (dDim == 16 && (kDim == 16 || kDim == 32)) ||
(dDim == 32 && (kDim == 8 || kDim == 16));
return (dDim == 16 && (kDim == 16 || kDim == 32 || kDim == 64)) ||
(dDim == 32 && (kDim == 8 || kDim == 16 || kDim == 32));
}

LogicalResult LDSTransposeConfigAttr::verify(
Expand All @@ -573,7 +577,7 @@ LogicalResult LDSTransposeConfigAttr::verify(
if (!isValidLdsTransposeMfmaGeometry(dDim, kDim)) {
return emitError() << "invalid MFMA geometry (" << dDim << "x" << kDim
<< ") for LDS transpose - valid combinations: "
"(16,16), (16,32), (32,8), (32,16)";
"(16,16), (16,32), (16,64), (32,8), (32,16), (32,32)";
}

// Validate positive dimensions
Expand Down Expand Up @@ -2161,6 +2165,29 @@ LogicalResult LDSTransposeLoadOp::verify() {
<< srcElemType << ")";
}

// Verify result vector length based on element type:
// - 16-bit types (f16, bf16): ds_read_tr16_b64 returns 4 elements
// - 8-bit types (f8E4M3FN, f8E5M2 - OCP FP8 for gfx950): ds_read_tr8_b64
// returns 8 elements
int64_t expectedVecLen;
if (srcElemType.isF16() || srcElemType.isBF16()) {
expectedVecLen = 4;
} else if (isa<Float8E4M3FNType>(srcElemType) ||
isa<Float8E5M2Type>(srcElemType)) {
expectedVecLen = 8;
} else if (srcElemType.isInteger(8)) {
expectedVecLen = 8;
} else {
return emitOpError("unsupported element type for LDS transpose load: ")
<< srcElemType;
}

if (resultType.getNumElements() != expectedVecLen) {
return emitOpError("expected result vector of ")
<< expectedVecLen << " elements for " << srcElemType
<< " type, but got " << resultType.getNumElements();
}

// Check hardware support using AmdArchDb
StringRef arch = rock::getArchValue(*this);
AmdArchInfo archInfo = rock::lookupArchInfo(arch);
Expand Down
39 changes: 37 additions & 2 deletions mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2489,7 +2489,7 @@ struct GridwiseAttentionAccelRewritePattern
ldsLayoutCfgMG0.doRotateWithK, ldsLayoutCfgMG0.doSwapThreadIterSubDims,
ldsLayoutCfgMG0.ldsLayoutDxK, directToLDS,
/*splitKAcrossThreadsFirst=*/false, gemm0G, gemm0M, gemm0InMPerThread,
/*ldsTransposeEnabled=*/ldsDecisionGemm0.enableA,
/*ldsTransposeEnabled=*/false,
/*accelDDim=*/ldsDecisionGemm0.mfmaDDim,
/*accelKDim=*/ldsDecisionGemm0.mfmaKDim);

Expand All @@ -2498,7 +2498,7 @@ struct GridwiseAttentionAccelRewritePattern
ldsLayoutCfgNG0.doRotateWithK, ldsLayoutCfgNG0.doSwapThreadIterSubDims,
ldsLayoutCfgNG0.ldsLayoutDxK, directToLDSQ,
/*splitKAcrossThreadsFirst=*/false, gemm0G, gemm0N, gemm0InNPerThread,
/*ldsTransposeEnabled=*/ldsDecisionGemm0.enableB,
/*ldsTransposeEnabled=*/false,
/*accelDDim=*/ldsDecisionGemm0.mfmaDDim,
/*accelKDim=*/ldsDecisionGemm0.mfmaKDim);

Expand Down Expand Up @@ -3384,6 +3384,41 @@ struct GridwiseGemmAccelRewritePattern
directToLDS, ldsLayoutConfigA, ldsLayoutConfigB, mPerBlock,
nPerBlock, kPerBlock, mPerWave, nPerWave, kpack, doubleBuffering);

// FP8 heuristic: check if LDS transpose should be disabled based on dims
// Rule 1: K >= 1280 causes regression, EXCEPT when K > otherDim AND
// otherDim > 1280
// Rule 2: Small square K-N matrices (K == N && K < 512)
auto shouldDisableLdsTranspose = [K, N](int64_t otherDim) -> bool {
if (K >= 1280 && !(K > otherDim && otherDim > 1280))
return true;
if (K == N && K < 512)
return true;
return false;
};

bool isFp8Type = isa<Float8E4M3FNType>(elementTypeA) ||
isa<Float8E5M2Type>(elementTypeA);

if (isFp8Type) {
// Case 1: TransA=false, TransB=false - only B uses LDS transpose
if (!ldsDecision.enableA && ldsDecision.enableB &&
shouldDisableLdsTranspose(N))
ldsDecision.enableB = false;

// Case 2: TransA=true, TransB=true - only A uses LDS transpose
if (ldsDecision.enableA && !ldsDecision.enableB &&
shouldDisableLdsTranspose(M))
ldsDecision.enableA = false;
}

// INT8 CONV heuristic: disable LDS transpose for N=1600 (40x40 spatial)
// when K<=M or K>2*M.
bool isInt8Type = elementTypeA.isInteger(8);
if (isInt8Type && N == 1600 && (K <= M || K > 2 * M)) {
ldsDecision.enableA = false;
ldsDecision.enableB = false;
}

LLVM_DEBUG(llvm::dbgs()
<< "M: " << M << "\n"
<< "N: " << N << "\n"
Expand Down
58 changes: 58 additions & 0 deletions mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,64 @@ Value MfmaEmitter::wrapLDSBufferForLoad(
toLDSRowCol.unmerge("k", 1, {"k_iter", "k_mfma", "blk_k", "k_base"},
{kIter, numMfmaPerKVec, numBlksInK, kBase});

} else if (useLdsTransposeLoad && kBase == 16) {
// This handles INT8 32x32 (instrK=32, kBase=16) and INT8 16x64
// (instrK=64, kBase=16) when kpack=1
int64_t instrK = mfmaAttr.k;
int64_t numBlksInK = instrK / kBase;
int64_t numBlksInD = (waveSize / inputSpanLen) / numBlksInK;

// Split blk_id into blk_d (for D dimension) and blk_k (for K dimension)
TopDownTMBuilder splitBlkId =
TopDownTMBuilder::below(splitWaveId, splitWaveIdAttr);
splitBlkId.passThrough({"wave_m", "wave_n"}, {0, 1},
{"wave_m", "wave_n"});
splitBlkId.merge({"blk_d", "blk_k"}, {2, 3}, "blk_id",
{numBlksInD, numBlksInK});
splitBlkId.passThrough({"blk_td", "d_iter", "k_iter", "k_vec"},
{4, 5, 6, 7},
{"blk_td", "d_iter", "k_iter", "k_vec"});
TransformMapAttr splitBlkIdAttr = splitBlkId.get();
transformAttrs.push_back(splitBlkIdAttr);

// For kVec < kBase, we split k_iter into outer_k (MFMA iterations)
// and inner_k (within kBase for each blk_k)
// Total K per thread = kIter * kVec
// Each thread covers K range: blk_k * kBase to blk_k * kBase + kBase - 1
// Number of MFMA iterations = (kIter * kVec) / kBase (iterations per
// blk_k) K iterations within each kBase = kBase / kVec Constraint:
// numMfmaIters * kIterPerBlk = kIter
int64_t kIterPerBlk = kBase / kVec; // iterations to cover kBase elements
int64_t numMfmaIters = kIter / kIterPerBlk; // number of outer iterations

TopDownTMBuilder splitKIter =
TopDownTMBuilder::below(splitBlkId, splitBlkIdAttr);
splitKIter.passThrough({"wave_m", "wave_n"}, {0, 1},
{"wave_m", "wave_n"});
splitKIter.passThrough({"blk_d", "blk_k", "blk_td", "d_iter"},
{2, 3, 4, 5},
{"blk_d", "blk_k", "blk_td", "d_iter"});
splitKIter.merge({"outer_k", "inner_k"}, {6, 7}, "k_iter",
{numMfmaIters, kIterPerBlk});
splitKIter.passThrough({"k_vec"}, {8}, {"k_vec"});
TransformMapAttr splitKIterAttr = splitKIter.get();
transformAttrs.push_back(splitKIterAttr);

toLDSRowCol = TopDownTMBuilder::below(splitKIter, splitKIterAttr);

// d = d_iter * dWaves * numBlksInD * inputSpanLen + wave_d * numBlksInD *
// inputSpanLen + blk_d * inputSpanLen + blk_td
toLDSRowCol.unmerge("d", 0, {"d_iter", thisWaveDim, "blk_d", "blk_td"},
{dRepeats, dWaves, numBlksInD, inputSpanLen});

// k = outer_k * instrK + blk_k * kBase + inner_k * kVec + k_vec
// This ensures K access pattern matches transpose load:
// - outer_k selects which MFMA iteration
// - blk_k selects which K block (0..kBase-1 or kBase..instrK-1)
// - inner_k * kVec + k_vec gives position within the K block
toLDSRowCol.unmerge("k", 1, {"outer_k", "blk_k", "inner_k", "k_vec"},
{numMfmaIters, numBlksInK, kIterPerBlk, kVec});

} else {
// Standard formula for regular load scenarios
toLDSRowCol = TopDownTMBuilder::below(splitWaveId, splitWaveIdAttr);
Expand Down
Loading