Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions mlir/include/mlir/Dialect/Rock/IR/MfmaInsnGroup.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ class MfmaInsn {
MfmaInsnAttr getAttr() const;
Type getArgTypeFor(Type elementTypeA);
VectorType getRetType(Type elementType);
bool isCoherentWithK(int64_t kPack, int64_t kPerBlock);
bool isCoherentWithK(int64_t kPack, int64_t kPerBlock,
int64_t scheduleVersion);
};

template <typename T>
Expand Down Expand Up @@ -138,7 +139,8 @@ class MfmaInsnGroup {
public:
static FailureOr<MfmaInsnGroup> select(Type elementTypeA, Type elementTypeB,
StringRef arch, int64_t mnPerXdl,
int64_t kPack, int64_t kPackPerBlock);
int64_t kPack, int64_t kPackPerBlock,
int64_t scheduleVersion);
MfmaInsnGroup(Type elementTypeA, Type elementTypeB, const MfmaInsn &insn,
const MfmaInsnGroupAttr &groupAttr);
int64_t getMRepeats(int64_t mPerWave);
Expand All @@ -150,7 +152,8 @@ class MfmaInsnGroup {
Type getArgTypeA();
Type getArgTypeB();
VectorType getRetType();
bool isCoherentWithK(int64_t kPack, int64_t kPerBlock);
bool isCoherentWithK(int64_t kPack, int64_t kPerBlock,
int64_t scheduleVersion);
SmallString<16> getROCDLIntrinsicName() { return groupAttr.insn; }
};

Expand Down
43 changes: 33 additions & 10 deletions mlir/lib/Dialect/Rock/IR/MfmaInsnGroup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -484,13 +484,33 @@ VectorType MfmaInsn::getRetType(Type elementType) {
return VectorType::get({attr.nOutputsOfMfma}, vectorElem);
}

bool MfmaInsn::isCoherentWithK(int64_t kpack, int64_t kPerBlock) {
// Check if the MFMA instruction is coherent with the K dimension configuration.
// Double-buffer pipelines allow kpack < k_base if k_base % kpack == 0.
// Single-buffer pipelines require kpack >= k_base to avoid wasting MFMA cycles.
bool MfmaInsn::isCoherentWithK(int64_t kpack, int64_t kPerBlock,
int64_t scheduleVersion) {
int64_t totalKPerBlock = kpack * kPerBlock;
// Double-buffer pipelines: scheduleVersion 2 or 4
bool isDoubleBuffer = (scheduleVersion == 2 || scheduleVersion == 4);

if (kpack > 1) {
if (kpack < attr.k_base) {
LLVM_DEBUG(llvm::dbgs()
<< "Should pack at least k_base elements and avoid waste "
"xdlopsgemm cycles\n");
return false;
if (!isDoubleBuffer) {
LLVM_DEBUG(llvm::dbgs()
<< "Should pack at least k_base elements and avoid waste "
"xdlopsgemm cycles\n");
return false;
}
if (attr.k_base % kpack != 0) {
LLVM_DEBUG(llvm::dbgs()
<< "kpack must divide k_base when kpack < k_base\n");
return false;
}
if (totalKPerBlock < attr.k) {
LLVM_DEBUG(llvm::dbgs() << "totalKPerBlock (" << totalKPerBlock
<< ") must be >= MFMA K (" << attr.k << ")\n");
return false;
}
}
if (attr.isKReduction && kPerBlock < attr.inputSpansPerMfmaIn) {
LLVM_DEBUG(
Expand Down Expand Up @@ -565,14 +585,16 @@ static MfmaTypeId convertTypesToId(Type dataTypeA, Type dataTypeB) {

FailureOr<MfmaInsnGroup>
MfmaInsnGroup::select(Type elementTypeA, Type elementTypeB, StringRef arch,
int64_t mnPerXdl, int64_t kPack, int64_t kPackPerBlock) {
int64_t mnPerXdl, int64_t kPack, int64_t kPackPerBlock,
int64_t scheduleVersion) {
LLVM_DEBUG(llvm::dbgs() << "Invoke Mfma group selection:\n"
<< "elementType A: " << elementTypeA << "\n"
<< "elementType B: " << elementTypeB << "\n"
<< "arch: " << arch << "\n"
<< "mnPerXdl: " << mnPerXdl << "\n"
<< "kPack: " << kPack << "\n"
<< "KPackPerBlock: " << kPackPerBlock << "\n");
<< "KPackPerBlock: " << kPackPerBlock << "\n"
<< "scheduleVersion: " << scheduleVersion << "\n");

// Use 64x64 as base unit in large waves
int64_t mPerMfmaGroup = getLenPerMfmaGroup(mnPerXdl);
Expand Down Expand Up @@ -605,7 +627,7 @@ MfmaInsnGroup::select(Type elementTypeA, Type elementTypeB, StringRef arch,
// gfx950 has double rate instructions. Select from those first.
selectFrom(getMfmaInsnGroupAttrMapGfx950());
if (succeeded(result)) {
if (result->isCoherentWithK(kPack, kPackPerBlock)) {
if (result->isCoherentWithK(kPack, kPackPerBlock, scheduleVersion)) {
LLVM_DEBUG(llvm::dbgs() << "Selected gfx950 double rate instruction\n");
return;
}
Expand Down Expand Up @@ -688,6 +710,7 @@ SmallVector<mlir::rock::MFMAParams, 2> MfmaInsnGroup::getImms() {
return groupAttr.imms;
}

bool MfmaInsnGroup::isCoherentWithK(int64_t kpack, int64_t kPerBlock) {
return insn.isCoherentWithK(kpack, kPerBlock);
bool MfmaInsnGroup::isCoherentWithK(int64_t kpack, int64_t kPerBlock,
int64_t scheduleVersion) {
return insn.isCoherentWithK(kpack, kPerBlock, scheduleVersion);
}
Original file line number Diff line number Diff line change
Expand Up @@ -2476,13 +2476,6 @@ struct GridwiseAttentionAccelRewritePattern
gemm0TuningParams.getNPerWave(), gemm0kpack,
/*doubleBuffering=*/false, /*bLoadsFromLDS=*/qLoadsFromLDS);

// Disable LDS transpose for large head dimensions (HeadDimQK >= 512)
// Note: gemm0N = qShape[2] = head_dim_qk
if (gemm0N >= 512) {
ldsDecisionGemm0.enableA = false;
ldsDecisionGemm0.enableB = false;
}

// create matrix params
BlockwiseMatrixParamsAttr matrixParamsK = BlockwiseMatrixParamsAttr::get(
rewriter.getContext(), elemTypeK, elemTypeKLoad,
Expand Down
19 changes: 10 additions & 9 deletions mlir/lib/Dialect/Rock/Tuning/GridwiseGemmParams.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -503,15 +503,16 @@ PopulateParamsXDL::isValidBlockwiseGemm(RockAccelTuningParamAttrInterface param,
if (auto derivedParam = dyn_cast<AccelGemmParamsAttr>(param)) {
mnPerXdl = derivedParam.getMnPerXdl();
}
auto maybeMfmaInsnGroup =
MfmaInsnGroup::select(dataTypeA, dataTypeB, arch, mnPerXdl,
param.getKpack(), param.getKpackPerBlock());
auto maybeMfmaInsnGroup = MfmaInsnGroup::select(
dataTypeA, dataTypeB, arch, mnPerXdl, param.getKpack(),
param.getKpackPerBlock(), param.getScheduleVersion());
if (failed(maybeMfmaInsnGroup)) {
LLVM_DEBUG(llvm::dbgs() << "Failed to select xdlops instruction group.\n");
return failure();
}
MfmaInsnGroup mfmaGroup = *maybeMfmaInsnGroup;
if (!mfmaGroup.isCoherentWithK(param.getKpack(), param.getKpackPerBlock())) {
if (!mfmaGroup.isCoherentWithK(param.getKpack(), param.getKpackPerBlock(),
param.getScheduleVersion())) {
LLVM_DEBUG(
llvm::dbgs()
<< "Mfma instruction group selection is not compatible with k.\n");
Expand All @@ -537,15 +538,15 @@ PopulateParamsXDL::getTuningParameters(OpBuilder &b, KernelType opType,
continue;

int64_t mnPerXdl = params.getMnPerXdl();
auto maybeMfmaInsnGroup =
MfmaInsnGroup::select(dataTypeA, dataTypeB, arch, mnPerXdl,
params.getKpack(), params.getKpackPerBlock());
auto maybeMfmaInsnGroup = MfmaInsnGroup::select(
dataTypeA, dataTypeB, arch, mnPerXdl, params.getKpack(),
params.getKpackPerBlock(), params.getScheduleVersion());
if (failed(maybeMfmaInsnGroup)) {
continue;
}
MfmaInsnGroup mfmaGroup = *maybeMfmaInsnGroup;
if (mfmaGroup.isCoherentWithK(params.getKpack(),
params.getKpackPerBlock())) {
if (mfmaGroup.isCoherentWithK(params.getKpack(), params.getKpackPerBlock(),
params.getScheduleVersion())) {
res.push_back(params);
}
}
Expand Down
97 changes: 65 additions & 32 deletions mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -546,9 +546,8 @@ Value MfmaEmitter::wrapLDSBufferForLoad(
TopDownTMBuilder toLDSRowCol(b, {}, {}, loc);

// Use LDS transpose compatible K formula when this operand uses LDS
// transpose load (and kVec >= kBase to ensure proper K distribution)
if (useLdsTransposeLoad && kVec >= kBase) {

// transpose load. Handles both kVec >= kBase and kVec < kBase cases.
if (useLdsTransposeLoad) {
// K access pattern must match the transpose load's pattern.
// For double-rate MFMA, properly distribute K across threads
int64_t instrK = mfmaAttr.k;
Expand All @@ -568,32 +567,63 @@ Value MfmaEmitter::wrapLDSBufferForLoad(
TransformMapAttr splitBlkIdAttr = splitBlkId.get();
transformAttrs.push_back(splitBlkIdAttr);

// Split k_vec into k_mfma and k_base for kpack > kBase
int64_t numMfmaPerKVec = kVec / kBase;

TopDownTMBuilder splitKVec =
TopDownTMBuilder::below(splitBlkId, splitBlkIdAttr);
splitKVec.passThrough({"wave_m", "wave_n"}, {0, 1}, {"wave_m", "wave_n"});
splitKVec.passThrough({"blk_d", "blk_k", "blk_td", "d_iter", "k_iter"},
{2, 3, 4, 5, 6},
{"blk_d", "blk_k", "blk_td", "d_iter", "k_iter"});
splitKVec.merge({"k_mfma", "k_base"}, {7, 8}, "k_vec",
{numMfmaPerKVec, kBase});
TransformMapAttr splitKVecAttr = splitKVec.get();
transformAttrs.push_back(splitKVecAttr);

toLDSRowCol = TopDownTMBuilder::below(splitKVec, splitKVecAttr);

// d = d_iter * dWaves * numBlksInD * inputSpanLen + wave_d * numBlksInD *
// inputSpanLen + blk_d * inputSpanLen + blk_td
toLDSRowCol.unmerge("d", 0, {"d_iter", thisWaveDim, "blk_d", "blk_td"},
{dRepeats, dWaves, numBlksInD, inputSpanLen});

// k = k_iter * (numMfmaPerKVec * instrK) + k_mfma * instrK + blk_k *
// kBase + k_base
toLDSRowCol.unmerge("k", 1, {"k_iter", "k_mfma", "blk_k", "k_base"},
{kIter, numMfmaPerKVec, numBlksInK, kBase});

if (kVec >= kBase) {
// Case 1: kVec >= kBase - split k_vec into k_mfma and k_base
int64_t numMfmaPerKVec = kVec / kBase;

TopDownTMBuilder splitKVec =
TopDownTMBuilder::below(splitBlkId, splitBlkIdAttr);
splitKVec.passThrough({"wave_m", "wave_n"}, {0, 1},
{"wave_m", "wave_n"});
splitKVec.passThrough({"blk_d", "blk_k", "blk_td", "d_iter", "k_iter"},
{2, 3, 4, 5, 6},
{"blk_d", "blk_k", "blk_td", "d_iter", "k_iter"});
splitKVec.merge({"k_mfma", "k_base"}, {7, 8}, "k_vec",
{numMfmaPerKVec, kBase});
TransformMapAttr splitKVecAttr = splitKVec.get();
transformAttrs.push_back(splitKVecAttr);

toLDSRowCol = TopDownTMBuilder::below(splitKVec, splitKVecAttr);

// d = d_iter * dWaves * numBlksInD * inputSpanLen + wave_d * numBlksInD
// * inputSpanLen + blk_d * inputSpanLen + blk_td
toLDSRowCol.unmerge("d", 0, {"d_iter", thisWaveDim, "blk_d", "blk_td"},
{dRepeats, dWaves, numBlksInD, inputSpanLen});

// k = k_iter * (numMfmaPerKVec * instrK) + k_mfma * instrK + blk_k *
// kBase + k_base
toLDSRowCol.unmerge("k", 1, {"k_iter", "k_mfma", "blk_k", "k_base"},
{kIter, numMfmaPerKVec, numBlksInK, kBase});
} else {
// Case 2: kVec < kBase - split k_iter to accumulate multiple kVec
// loads into one kBase worth of data (e.g., kVec=4, kBase=8)
int64_t numKVecPerMfma = kBase / kVec;
int64_t kOuter = kIter / numKVecPerMfma;

TopDownTMBuilder splitKIter =
TopDownTMBuilder::below(splitBlkId, splitBlkIdAttr);
splitKIter.passThrough({"wave_m", "wave_n"}, {0, 1},
{"wave_m", "wave_n"});
splitKIter.passThrough({"blk_d", "blk_k", "blk_td", "d_iter"},
{2, 3, 4, 5},
{"blk_d", "blk_k", "blk_td", "d_iter"});
splitKIter.merge({"k_outer", "k_inner"}, {6, 7}, "k_iter",
{kOuter, numKVecPerMfma});
splitKIter.passThrough({"k_vec"}, {8}, {"k_vec"});
Comment on lines +598 to +612
Copy link

Copilot AI Feb 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the kVec < kBase branch, kOuter is computed as kIter / numKVecPerMfma and then used as a dimension size in merge({"k_outer","k_inner"}, ..., {kOuter, numKVecPerMfma}). If kIter is not an exact multiple of numKVecPerMfma, this truncates and makes the merged size inconsistent with the original k_iter extent. Please add a check/assert that kIter % numKVecPerMfma == 0 (or adjust the transform construction to handle the remainder safely).

Copilot uses AI. Check for mistakes.
TransformMapAttr splitKIterAttr = splitKIter.get();
transformAttrs.push_back(splitKIterAttr);

toLDSRowCol = TopDownTMBuilder::below(splitKIter, splitKIterAttr);

// d formula same as kVec >= kBase case
toLDSRowCol.unmerge("d", 0, {"d_iter", thisWaveDim, "blk_d", "blk_td"},
{dRepeats, dWaves, numBlksInD, inputSpanLen});

// k = k_outer * instrK + blk_k * kBase + k_inner * kVec + k_vec
// This accumulates numKVecPerMfma loads of kVec elements into kBase
toLDSRowCol.unmerge("k", 1, {"k_outer", "blk_k", "k_inner", "k_vec"},
{kOuter, numBlksInK, numKVecPerMfma, kVec});
}
} else {
// Standard formula for regular load scenarios
toLDSRowCol = TopDownTMBuilder::below(splitWaveId, splitWaveIdAttr);
Expand Down Expand Up @@ -760,9 +790,11 @@ MfmaEmitter::createAccelGemmOperandTransforms(
TransformMapAttr splitWaveIdAttr = splitWaveId.get();
transformAttrs.push_back(splitWaveIdAttr);
// Fourth coordinate transform
// Check if we need LDS transpose compatible K formula
// Check if we need LDS transpose compatible K formula.
// When prefetch is used: kPack >= kBase allows LDS transpose load,
// kPack < kBase disables it (falls back to regular load).
Comment on lines +793 to +795
Copy link

Copilot AI Feb 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment about prefetch behavior (“kPack >= kBase allows LDS transpose load, kPack < kBase disables it”) no longer matches the logic here now that the kPack >= kBase condition was removed from useLdsTransposeCompatibleK. Please update the comment to reflect the actual gating conditions (or reintroduce the check if that is still the intended behavior).

Suggested change
// Check if we need LDS transpose compatible K formula.
// When prefetch is used: kPack >= kBase allows LDS transpose load,
// kPack < kBase disables it (falls back to regular load).
// Check if we need an LDS transpose-compatible K formula.
// LDS transpose-compatible loads are used only when the other operand
// uses LDS transpose and this is a K-reduction; otherwise we fall back
// to the regular (non-transpose-compatible) load path.

Copilot uses AI. Check for mistakes.
bool useLdsTransposeCompatibleK =
otherOperandUsesLdsTranspose && isKReduction && (kPack >= kBase);
otherOperandUsesLdsTranspose && isKReduction;
int64_t numBlksInK = instrK / kBase;
int64_t numBlksInD = (waveSize / inputSpanLen) / numBlksInK;

Expand Down Expand Up @@ -1429,7 +1461,8 @@ AccelEmitter::select(GemmFeatures features, Type dataTypeA, Type dataTypeB,
if (isMfma) {
auto maybeMfmaInsnGroup = MfmaInsnGroup::select(
dataTypeA, dataTypeB, arch, tuningParams.getMnPerXdl(),
tuningParams.getKpack(), tuningParams.getKpackPerBlock());
tuningParams.getKpack(), tuningParams.getKpackPerBlock(),
tuningParams.getScheduleVersion());
if (failed(maybeMfmaInsnGroup)) {
return nullptr;
}
Expand Down
12 changes: 12 additions & 0 deletions mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,18 @@ LDSTransposeDecision decideLDSTransposeForOperands(
<< (decB.usable ? "USABLE" : "NOT USABLE") << "\n");
}

// If B is prefetched (doesn't load from LDS) and kpack < kBase,
// disable A's LDS transpose. The prefetched B operand needs a compatible
// K formula that requires kpack >= kBase (floordiv/mod can't be expressed
// with linear transforms).
int64_t kBase = mfmaEmitter->getParams().kBase;
if (!bLoadsFromLDS && kpack < kBase) {
decA.usable = false;
LLVM_DEBUG(llvm::dbgs()
<< "[lds_transpose] Disabling A: prefetch with "
<< "kPack(" << kpack << ") < kBase(" << kBase << ")\n");
}

// Enable LDS transpose load for each operand that supports it.
// The K access pattern formula in AccelEmitter.cpp (useLdsTransposeLoad)
// ensures compatibility when mixing regular load with transpose load.
Expand Down
25 changes: 9 additions & 16 deletions mlir/test/e2e/LdsTransposeLoadAttention.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ config = "-seq_len_q 128 -seq_len_k 128 -head_dim_qk 64 -head_dim_v 128 --transK
config = "-seq_len_q 256 -seq_len_k 256 -head_dim_qk 128 -head_dim_v 128 --transK=false --transQ=true -perf_config attn:v2:64,64,16,16,16,16,8,1,4,2,1"

[[suite.test]]
config = "-seq_len_q 64 -seq_len_k 64 -head_dim_qk 64 -head_dim_v 64 --transK=false --transQ=true -perf_config attn:v2:16,16,16,16,16,16,4,1,3,2,1"
config = "-seq_len_q 64 -seq_len_k 64 -head_dim_qk 64 -head_dim_v 64 --transK=false --transQ=true -perf_config attn:v2:32,32,16,16,16,16,8,1,3,2,1"

[[suite.test]]
config = "-seq_len_q 128 -seq_len_k 128 -head_dim_qk 128 -head_dim_v 128 --transK=false --transQ=true -perf_config attn:v2:16,16,32,16,16,16,4,1,4,2,1"
config = "-seq_len_q 128 -seq_len_k 128 -head_dim_qk 128 -head_dim_v 128 --transK=false --transQ=true -perf_config attn:v2:32,32,16,32,32,16,8,1,3,2,1"

# ============================================================================
# Suite 3: transK=false, transQ=false - Only K uses LDS transpose
Expand Down Expand Up @@ -83,13 +83,13 @@ config = "-seq_len_q 128 -seq_len_k 128 -head_dim_qk 128 -head_dim_v 128 --trans
config = "-seq_len_q 256 -seq_len_k 256 -head_dim_qk 256 -head_dim_v 128 --transK=false --transQ=false -perf_config attn:v2:64,64,32,32,32,32,8,1,4,2,1"

[[suite.test]]
config = "-seq_len_q 64 -seq_len_k 64 -head_dim_qk 64 -head_dim_v 64 --transK=false --transQ=false -perf_config attn:v2:32,32,16,16,16,16,4,1,3,2,1"
config = "-seq_len_q 64 -seq_len_k 64 -head_dim_qk 64 -head_dim_v 64 --transK=false --transQ=false -perf_config attn:v2:32,32,16,16,16,16,8,1,3,2,1"

[[suite.test]]
config = "-seq_len_q 128 -seq_len_k 128 -head_dim_qk 128 -head_dim_v 128 --transK=false --transQ=false -perf_config attn:v2:16,16,32,16,16,16,4,1,4,2,1"
config = "-seq_len_q 128 -seq_len_k 128 -head_dim_qk 128 -head_dim_v 128 --transK=false --transQ=false -perf_config attn:v2:32,32,16,32,32,16,8,1,3,2,1"

[[suite.test]]
config = "-seq_len_q 128 -seq_len_k 128 -head_dim_qk 128 -head_dim_v 128 --transK=false --transQ=false -perf_config attn:v2:32,32,32,32,32,32,4,1,3,2,1"
config = "-seq_len_q 128 -seq_len_k 128 -head_dim_qk 128 -head_dim_v 128 --transK=false --transQ=false -perf_config attn:v2:32,32,32,32,32,32,8,1,3,2,1"

# ============================================================================
# Suite 5: transK=true, transQ=true - Only Q can use LDS transpose
Expand All @@ -108,19 +108,12 @@ config = "-seq_len_q 256 -seq_len_k 256 -head_dim_qk 256 -head_dim_v 256 --trans
config = "-seq_len_q 128 -seq_len_k 128 -head_dim_qk 256 -head_dim_v 64 --transK=true --transQ=true -perf_config attn:v2:32,32,16,16,16,16,8,1,3,2,1"

# ============================================================================
# Suite 6: GEMM1 LDS Transpose - V uses LDS transpose, P is prefetched
# Suite 6: GEMM1 fallback - kpack=4 < kBase=8, LDS transpose disabled for V
# Tests that the fallback to regular path works correctly when P is prefetched
# but kpack < kBase prevents LDS transpose for V operand.
# ============================================================================
[[suite]]
name = "lds_transpose_gemm1_v_with_p_prefetch"
name = "lds_transpose_gemm1_fallback"

# 32x8 MFMA, medium dimensions
[[suite.test]]
config = "-seq_len_q 128 -seq_len_k 128 -head_dim_qk 64 -head_dim_v 64 --transK=false --transQ=false -perf_config attn:v3:32,32,32,32,32,32,32,4,1,3,2,0,1"

# 32x8 MFMA, larger dimensions
[[suite.test]]
config = "-seq_len_q 256 -seq_len_k 256 -head_dim_qk 128 -head_dim_v 128 --transK=false --transQ=false -perf_config attn:v3:32,32,32,32,32,32,32,4,1,3,2,0,1"

# 16x16 MFMA, smaller dimensions
[[suite.test]]
config = "-seq_len_q 64 -seq_len_k 64 -head_dim_qk 64 -head_dim_v 64 --transK=false --transQ=false -perf_config attn:v3:16,16,16,16,16,16,16,4,1,3,2,0,1"
2 changes: 1 addition & 1 deletion mlir/test/e2e/PrLdsTransposeLoad.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ config = "-g 1 -m 16 -k 64 -n 16 --transA=true --transB=true --perf_config v3:16
config = "-g 1 -m 32 -k 128 -n 32 --transA=true --transB=true --perf_config v3:64,32,16,32,32,1,1,4,2,1,1"

[[suite.test]]
config = "-g 1 -m 32 -k 256 -n 32 --transA=true --transB=true --perf_config v3:32,64,16,32,32,1,8,3,2,1,1"
config = "-g 1 -m 32 -k 256 -n 32 --transA=true --transB=true --perf_config v3:32,64,16,32,32,1,1,3,2,1,1"

[[suite.test]]
config = "-g 256 -m 32 -k 32 -n 32 --transA=true --transB=true --perf_config v3:32,32,32,32,16,1,1,4,2,1,1"
Expand Down
Loading