Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -599,12 +599,16 @@ std::optional<std::string> KernelGenerator::generateOp(Operation *op) {
};

for (unsigned i = 0; i < numIter; ++i) {
auto [srcPhys, isSGPR] =
getPhysRegInfo(condOp.getIterArgs()[i]);
auto [dstPhys, dstIsSGPR] = getPhysRegInfo(body.getArgument(i));
Value iterArg = condOp.getIterArgs()[i];
Value blockArg = body.getArgument(i);
auto [srcPhys, isSGPR] = getPhysRegInfo(iterArg);
auto [dstPhys, dstIsSGPR] = getPhysRegInfo(blockArg);

if (srcPhys >= 0 && dstPhys >= 0 && srcPhys != dstPhys) {
pendingCopies.push_back({dstPhys, srcPhys, isSGPR});
// Decompose multi-register values into individual copies.
int64_t size = getRegSize(iterArg.getType());
for (int64_t j = 0; j < size; ++j)
pendingCopies.push_back({dstPhys + j, srcPhys + j, isSGPR});
}
}

Expand Down
36 changes: 33 additions & 3 deletions wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/Liveness.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,34 @@ LivenessInfo computeLiveness(ProgramOp program) {
// tiedClassId field on each range identifies its class membership.
auto &tc = info.tiedClasses;

// Check whether tying an async memory load's iter_arg to its block arg
// is unsafe. An async load (MemoryOp trait with results) writes its
// destination register asynchronously. If the corresponding block arg
// is still read after the load issues, sharing a register would let the
// load clobber a value MFMAs are still consuming.
auto isUnsafeAsyncTie = [&](Value iterArg, BlockArgument blockArg) -> bool {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than have this as a heuristic in the pass, could we add this as a verification hook on ConditionOp or LoopOp that flags when an async memory result is passed as an iter_arg with an overlapping block arg use. This would catch the hazard at IR validation time rather than relying on the liveness pass to silently handle it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what IR validation will give us besides random compilation failures for user. Also, checking non-local properties (use-def chains) in verifier is a bad practice.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea with the verifier is that we always validate this as an invariant in the IR, rather than only evaluate it in this pass. Okay then maybe we should model it as a normal form but we can do that in a separate PR.

Operation *defOp = iterArg.getDefiningOp();
if (!defOp)
return false;
if (!defOp->hasTrait<OpTrait::MemoryOp>() || defOp->getNumResults() == 0)
return false;

auto defIt = opToIdx.find(defOp);
if (defIt == opToIdx.end())
return false;
int64_t defIdx = defIt->second;

// Find the last use of blockArg in the program-linear order.
int64_t lastUseIdx = -1;
for (OpOperand &use : blockArg.getUses()) {
auto userIt = opToIdx.find(use.getOwner());
if (userIt != opToIdx.end())
lastUseIdx = std::max(lastUseIdx, userIt->second);
}

return lastUseIdx > defIdx;
};

program.walk([&](LoopOp loopOp) {
Block &bodyBlock = loopOp.getBodyBlock();
auto condOp = dyn_cast<ConditionOp>(bodyBlock.getTerminator());
Expand Down Expand Up @@ -476,10 +504,11 @@ LivenessInfo computeLiveness(ProgramOp program) {
members.push_back(loopResult);
}

// Condition iter_arg -> block arg
// Condition iter_arg -> block arg (skip unsafe async memory loads).
if (i < condOp.getIterArgs().size()) {
Value iterArg = condOp.getIterArgs()[i];
if (info.ranges.contains(iterArg))
if (info.ranges.contains(iterArg) &&
!isUnsafeAsyncTie(iterArg, blockArg))
members.push_back(iterArg);
}

Expand Down Expand Up @@ -529,7 +558,8 @@ LivenessInfo computeLiveness(ProgramOp program) {
}
if (i < condOp.getIterArgs().size()) {
Value iterArg = condOp.getIterArgs()[i];
if (info.ranges.contains(iterArg) && !tc.tiedPairs.contains(iterArg))
if (info.ranges.contains(iterArg) && !tc.tiedPairs.contains(iterArg) &&
!isUnsafeAsyncTie(iterArg, blockArg))
tc.tiedPairs[iterArg] = blockArg;
}
if (i < loopOp->getNumResults()) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
// RUN: waveasm-translate --waveasm-linear-scan --emit-assembly %s | FileCheck %s
//
// Assembly emission tests for async memory iter_arg tying.
//
// Verifies that the emitter produces correct back-edge copies when iter_args
// and block args are NOT tied (unsafe async memory ops), and omits copies
// when they ARE tied (safe ordering or synchronous ops).

//===----------------------------------------------------------------------===//
// Test 1: UNSAFE buffer_load — emitter must produce 4 × v_mov_b32 copies
// to move the untied iter_arg into the block arg's register at the back edge.
//===----------------------------------------------------------------------===//

// CHECK-LABEL: vmem_emit_unsafe:
waveasm.program @vmem_emit_unsafe
target = #waveasm.target<#waveasm.gfx942, 5>
abi = #waveasm.abi<> {

%c0 = waveasm.constant 0 : !waveasm.imm<0>
%c1 = waveasm.constant 1 : !waveasm.imm<1>
%c10 = waveasm.constant 10 : !waveasm.imm<10>
%srd = waveasm.precolored.sreg 0, 4 : !waveasm.psreg<0, 4>
%a = waveasm.precolored.vreg 0, 4 : !waveasm.pvreg<0, 4>
%b = waveasm.precolored.vreg 4, 4 : !waveasm.pvreg<4, 4>

%init_i = waveasm.s_mov_b32 %c0 : !waveasm.imm<0> -> !waveasm.sreg
%init_acc = waveasm.v_mov_b32 %c0 : !waveasm.imm<0> -> !waveasm.vreg<4, 4>
%init_data = waveasm.buffer_load_dwordx4 %srd, %c0, %c0 : !waveasm.psreg<0, 4>, !waveasm.imm<0>, !waveasm.imm<0> -> !waveasm.vreg<4, 4>

// CHECK: L_loop_0:
// CHECK: buffer_load_dwordx4
// CHECK: v_mfma_f32_16x16x16_f16
// Back-edge copies: 4 individual v_mov_b32 for the wide register.
// CHECK: v_mov_b32
// CHECK: v_mov_b32
// CHECK: v_mov_b32
// CHECK: v_mov_b32
// CHECK: s_cbranch_scc1 L_loop_0
%ri, %racc, %rdata = waveasm.loop(
%i = %init_i, %acc = %init_acc, %data = %init_data)
: (!waveasm.sreg, !waveasm.vreg<4, 4>, !waveasm.vreg<4, 4>)
-> (!waveasm.sreg, !waveasm.vreg<4, 4>, !waveasm.vreg<4, 4>) {

%data_next = waveasm.buffer_load_dwordx4 %srd, %c0, %c0
: !waveasm.psreg<0, 4>, !waveasm.imm<0>, !waveasm.imm<0> -> !waveasm.vreg<4, 4>
%acc_new = waveasm.v_mfma_f32_16x16x16_f16 %data, %b, %acc
: !waveasm.vreg<4, 4>, !waveasm.pvreg<4, 4>, !waveasm.vreg<4, 4> -> !waveasm.vreg<4, 4>

%next_i = waveasm.s_add_u32 %i, %c1 : !waveasm.sreg, !waveasm.imm<1> -> !waveasm.sreg
%cond = waveasm.s_cmp_lt_u32 %next_i, %c10 : !waveasm.sreg, !waveasm.imm<10> -> !waveasm.sreg
waveasm.condition %cond : !waveasm.sreg
iter_args(%next_i, %acc_new, %data_next)
: !waveasm.sreg, !waveasm.vreg<4, 4>, !waveasm.vreg<4, 4>
}

waveasm.s_endpgm
}

//===----------------------------------------------------------------------===//
// Test 2: SAFE buffer_load — block arg dead before load, no copies needed.
// The buffer_load writes directly into the block arg's register.
//===----------------------------------------------------------------------===//

// CHECK-LABEL: vmem_emit_safe:
waveasm.program @vmem_emit_safe
target = #waveasm.target<#waveasm.gfx942, 5>
abi = #waveasm.abi<> {

%c0 = waveasm.constant 0 : !waveasm.imm<0>
%c1 = waveasm.constant 1 : !waveasm.imm<1>
%c10 = waveasm.constant 10 : !waveasm.imm<10>
%srd = waveasm.precolored.sreg 0, 4 : !waveasm.psreg<0, 4>
%b = waveasm.precolored.vreg 4, 4 : !waveasm.pvreg<4, 4>

%init_i = waveasm.s_mov_b32 %c0 : !waveasm.imm<0> -> !waveasm.sreg
%init_acc = waveasm.v_mov_b32 %c0 : !waveasm.imm<0> -> !waveasm.vreg<4, 4>
%init_data = waveasm.buffer_load_dwordx4 %srd, %c0, %c0 : !waveasm.psreg<0, 4>, !waveasm.imm<0>, !waveasm.imm<0> -> !waveasm.vreg<4, 4>

// CHECK: L_loop_0:
// CHECK: v_mfma_f32_16x16x16_f16
// CHECK: buffer_load_dwordx4
// No v_mov copies — tied registers, load writes to block arg directly.
// CHECK-NOT: v_mov_b32
// CHECK: s_cbranch_scc1 L_loop_0
%ri, %racc, %rdata = waveasm.loop(
%i = %init_i, %acc = %init_acc, %data = %init_data)
: (!waveasm.sreg, !waveasm.vreg<4, 4>, !waveasm.vreg<4, 4>)
-> (!waveasm.sreg, !waveasm.vreg<4, 4>, !waveasm.vreg<4, 4>) {

%acc_new = waveasm.v_mfma_f32_16x16x16_f16 %data, %b, %acc
: !waveasm.vreg<4, 4>, !waveasm.pvreg<4, 4>, !waveasm.vreg<4, 4> -> !waveasm.vreg<4, 4>
%data_next = waveasm.buffer_load_dwordx4 %srd, %c0, %c0
: !waveasm.psreg<0, 4>, !waveasm.imm<0>, !waveasm.imm<0> -> !waveasm.vreg<4, 4>

%next_i = waveasm.s_add_u32 %i, %c1 : !waveasm.sreg, !waveasm.imm<1> -> !waveasm.sreg
%cond = waveasm.s_cmp_lt_u32 %next_i, %c10 : !waveasm.sreg, !waveasm.imm<10> -> !waveasm.sreg
waveasm.condition %cond : !waveasm.sreg
iter_args(%next_i, %acc_new, %data_next)
: !waveasm.sreg, !waveasm.vreg<4, 4>, !waveasm.vreg<4, 4>
}

waveasm.s_endpgm
}

//===----------------------------------------------------------------------===//
// Test 3: UNSAFE ds_read_b128 — same back-edge copy pattern via LDS.
//===----------------------------------------------------------------------===//

// CHECK-LABEL: lds_emit_unsafe:
waveasm.program @lds_emit_unsafe
target = #waveasm.target<#waveasm.gfx942, 5>
abi = #waveasm.abi<> {

%c0 = waveasm.constant 0 : !waveasm.imm<0>
%c1 = waveasm.constant 1 : !waveasm.imm<1>
%c10 = waveasm.constant 10 : !waveasm.imm<10>
%b = waveasm.precolored.vreg 4, 4 : !waveasm.pvreg<4, 4>
%lds_addr = waveasm.precolored.vreg 0 : !waveasm.pvreg<0>

%init_i = waveasm.s_mov_b32 %c0 : !waveasm.imm<0> -> !waveasm.sreg
%init_acc = waveasm.v_mov_b32 %c0 : !waveasm.imm<0> -> !waveasm.vreg<4, 4>
%init_data = waveasm.ds_read_b128 %lds_addr : !waveasm.pvreg<0> -> !waveasm.vreg<4, 4>

// CHECK: L_loop_0:
// CHECK: ds_read_b128
// CHECK: v_mfma_f32_16x16x16_f16
// Back-edge copies: 4 individual v_mov_b32 for the wide register.
// CHECK: v_mov_b32
// CHECK: v_mov_b32
// CHECK: v_mov_b32
// CHECK: v_mov_b32
// CHECK: s_cbranch_scc1 L_loop_0
%ri, %racc, %rdata = waveasm.loop(
%i = %init_i, %acc = %init_acc, %data = %init_data)
: (!waveasm.sreg, !waveasm.vreg<4, 4>, !waveasm.vreg<4, 4>)
-> (!waveasm.sreg, !waveasm.vreg<4, 4>, !waveasm.vreg<4, 4>) {

%data_next = waveasm.ds_read_b128 %lds_addr : !waveasm.pvreg<0> -> !waveasm.vreg<4, 4>
%acc_new = waveasm.v_mfma_f32_16x16x16_f16 %data, %b, %acc
: !waveasm.vreg<4, 4>, !waveasm.pvreg<4, 4>, !waveasm.vreg<4, 4> -> !waveasm.vreg<4, 4>

%next_i = waveasm.s_add_u32 %i, %c1 : !waveasm.sreg, !waveasm.imm<1> -> !waveasm.sreg
%cond = waveasm.s_cmp_lt_u32 %next_i, %c10 : !waveasm.sreg, !waveasm.imm<10> -> !waveasm.sreg
waveasm.condition %cond : !waveasm.sreg
iter_args(%next_i, %acc_new, %data_next)
: !waveasm.sreg, !waveasm.vreg<4, 4>, !waveasm.vreg<4, 4>
}

waveasm.s_endpgm
}
Loading