Expand MFMA instruction selection for kpack values#2242
Expand MFMA instruction selection for kpack values#2242stefankoncarevic wants to merge 10 commits intodevelopfrom
Conversation
Allow kpack < k_base when k_base >= 8 and k_base % kpack == 0. This enables better utilization of double-rate MFMA instructions (e.g., gfx950 f16/bf16/int8, gfx942 int8/fp8) with kpack=4. Disable LDS transpose for prefetch when kpack < kBase as a necessary fix for the relaxed validation.
There was a problem hiding this comment.
Pull request overview
Relaxes MFMA kpack/k_base coherence constraints to enable selection of newer MFMA instructions (notably when kpack < k_base), and updates LDS-transpose attention e2e configs accordingly.
Changes:
- Relax
MfmaInsn::isCoherentWithKto allowkpack < k_basefor newer MFMA (with divisibility + K-coverage constraints). - Extend LDS-transpose compatible K-indexing logic to handle
kVec < kBaseinMfmaEmitter::wrapLDSBufferForLoad, plus add a prefetch-related transpose disable rule. - Update attention e2e TOML suites/configs to exercise the new behaviors and a GEMM1 fallback scenario.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| mlir/test/e2e/PrLdsTransposeLoadAttention.toml | Adjusts PR-focused attention configs (head dims / perf configs) for quicker validation coverage. |
| mlir/test/e2e/LdsTransposeLoadAttention.toml | Updates main attention LDS-transpose suites and adds a GEMM1 fallback test scenario for kpack < kBase. |
| mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp | Disables operand A LDS-transpose when operand B is prefetched and kpack < kBase due to incompatible linear K mapping. |
| mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp | Updates LDS-transpose-compatible K mapping to support kVec < kBase, and relaxes compatible-K gating in operand transforms. |
| mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp | Removes the hard disable of LDS transpose for large head dimensions in attention GEMM0. |
| mlir/lib/Dialect/Rock/IR/MfmaInsnGroup.cpp | Implements relaxed kpack < k_base coherence rules for MFMA instruction selection. |
Comments suppressed due to low confidence (1)
mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp:807
useLdsTransposeCompatibleKpath still assumeskPack >= kBase. With the relaxed MFMA/kpack rules, this function can now be reached withkPack < kBase, makingnumMfmaPerKPack = kPack / kBaseevaluate to 0 and producing an invalidTransformMap(and incorrect K indexing). Please add a dedicatedkPack < kBasehandling here (similar to thekVec < kBasehandling inwrapLDSBufferForLoad) or explicitly guard and fall back to the regular path whenkPack < kBase.
// Check if we need LDS transpose compatible K formula.
// When prefetch is used: kPack >= kBase allows LDS transpose load,
// kPack < kBase disables it (falls back to regular load).
bool useLdsTransposeCompatibleK =
otherOperandUsesLdsTranspose && isKReduction;
int64_t numBlksInK = instrK / kBase;
int64_t numBlksInD = (waveSize / inputSpanLen) / numBlksInK;
TransformMapAttr toLDSRowColAttr;
if (useLdsTransposeCompatibleK) {
// LDS transpose compatible path: split blk_id into blk_d and blk_k
// Also split kpack into k_mfma and k_base to match LDS transpose pattern
int64_t numMfmaPerKPack = kPack / kBase;
// First, add a transform to split blk_id
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| // Case 2: kVec < kBase - split k_iter to accumulate multiple kVec | ||
| // loads into one kBase worth of data (e.g., kVec=4, kBase=8) | ||
| int64_t numKVecPerMfma = kBase / kVec; | ||
| int64_t kOuter = kIter / numKVecPerMfma; | ||
|
|
||
| TopDownTMBuilder splitKIter = | ||
| TopDownTMBuilder::below(splitBlkId, splitBlkIdAttr); | ||
| splitKIter.passThrough({"wave_m", "wave_n"}, {0, 1}, | ||
| {"wave_m", "wave_n"}); | ||
| splitKIter.passThrough({"blk_d", "blk_k", "blk_td", "d_iter"}, | ||
| {2, 3, 4, 5}, | ||
| {"blk_d", "blk_k", "blk_td", "d_iter"}); | ||
| splitKIter.merge({"k_outer", "k_inner"}, {6, 7}, "k_iter", | ||
| {kOuter, numKVecPerMfma}); | ||
| splitKIter.passThrough({"k_vec"}, {8}, {"k_vec"}); |
There was a problem hiding this comment.
In the kVec < kBase branch, kOuter is computed as kIter / numKVecPerMfma and then used as a dimension size in merge({"k_outer","k_inner"}, ..., {kOuter, numKVecPerMfma}). If kIter is not an exact multiple of numKVecPerMfma, this truncates and makes the merged size inconsistent with the original k_iter extent. Please add a check/assert that kIter % numKVecPerMfma == 0 (or adjust the transform construction to handle the remainder safely).
| // Check if we need LDS transpose compatible K formula. | ||
| // When prefetch is used: kPack >= kBase allows LDS transpose load, | ||
| // kPack < kBase disables it (falls back to regular load). |
There was a problem hiding this comment.
The comment about prefetch behavior (“kPack >= kBase allows LDS transpose load, kPack < kBase disables it”) no longer matches the logic here now that the kPack >= kBase condition was removed from useLdsTransposeCompatibleK. Please update the comment to reflect the actual gating conditions (or reintroduce the check if that is still the intended behavior).
| // Check if we need LDS transpose compatible K formula. | |
| // When prefetch is used: kPack >= kBase allows LDS transpose load, | |
| // kPack < kBase disables it (falls back to regular load). | |
| // Check if we need an LDS transpose-compatible K formula. | |
| // LDS transpose-compatible loads are used only when the other operand | |
| // uses LDS transpose and this is a K-reduction; otherwise we fall back | |
| // to the regular (non-transpose-compatible) load path. |
|
We required kPack controls LDS layout and therefore LDS reads/writes vectorization and bank conflicts indirectly. in GEMM we do this for "Single buffer" pipeline. Here if That is the reason why we have this constraint Tuner will select a different "kPack" value will pick double rate MFMA eventually. It is a bit different story for "Double buffer" pipeline though where we load entire |
Relaxed kpack validation (kpack < k_base) now only applies to double-buffer pipelines (scheduleVersion 2 or 4).
You're right when comparing the same config directly - single-buffer pipelines show significant degradation with relaxed kPack < kBase validation. |
For detailed performance analysis and tuning data showing the impact of expanded MFMA selection for each data type, refer to the linked issue.
Resolves: https://amd-hub.atlassian.net/browse/AIROCMLIR-481
Motivation
Previously, the isCoherentWithK validation strictly required kpack >= k_base, which prevented valid configurations for newer MFMA instructions. This limitation was particularly impactful for:
This change relaxes the constraint to enable broader MFMA selection, improving hardware utilization and tuning flexibility.
Technical Details
Core Change: MfmaInsnGroup.cpp
Modified isCoherentWithK to allow kpack < k_base when:
Impact by Architecture and Data Type
gfx950:
gfx942:
Test Plan
All PR and nightly tests pass.
Test Result
Submission Checklist