Skip to content

Comments

Expand MFMA instruction selection for kpack values#2242

Open
stefankoncarevic wants to merge 10 commits intodevelopfrom
mfma-enable-kpack-values-gfx950
Open

Expand MFMA instruction selection for kpack values#2242
stefankoncarevic wants to merge 10 commits intodevelopfrom
mfma-enable-kpack-values-gfx950

Conversation

@stefankoncarevic
Copy link
Contributor

@stefankoncarevic stefankoncarevic commented Feb 18, 2026

For detailed performance analysis and tuning data showing the impact of expanded MFMA selection for each data type, refer to the linked issue.
Resolves: https://amd-hub.atlassian.net/browse/AIROCMLIR-481

Motivation

Previously, the isCoherentWithK validation strictly required kpack >= k_base, which prevented valid configurations for newer MFMA instructions. This limitation was particularly impactful for:

  • fp8 on gfx950/gfx942: kpack=4 could not use ANY MFMA (all have k_base=8)
  • int8 on gfx942: kpack=4 could not use the newer MFMA instructions (32x32x16, 16x16x32)
  • int8 on gfx950: kpack=4,8 could not use double-rate MFMA (32x32x32, 16x16x64)
  • f16/bf16 on gfx950: kpack=4 could not leverage double-rate MFMA (32x32x16, 16x16x32)

This change relaxes the constraint to enable broader MFMA selection, improving hardware utilization and tuning flexibility.

Technical Details

Core Change: MfmaInsnGroup.cpp
Modified isCoherentWithK to allow kpack < k_base when:

  1. Double-buffer pipeline (scheduleVersion = 2 or 4)
  2. single-buffer pipelines still require strict kpack >= k_base to avoid wasting MFMA cycles
  3. k_base % kpack == 0 (kpack must evenly divide k_base)
  4. kpack × kPackPerBlock >= MFMA_K (must cover full K dimension)

Impact by Architecture and Data Type

gfx950:

Data Type kpack=4 kpack=8
f16/bf16 Can now select double-rate MFMA (32x32x16, 16x16x32) No change
fp8 Now works - was completely disabled No change
int8 All 4 MFMA now available including double-rate Double-rate MFMA (32x32x32, 16x16x64) now available

gfx942:

Data Type kpack=4 kpack=8
f16/bf16 No change (no double-rate MFMA available) No change
fp8 Now works - was completely disabled No change
int8 Now works - was disabled for k_base=8 MFMA No change

Test Plan

All PR and nightly tests pass.

Test Result

Submission Checklist

Allow kpack < k_base when k_base >= 8 and k_base % kpack == 0.
This enables better utilization of double-rate MFMA instructions
(e.g., gfx950 f16/bf16/int8, gfx942 int8/fp8) with kpack=4.
Disable LDS transpose for prefetch when kpack < kBase as a
necessary fix for the relaxed validation.
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Relaxes MFMA kpack/k_base coherence constraints to enable selection of newer MFMA instructions (notably when kpack < k_base), and updates LDS-transpose attention e2e configs accordingly.

Changes:

  • Relax MfmaInsn::isCoherentWithK to allow kpack < k_base for newer MFMA (with divisibility + K-coverage constraints).
  • Extend LDS-transpose compatible K-indexing logic to handle kVec < kBase in MfmaEmitter::wrapLDSBufferForLoad, plus add a prefetch-related transpose disable rule.
  • Update attention e2e TOML suites/configs to exercise the new behaviors and a GEMM1 fallback scenario.

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
mlir/test/e2e/PrLdsTransposeLoadAttention.toml Adjusts PR-focused attention configs (head dims / perf configs) for quicker validation coverage.
mlir/test/e2e/LdsTransposeLoadAttention.toml Updates main attention LDS-transpose suites and adds a GEMM1 fallback test scenario for kpack < kBase.
mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp Disables operand A LDS-transpose when operand B is prefetched and kpack < kBase due to incompatible linear K mapping.
mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp Updates LDS-transpose-compatible K mapping to support kVec < kBase, and relaxes compatible-K gating in operand transforms.
mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp Removes the hard disable of LDS transpose for large head dimensions in attention GEMM0.
mlir/lib/Dialect/Rock/IR/MfmaInsnGroup.cpp Implements relaxed kpack < k_base coherence rules for MFMA instruction selection.
Comments suppressed due to low confidence (1)

mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp:807

  • useLdsTransposeCompatibleK path still assumes kPack >= kBase. With the relaxed MFMA/kpack rules, this function can now be reached with kPack < kBase, making numMfmaPerKPack = kPack / kBase evaluate to 0 and producing an invalid TransformMap (and incorrect K indexing). Please add a dedicated kPack < kBase handling here (similar to the kVec < kBase handling in wrapLDSBufferForLoad) or explicitly guard and fall back to the regular path when kPack < kBase.
    // Check if we need LDS transpose compatible K formula.
    // When prefetch is used: kPack >= kBase allows LDS transpose load,
    // kPack < kBase disables it (falls back to regular load).
    bool useLdsTransposeCompatibleK =
        otherOperandUsesLdsTranspose && isKReduction;
    int64_t numBlksInK = instrK / kBase;
    int64_t numBlksInD = (waveSize / inputSpanLen) / numBlksInK;

    TransformMapAttr toLDSRowColAttr;
    if (useLdsTransposeCompatibleK) {
      // LDS transpose compatible path: split blk_id into blk_d and blk_k
      // Also split kpack into k_mfma and k_base to match LDS transpose pattern
      int64_t numMfmaPerKPack = kPack / kBase;

      // First, add a transform to split blk_id

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +598 to +612
// Case 2: kVec < kBase - split k_iter to accumulate multiple kVec
// loads into one kBase worth of data (e.g., kVec=4, kBase=8)
int64_t numKVecPerMfma = kBase / kVec;
int64_t kOuter = kIter / numKVecPerMfma;

TopDownTMBuilder splitKIter =
TopDownTMBuilder::below(splitBlkId, splitBlkIdAttr);
splitKIter.passThrough({"wave_m", "wave_n"}, {0, 1},
{"wave_m", "wave_n"});
splitKIter.passThrough({"blk_d", "blk_k", "blk_td", "d_iter"},
{2, 3, 4, 5},
{"blk_d", "blk_k", "blk_td", "d_iter"});
splitKIter.merge({"k_outer", "k_inner"}, {6, 7}, "k_iter",
{kOuter, numKVecPerMfma});
splitKIter.passThrough({"k_vec"}, {8}, {"k_vec"});
Copy link

Copilot AI Feb 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the kVec < kBase branch, kOuter is computed as kIter / numKVecPerMfma and then used as a dimension size in merge({"k_outer","k_inner"}, ..., {kOuter, numKVecPerMfma}). If kIter is not an exact multiple of numKVecPerMfma, this truncates and makes the merged size inconsistent with the original k_iter extent. Please add a check/assert that kIter % numKVecPerMfma == 0 (or adjust the transform construction to handle the remainder safely).

Copilot uses AI. Check for mistakes.
Comment on lines +793 to +795
// Check if we need LDS transpose compatible K formula.
// When prefetch is used: kPack >= kBase allows LDS transpose load,
// kPack < kBase disables it (falls back to regular load).
Copy link

Copilot AI Feb 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment about prefetch behavior (“kPack >= kBase allows LDS transpose load, kPack < kBase disables it”) no longer matches the logic here now that the kPack >= kBase condition was removed from useLdsTransposeCompatibleK. Please update the comment to reflect the actual gating conditions (or reintroduce the check if that is still the intended behavior).

Suggested change
// Check if we need LDS transpose compatible K formula.
// When prefetch is used: kPack >= kBase allows LDS transpose load,
// kPack < kBase disables it (falls back to regular load).
// Check if we need an LDS transpose-compatible K formula.
// LDS transpose-compatible loads are used only when the other operand
// uses LDS transpose and this is a K-reduction; otherwise we fall back
// to the regular (non-transpose-compatible) load path.

Copilot uses AI. Check for mistakes.
@umangyadav
Copy link
Member

umangyadav commented Feb 18, 2026

We required kPack > kBase to avoid wasting MFMAs.

kPack controls LDS layout and therefore LDS reads/writes vectorization and bank conflicts indirectly.
e.g.

in GEMM we do this for "Single buffer" pipeline.

kPerBlock = kPack * kPackPerBlock
scf.for 0 to K step kPerBlock {
globalLoad {mPerBlock, kPerBlock} tile of matrix A
globalLoad {nPerBlock, kPerBlock} tile of matrix B
LDSWrite {kPackPerBlock, mPerBlock, kpack} of matrix A
LDSWrite {kPackPerBlock, nPerBlock, kPack} of matrix B

for i = 0 to kPackPerBlock step 1
  regA = LDSRead {i, mPerBlock, kPack} of matrix A
  for j = 0 to kPackPerBlock  step 1
     regB = LDSRead{j, nPerBlock, kPack} of matrix B
     for k = 0  to kPack step kBase 
        mfmaRegA = regA[:, k:k+kBase]
        mfmaRegB = regB[:, k:k+kBase]
        emitMFMA()
}

Here if kPack < kBase then we would perform MFMA but it would calculate wasteful data.

That is the reason why we have this constraint (kPack < kBase) in isCoherentWithK`.

Tuner will select a different "kPack" value will pick double rate MFMA eventually.

It is a bit different story for "Double buffer" pipeline though where we load entire [kPackPerBlock, mPerBlock] tile from LDS. In that case we should allow kPerBlock % kBase == 0 cases and set kBasePerThread = kPerBlock / kBase

Relaxed kpack validation (kpack < k_base) now only applies to
double-buffer pipelines (scheduleVersion 2 or 4).
@stefankoncarevic
Copy link
Contributor Author

We required kPack > kBase to avoid wasting MFMAs.

kPack controls LDS layout and therefore LDS reads/writes vectorization and bank conflicts indirectly. e.g.

in GEMM we do this for "Single buffer" pipeline.

kPerBlock = kPack * kPackPerBlock
scf.for 0 to K step kPerBlock {
globalLoad {mPerBlock, kPerBlock} tile of matrix A
globalLoad {nPerBlock, kPerBlock} tile of matrix B
LDSWrite {kPackPerBlock, mPerBlock, kpack} of matrix A
LDSWrite {kPackPerBlock, nPerBlock, kPack} of matrix B

for i = 0 to kPackPerBlock step 1
  regA = LDSRead {i, mPerBlock, kPack} of matrix A
  for j = 0 to kPackPerBlock  step 1
     regB = LDSRead{j, nPerBlock, kPack} of matrix B
     for k = 0  to kPack step kBase 
        mfmaRegA = regA[:, k:k+kBase]
        mfmaRegB = regB[:, k:k+kBase]
        emitMFMA()
}

Here if kPack < kBase then we would perform MFMA but it would calculate wasteful data.

That is the reason why we have this constraint (kPack < kBase) in isCoherentWithK`.

Tuner will select a different "kPack" value will pick double rate MFMA eventually.

It is a bit different story for "Double buffer" pipeline though where we load entire [kPackPerBlock, mPerBlock] tile from LDS. In that case we should allow kPerBlock % kBase == 0 cases and set kBasePerThread = kPerBlock / kBase

You're right when comparing the same config directly - single-buffer pipelines show significant degradation with relaxed kPack < kBase validation.
However, when comparing best-vs-best single-buffer configs, the results are more mixed (40-60 improved or regressed). Still, the risk of regression for single-buffer is clear.
I updated the PR, enabled the relaxation only for double-buffer pipelines (pipeline 2 or 4), which should be safe since they load the entire kPerBlock tile before processing.

@stefankoncarevic stefankoncarevic mentioned this pull request Feb 20, 2026
1 task
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants