From 838a812e81343ccf661ae4028a478d6ba416ac91 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Tue, 24 Feb 2026 19:25:02 +0100 Subject: [PATCH 1/2] constants pass Signed-off-by: Ivan Butygin --- .../Transforms/MemoryOffsetOptimization.cpp | 74 +++++++++++-------- .../test/Transforms/memory-offset-opt.mlir | 72 +++++++++++++++++- 2 files changed, 113 insertions(+), 33 deletions(-) diff --git a/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/MemoryOffsetOptimization.cpp b/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/MemoryOffsetOptimization.cpp index 92d77247b..e35156630 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/MemoryOffsetOptimization.cpp +++ b/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/MemoryOffsetOptimization.cpp @@ -16,6 +16,13 @@ // V_LSHLREV_B32(N, V_ADD_U32(base, K)) -> V_LSHLREV_B32(N, base) + // offset:K< maxOffset +// -> V_ADD_U32(base, K_hi) offset:K_lo +// where K = K_hi + K_lo, K_lo = K % (maxOffset + 1). +// The downstream ScopedCSE pass then merges v_add ops that now share +// the same K_hi, turning e.g. v_add(base,68608) and v_add(base,67584) +// into a single v_add(base,67584) with offsets 1024 and 0. // // After folding, dead instructions are removed by a DCE sweep. //===----------------------------------------------------------------------===// @@ -231,25 +238,25 @@ static AddrAnalysis extractConstant(Value addr, OpBuilder &builder, if (inner.constOffset == 0) return std::nullopt; - // Check shift overflow before creating any new ops + // Check shift overflow before creating any new ops. auto shiftedConst = safeShiftLeft(inner.constOffset, *shiftAmt); if (!shiftedConst) return std::nullopt; - // Create new shift of the stripped base + // Create new shift of the stripped base. auto newShift = V_LSHLREV_B32::create(builder, loc, shiftOp.getResult().getType(), shiftOp.getSrc0(), inner.base); - // Recurse on the other operand too + // Recurse on the other operand too. auto otherAnalysis = extractConstant(otherVal, builder, loc); - // Check addition overflow + // Check addition overflow. auto totalConst = safeAdd(*shiftedConst, otherAnalysis.constOffset); if (!totalConst) return std::nullopt; - // Create new add with the stripped shift and other operand + // Create new add with the stripped shift and other operand. Value newBase = V_ADD_U32::create(builder, loc, addLike->resultType, newShift.getResult(), otherAnalysis.base); @@ -282,7 +289,7 @@ static AddrAnalysis extractConstant(Value addr, OpBuilder &builder, auto srcAnalysis = extractConstant(src, builder, loc); if (srcAnalysis.constOffset == 0) return std::nullopt; - // Also recurse on the other operand + // Also recurse on the other operand. auto otherAnalysis = extractConstant(other, builder, loc); auto totalConst = safeAdd(srcAnalysis.constOffset, otherAnalysis.constOffset); @@ -307,10 +314,10 @@ static AddrAnalysis extractConstant(Value addr, OpBuilder &builder, if (shiftAmt && *shiftAmt >= 0 && *shiftAmt < 32) { auto inner = extractConstant(shiftOp.getSrc1(), builder, loc); if (inner.constOffset != 0) { - // Check shift overflow before creating any new ops + // Check shift overflow before creating any new ops. auto shiftedConst = safeShiftLeft(inner.constOffset, *shiftAmt); if (shiftedConst) { - // Create new shift of the stripped base + // Create new shift of the stripped base. auto newShift = V_LSHLREV_B32::create(builder, loc, shiftOp.getResult().getType(), shiftOp.getSrc0(), inner.base); @@ -426,7 +433,7 @@ struct MemoryOffsetOptPass module.walk([&](ProgramOp program) { OpBuilder builder(program.getBody().front().getParentOp()); - // Collect memory ops to process (avoid modifying while iterating) + // Collect memory ops to process (avoid modifying while iterating). SmallVector memOps; program.walk([&](Operation *op) { if (getMemOpKind(op) != MemOpKind::Unknown) @@ -444,11 +451,11 @@ struct MemoryOffsetOptPass int64_t maxOffset = getMaxOffset(kind); // Set insertion point right before the memory op for any new - // instructions + // instructions. builder.setInsertionPoint(op); Location loc = op->getLoc(); - // Extract constants from the address tree + // Extract constants from the address tree. AddrAnalysis analysis = extractConstant(addr, builder, loc); if (analysis.constOffset == 0) @@ -460,40 +467,47 @@ struct MemoryOffsetOptPass continue; if (newOffset <= maxOffset) { - // Constant fits in hardware offset field: fold into offset:N + // Constant fits entirely in the hardware offset field. op->setOperand(addrIdx, analysis.base); setOffset(op, newOffset, kind); totalFolded++; } else { - // Constant exceeds hardware offset limit. Still apply the shift - // distribution to simplify the address tree, but leave the constant - // as an explicit v_add_u32 with a literal. This enables CSE to - // deduplicate the shared base (analysis.base) across multiple - // memory ops that differ only by their constant offset. + // Constant exceeds the hardware offset limit. Split it: + // v_add_u32(base, K) where K > maxOffset + // → v_add_u32(base, K_hi) offset:(K_lo + existingOffset) + // where K = K_hi + K_lo, K_lo = K % (maxOffset + 1). // - // Check if constant fits in 32-bit integer (V_ADD_U32 limitation) - if (analysis.constOffset > std::numeric_limits::max() || - analysis.constOffset < std::numeric_limits::min()) { + // The downstream ScopedCSE pass then merges v_add_u32 ops that + // now share the same K_hi and base (e.g., loads that originally + // had constants 67584 and 68608 both become v_add_u32(base, 67584) + // with offsets 0 and 1024 respectively). + int64_t K = analysis.constOffset; + int64_t K_lo = K % (maxOffset + 1); + int64_t K_hi = K - K_lo; + int64_t splitOffset = existingOffset + K_lo; + + // Sanity: K_lo is in [0, maxOffset], so splitOffset is bounded. + assert(K_lo >= 0 && K_lo <= maxOffset && "bad constant split"); + + if (splitOffset < 0 || splitOffset > maxOffset) + continue; + // Check if K_hi fits in 32-bit integer (V_ADD_U32 limitation). + if (K_hi > std::numeric_limits::max() || + K_hi < std::numeric_limits::min()) continue; - } - // Before: (base + K) << N + col [3 ops, K< maxOffset] - // After: (base << N + col) + K<(analysis.constOffset); - auto constOp = - ConstantOp::create(builder, loc, constImm, analysis.constOffset); + auto constImm = builder.getType(K_hi); + auto constOp = ConstantOp::create(builder, loc, constImm, K_hi); auto vregType = builder.getType(1, 1); - // NOTE: constant must be src0 (first operand) for VOP2 encoding. - // src1 must be a VGPR on AMDGCN. + // NOTE: constant must be src0 for VOP2 encoding. auto newAddr = V_ADD_U32::create(builder, loc, vregType, constOp, analysis.base); op->setOperand(addrIdx, newAddr.getResult()); + setOffset(op, splitOffset, kind); totalFolded++; } } - // Remove dead instructions created by the folding - // totalDead += removeDeadOps(program); // NOTE: Dead code elimination is delegated to the standard Canonicalizer // or CSE passes that should run after this pass. }); diff --git a/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/memory-offset-opt.mlir b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/memory-offset-opt.mlir index f0ba07213..272323e96 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/memory-offset-opt.mlir +++ b/wave_lang/kernel/wave/asm/wave_asm/test/Transforms/memory-offset-opt.mlir @@ -93,7 +93,8 @@ waveasm.program @test_shift_distribution target = #waveasm.target<#waveasm.gfx94 } //===----------------------------------------------------------------------===// -// Test: Buffer store offset exceeds limit -> NO fold +// Test: Buffer store offset exceeds limit -> split into K_hi + K_lo +// 8192 = 8192 + 0 (K_lo = 8192 % 4096 = 0), so no useful split. //===----------------------------------------------------------------------===// // CHECK-LABEL: waveasm.program @test_offset_limit @@ -106,15 +107,80 @@ waveasm.program @test_offset_limit target = #waveasm.target<#waveasm.gfx942, 5> %c8192 = waveasm.constant 8192 : !waveasm.imm<8192> %addr = waveasm.v_add_u32 %base, %c8192 : !waveasm.pvreg<1>, !waveasm.imm<8192> -> !waveasm.vreg - // Should NOT be folded because 8192 > 4095 (buffer max offset) + // 8192 % 4096 = 0, so K_hi = 8192 and K_lo = 0. The v_add stays. + // CHECK: waveasm.constant 8192 // CHECK: waveasm.v_add_u32 // CHECK: waveasm.buffer_store_dword - // CHECK-NOT: instOffset = 8192 waveasm.buffer_store_dword %data, %srd, %addr : !waveasm.pvreg<0>, !waveasm.psreg<0, 4>, !waveasm.vreg waveasm.s_endpgm } +//===----------------------------------------------------------------------===// +// Test: Constant splitting for oversized buffer offset +// v_add_u32(base, 5120) where 5120 > 4095 +// K_lo = 5120 % 4096 = 1024, K_hi = 4096 +// -> v_add_u32(base, 4096) offset:1024 +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: waveasm.program @test_constant_split_buffer +waveasm.program @test_constant_split_buffer target = #waveasm.target<#waveasm.gfx942, 5> abi = #waveasm.abi<> { + %srd = waveasm.precolored.sreg 0, 4 : !waveasm.psreg<0, 4> + %base = waveasm.precolored.vreg 0 : !waveasm.pvreg<0> + %soff = waveasm.precolored.sreg 4 : !waveasm.psreg<4> + + // Address = base + 5120 + %c5120 = waveasm.constant 5120 : !waveasm.imm<5120> + %addr = waveasm.v_add_u32 %base, %c5120 : !waveasm.pvreg<0>, !waveasm.imm<5120> -> !waveasm.vreg + + // 5120 = 4096 + 1024. The v_add gets K_hi=4096, offset gets K_lo=1024. + // CHECK: waveasm.constant 4096 + // CHECK: waveasm.v_add_u32 %{{.*}}, %{{.*}} : !waveasm.imm<4096> + // CHECK: waveasm.buffer_load_dwordx4 + // CHECK-SAME: offset : 1024 + %result = waveasm.buffer_load_dwordx4 %srd, %addr, %soff : !waveasm.psreg<0, 4>, !waveasm.vreg, !waveasm.psreg<4> -> !waveasm.vreg<4, 4> + + waveasm.s_endpgm +} + +//===----------------------------------------------------------------------===// +// Test: Two buffer loads differing by a small constant get the same K_hi +// after splitting, enabling CSE to merge them downstream. +// v_add_u32(base, 67584) -> v_add_u32(base, 65536) offset:2048 +// v_add_u32(base, 68608) -> v_add_u32(base, 65536) offset:3072 +// Both emit the same K_hi=65536, so ScopedCSE can share the v_add. +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: waveasm.program @test_constant_split_shared_khi +waveasm.program @test_constant_split_shared_khi target = #waveasm.target<#waveasm.gfx942, 5> abi = #waveasm.abi<> { + %srd = waveasm.precolored.sreg 0, 4 : !waveasm.psreg<0, 4> + %base = waveasm.precolored.vreg 0 : !waveasm.pvreg<0> + %soff = waveasm.precolored.sreg 4 : !waveasm.psreg<4> + + %c67584 = waveasm.constant 67584 : !waveasm.imm<67584> + %addr1 = waveasm.v_add_u32 %base, %c67584 : !waveasm.pvreg<0>, !waveasm.imm<67584> -> !waveasm.vreg + + %c68608 = waveasm.constant 68608 : !waveasm.imm<68608> + %addr2 = waveasm.v_add_u32 %base, %c68608 : !waveasm.pvreg<0>, !waveasm.imm<68608> -> !waveasm.vreg + + // 67584 = 65536 + 2048 -> v_add(base, 65536) offset:2048 + // 68608 = 65536 + 3072 -> v_add(base, 65536) offset:3072 + // Both produce the same K_hi = 65536. + // CHECK: waveasm.constant 65536 + // CHECK: waveasm.v_add_u32 %{{.*}}, %{{.*}} : !waveasm.imm<65536> + // CHECK: waveasm.buffer_load_dwordx4 + // CHECK-SAME: offset : 2048 + %r1 = waveasm.buffer_load_dwordx4 %srd, %addr1, %soff : !waveasm.psreg<0, 4>, !waveasm.vreg, !waveasm.psreg<4> -> !waveasm.vreg<4, 4> + + // CHECK: waveasm.constant 65536 + // CHECK: waveasm.v_add_u32 %{{.*}}, %{{.*}} : !waveasm.imm<65536> + // CHECK: waveasm.buffer_load_dwordx4 + // CHECK-SAME: offset : 3072 + %r2 = waveasm.buffer_load_dwordx4 %srd, %addr2, %soff : !waveasm.psreg<0, 4>, !waveasm.vreg, !waveasm.psreg<4> -> !waveasm.vreg<4, 4> + + waveasm.s_endpgm +} + //===----------------------------------------------------------------------===// // Test: DS read with large offset IS allowed (DS max = 65535) //===----------------------------------------------------------------------===// From e6b1da115974b55f766e45eeb10da815cdec1687 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Tue, 24 Feb 2026 19:30:01 +0100 Subject: [PATCH 2/2] use LDBG Signed-off-by: Ivan Butygin --- .../Transforms/MemoryOffsetOptimization.cpp | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/MemoryOffsetOptimization.cpp b/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/MemoryOffsetOptimization.cpp index e35156630..18a7bb0b4 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/MemoryOffsetOptimization.cpp +++ b/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/MemoryOffsetOptimization.cpp @@ -38,6 +38,7 @@ #include "mlir/Pass/Pass.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #define DEBUG_TYPE "waveasm-memory-offset-opt" @@ -184,9 +185,8 @@ static AddrAnalysis extractConstant(Value addr, OpBuilder &builder, return {addr, 0}; } if (check == OrOverlapCheck::Unknown) { - LLVM_DEBUG(llvm::dbgs() - << "MemoryOffsetOpt: skipping V_OR_B32 - cannot prove " - << "non-overlapping bits for constant " << *c << "\n"); + LDBG() << "skipping V_OR_B32 - cannot prove " + << "non-overlapping bits for constant " << *c; return {addr, 0}; } } @@ -210,9 +210,8 @@ static AddrAnalysis extractConstant(Value addr, OpBuilder &builder, return {addr, 0}; } if (check == OrOverlapCheck::Unknown) { - LLVM_DEBUG(llvm::dbgs() - << "MemoryOffsetOpt: skipping V_OR_B32 - cannot prove " - << "non-overlapping bits for constant " << *c << "\n"); + LDBG() << "skipping V_OR_B32 - cannot prove " + << "non-overlapping bits for constant " << *c; return {addr, 0}; } } @@ -512,10 +511,9 @@ struct MemoryOffsetOptPass // or CSE passes that should run after this pass. }); - LLVM_DEBUG(if (totalFolded > 0) { - llvm::dbgs() << "MemoryOffsetOpt: folded " << totalFolded - << " constant address components into offset fields\n"; - }); + if (totalFolded > 0) + LDBG() << "folded " << totalFolded + << " constant address components into offset fields"; } };