-
Notifications
You must be signed in to change notification settings - Fork 54
Expand MFMA instruction selection for kpack values #2242
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
bd7ce28
420f609
f440b8f
039aef3
3e49dfb
f644a1c
4c3b993
4759559
1fa9f3d
f439c0e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -546,9 +546,8 @@ Value MfmaEmitter::wrapLDSBufferForLoad( | |||||||||||||||
| TopDownTMBuilder toLDSRowCol(b, {}, {}, loc); | ||||||||||||||||
|
|
||||||||||||||||
| // Use LDS transpose compatible K formula when this operand uses LDS | ||||||||||||||||
| // transpose load (and kVec >= kBase to ensure proper K distribution) | ||||||||||||||||
| if (useLdsTransposeLoad && kVec >= kBase) { | ||||||||||||||||
|
|
||||||||||||||||
| // transpose load. Handles both kVec >= kBase and kVec < kBase cases. | ||||||||||||||||
| if (useLdsTransposeLoad) { | ||||||||||||||||
| // K access pattern must match the transpose load's pattern. | ||||||||||||||||
| // For double-rate MFMA, properly distribute K across threads | ||||||||||||||||
| int64_t instrK = mfmaAttr.k; | ||||||||||||||||
|
|
@@ -568,32 +567,63 @@ Value MfmaEmitter::wrapLDSBufferForLoad( | |||||||||||||||
| TransformMapAttr splitBlkIdAttr = splitBlkId.get(); | ||||||||||||||||
| transformAttrs.push_back(splitBlkIdAttr); | ||||||||||||||||
|
|
||||||||||||||||
| // Split k_vec into k_mfma and k_base for kpack > kBase | ||||||||||||||||
| int64_t numMfmaPerKVec = kVec / kBase; | ||||||||||||||||
|
|
||||||||||||||||
| TopDownTMBuilder splitKVec = | ||||||||||||||||
| TopDownTMBuilder::below(splitBlkId, splitBlkIdAttr); | ||||||||||||||||
| splitKVec.passThrough({"wave_m", "wave_n"}, {0, 1}, {"wave_m", "wave_n"}); | ||||||||||||||||
| splitKVec.passThrough({"blk_d", "blk_k", "blk_td", "d_iter", "k_iter"}, | ||||||||||||||||
| {2, 3, 4, 5, 6}, | ||||||||||||||||
| {"blk_d", "blk_k", "blk_td", "d_iter", "k_iter"}); | ||||||||||||||||
| splitKVec.merge({"k_mfma", "k_base"}, {7, 8}, "k_vec", | ||||||||||||||||
| {numMfmaPerKVec, kBase}); | ||||||||||||||||
| TransformMapAttr splitKVecAttr = splitKVec.get(); | ||||||||||||||||
| transformAttrs.push_back(splitKVecAttr); | ||||||||||||||||
|
|
||||||||||||||||
| toLDSRowCol = TopDownTMBuilder::below(splitKVec, splitKVecAttr); | ||||||||||||||||
|
|
||||||||||||||||
| // d = d_iter * dWaves * numBlksInD * inputSpanLen + wave_d * numBlksInD * | ||||||||||||||||
| // inputSpanLen + blk_d * inputSpanLen + blk_td | ||||||||||||||||
| toLDSRowCol.unmerge("d", 0, {"d_iter", thisWaveDim, "blk_d", "blk_td"}, | ||||||||||||||||
| {dRepeats, dWaves, numBlksInD, inputSpanLen}); | ||||||||||||||||
|
|
||||||||||||||||
| // k = k_iter * (numMfmaPerKVec * instrK) + k_mfma * instrK + blk_k * | ||||||||||||||||
| // kBase + k_base | ||||||||||||||||
| toLDSRowCol.unmerge("k", 1, {"k_iter", "k_mfma", "blk_k", "k_base"}, | ||||||||||||||||
| {kIter, numMfmaPerKVec, numBlksInK, kBase}); | ||||||||||||||||
|
|
||||||||||||||||
| if (kVec >= kBase) { | ||||||||||||||||
| // Case 1: kVec >= kBase - split k_vec into k_mfma and k_base | ||||||||||||||||
| int64_t numMfmaPerKVec = kVec / kBase; | ||||||||||||||||
|
|
||||||||||||||||
| TopDownTMBuilder splitKVec = | ||||||||||||||||
| TopDownTMBuilder::below(splitBlkId, splitBlkIdAttr); | ||||||||||||||||
| splitKVec.passThrough({"wave_m", "wave_n"}, {0, 1}, | ||||||||||||||||
| {"wave_m", "wave_n"}); | ||||||||||||||||
| splitKVec.passThrough({"blk_d", "blk_k", "blk_td", "d_iter", "k_iter"}, | ||||||||||||||||
| {2, 3, 4, 5, 6}, | ||||||||||||||||
| {"blk_d", "blk_k", "blk_td", "d_iter", "k_iter"}); | ||||||||||||||||
| splitKVec.merge({"k_mfma", "k_base"}, {7, 8}, "k_vec", | ||||||||||||||||
| {numMfmaPerKVec, kBase}); | ||||||||||||||||
| TransformMapAttr splitKVecAttr = splitKVec.get(); | ||||||||||||||||
| transformAttrs.push_back(splitKVecAttr); | ||||||||||||||||
|
|
||||||||||||||||
| toLDSRowCol = TopDownTMBuilder::below(splitKVec, splitKVecAttr); | ||||||||||||||||
|
|
||||||||||||||||
| // d = d_iter * dWaves * numBlksInD * inputSpanLen + wave_d * numBlksInD | ||||||||||||||||
| // * inputSpanLen + blk_d * inputSpanLen + blk_td | ||||||||||||||||
| toLDSRowCol.unmerge("d", 0, {"d_iter", thisWaveDim, "blk_d", "blk_td"}, | ||||||||||||||||
| {dRepeats, dWaves, numBlksInD, inputSpanLen}); | ||||||||||||||||
|
|
||||||||||||||||
| // k = k_iter * (numMfmaPerKVec * instrK) + k_mfma * instrK + blk_k * | ||||||||||||||||
| // kBase + k_base | ||||||||||||||||
| toLDSRowCol.unmerge("k", 1, {"k_iter", "k_mfma", "blk_k", "k_base"}, | ||||||||||||||||
| {kIter, numMfmaPerKVec, numBlksInK, kBase}); | ||||||||||||||||
| } else { | ||||||||||||||||
| // Case 2: kVec < kBase - split k_iter to accumulate multiple kVec | ||||||||||||||||
| // loads into one kBase worth of data (e.g., kVec=4, kBase=8) | ||||||||||||||||
| int64_t numKVecPerMfma = kBase / kVec; | ||||||||||||||||
| int64_t kOuter = kIter / numKVecPerMfma; | ||||||||||||||||
|
|
||||||||||||||||
| TopDownTMBuilder splitKIter = | ||||||||||||||||
| TopDownTMBuilder::below(splitBlkId, splitBlkIdAttr); | ||||||||||||||||
| splitKIter.passThrough({"wave_m", "wave_n"}, {0, 1}, | ||||||||||||||||
| {"wave_m", "wave_n"}); | ||||||||||||||||
| splitKIter.passThrough({"blk_d", "blk_k", "blk_td", "d_iter"}, | ||||||||||||||||
| {2, 3, 4, 5}, | ||||||||||||||||
| {"blk_d", "blk_k", "blk_td", "d_iter"}); | ||||||||||||||||
| splitKIter.merge({"k_outer", "k_inner"}, {6, 7}, "k_iter", | ||||||||||||||||
| {kOuter, numKVecPerMfma}); | ||||||||||||||||
| splitKIter.passThrough({"k_vec"}, {8}, {"k_vec"}); | ||||||||||||||||
| TransformMapAttr splitKIterAttr = splitKIter.get(); | ||||||||||||||||
| transformAttrs.push_back(splitKIterAttr); | ||||||||||||||||
|
|
||||||||||||||||
| toLDSRowCol = TopDownTMBuilder::below(splitKIter, splitKIterAttr); | ||||||||||||||||
|
|
||||||||||||||||
| // d formula same as kVec >= kBase case | ||||||||||||||||
| toLDSRowCol.unmerge("d", 0, {"d_iter", thisWaveDim, "blk_d", "blk_td"}, | ||||||||||||||||
| {dRepeats, dWaves, numBlksInD, inputSpanLen}); | ||||||||||||||||
|
|
||||||||||||||||
| // k = k_outer * instrK + blk_k * kBase + k_inner * kVec + k_vec | ||||||||||||||||
| // This accumulates numKVecPerMfma loads of kVec elements into kBase | ||||||||||||||||
| toLDSRowCol.unmerge("k", 1, {"k_outer", "blk_k", "k_inner", "k_vec"}, | ||||||||||||||||
| {kOuter, numBlksInK, numKVecPerMfma, kVec}); | ||||||||||||||||
| } | ||||||||||||||||
| } else { | ||||||||||||||||
| // Standard formula for regular load scenarios | ||||||||||||||||
| toLDSRowCol = TopDownTMBuilder::below(splitWaveId, splitWaveIdAttr); | ||||||||||||||||
|
|
@@ -760,9 +790,11 @@ MfmaEmitter::createAccelGemmOperandTransforms( | |||||||||||||||
| TransformMapAttr splitWaveIdAttr = splitWaveId.get(); | ||||||||||||||||
| transformAttrs.push_back(splitWaveIdAttr); | ||||||||||||||||
| // Fourth coordinate transform | ||||||||||||||||
| // Check if we need LDS transpose compatible K formula | ||||||||||||||||
| // Check if we need LDS transpose compatible K formula. | ||||||||||||||||
| // When prefetch is used: kPack >= kBase allows LDS transpose load, | ||||||||||||||||
| // kPack < kBase disables it (falls back to regular load). | ||||||||||||||||
|
Comment on lines
+793
to
+795
|
||||||||||||||||
| // Check if we need LDS transpose compatible K formula. | |
| // When prefetch is used: kPack >= kBase allows LDS transpose load, | |
| // kPack < kBase disables it (falls back to regular load). | |
| // Check if we need an LDS transpose-compatible K formula. | |
| // LDS transpose-compatible loads are used only when the other operand | |
| // uses LDS transpose and this is a K-reduction; otherwise we fall back | |
| // to the regular (non-transpose-compatible) load path. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the
kVec < kBasebranch,kOuteris computed askIter / numKVecPerMfmaand then used as a dimension size inmerge({"k_outer","k_inner"}, ..., {kOuter, numKVecPerMfma}). IfkIteris not an exact multiple ofnumKVecPerMfma, this truncates and makes the merged size inconsistent with the originalk_iterextent. Please add a check/assert thatkIter % numKVecPerMfma == 0(or adjust the transform construction to handle the remainder safely).