Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@
// V_LSHLREV_B32(N, V_ADD_U32(base, K)) -> V_LSHLREV_B32(N, base) +
// offset:K<<N
// 4. Multi-level combinations of the above
// 5. Constant splitting for oversized offsets:
// V_ADD_U32(base, K) where K > maxOffset
// -> V_ADD_U32(base, K_hi) offset:K_lo
// where K = K_hi + K_lo, K_lo = K % (maxOffset + 1).
// The downstream ScopedCSE pass then merges v_add ops that now share
// the same K_hi, turning e.g. v_add(base,68608) and v_add(base,67584)
// into a single v_add(base,67584) with offsets 1024 and 0.
//
// After folding, dead instructions are removed by a DCE sweep.
//===----------------------------------------------------------------------===//
Expand All @@ -31,6 +38,7 @@
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/DebugLog.h"

#define DEBUG_TYPE "waveasm-memory-offset-opt"

Expand Down Expand Up @@ -177,9 +185,8 @@ static AddrAnalysis extractConstant(Value addr, OpBuilder &builder,
return {addr, 0};
}
if (check == OrOverlapCheck::Unknown) {
LLVM_DEBUG(llvm::dbgs()
<< "MemoryOffsetOpt: skipping V_OR_B32 - cannot prove "
<< "non-overlapping bits for constant " << *c << "\n");
LDBG() << "skipping V_OR_B32 - cannot prove "
<< "non-overlapping bits for constant " << *c;
return {addr, 0};
}
}
Expand All @@ -203,9 +210,8 @@ static AddrAnalysis extractConstant(Value addr, OpBuilder &builder,
return {addr, 0};
}
if (check == OrOverlapCheck::Unknown) {
LLVM_DEBUG(llvm::dbgs()
<< "MemoryOffsetOpt: skipping V_OR_B32 - cannot prove "
<< "non-overlapping bits for constant " << *c << "\n");
LDBG() << "skipping V_OR_B32 - cannot prove "
<< "non-overlapping bits for constant " << *c;
return {addr, 0};
}
}
Expand All @@ -231,25 +237,25 @@ static AddrAnalysis extractConstant(Value addr, OpBuilder &builder,
if (inner.constOffset == 0)
return std::nullopt;

// Check shift overflow before creating any new ops
// Check shift overflow before creating any new ops.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I usually avoid making irrelevant stylistic changes in a load-bearing commit. These show up in git history with unrelated commit message and you will be the one to blame for this code. Putting them in a separate NFC is a better idea.

auto shiftedConst = safeShiftLeft(inner.constOffset, *shiftAmt);
if (!shiftedConst)
return std::nullopt;

// Create new shift of the stripped base
// Create new shift of the stripped base.
auto newShift =
V_LSHLREV_B32::create(builder, loc, shiftOp.getResult().getType(),
shiftOp.getSrc0(), inner.base);

// Recurse on the other operand too
// Recurse on the other operand too.
auto otherAnalysis = extractConstant(otherVal, builder, loc);

// Check addition overflow
// Check addition overflow.
auto totalConst = safeAdd(*shiftedConst, otherAnalysis.constOffset);
if (!totalConst)
return std::nullopt;

// Create new add with the stripped shift and other operand
// Create new add with the stripped shift and other operand.
Value newBase =
V_ADD_U32::create(builder, loc, addLike->resultType,
newShift.getResult(), otherAnalysis.base);
Expand Down Expand Up @@ -282,7 +288,7 @@ static AddrAnalysis extractConstant(Value addr, OpBuilder &builder,
auto srcAnalysis = extractConstant(src, builder, loc);
if (srcAnalysis.constOffset == 0)
return std::nullopt;
// Also recurse on the other operand
// Also recurse on the other operand.
auto otherAnalysis = extractConstant(other, builder, loc);
auto totalConst =
safeAdd(srcAnalysis.constOffset, otherAnalysis.constOffset);
Expand All @@ -307,10 +313,10 @@ static AddrAnalysis extractConstant(Value addr, OpBuilder &builder,
if (shiftAmt && *shiftAmt >= 0 && *shiftAmt < 32) {
auto inner = extractConstant(shiftOp.getSrc1(), builder, loc);
if (inner.constOffset != 0) {
// Check shift overflow before creating any new ops
// Check shift overflow before creating any new ops.
auto shiftedConst = safeShiftLeft(inner.constOffset, *shiftAmt);
if (shiftedConst) {
// Create new shift of the stripped base
// Create new shift of the stripped base.
auto newShift =
V_LSHLREV_B32::create(builder, loc, shiftOp.getResult().getType(),
shiftOp.getSrc0(), inner.base);
Expand Down Expand Up @@ -426,7 +432,7 @@ struct MemoryOffsetOptPass
module.walk([&](ProgramOp program) {
OpBuilder builder(program.getBody().front().getParentOp());

// Collect memory ops to process (avoid modifying while iterating)
// Collect memory ops to process (avoid modifying while iterating).
SmallVector<Operation *, 32> memOps;
program.walk([&](Operation *op) {
if (getMemOpKind(op) != MemOpKind::Unknown)
Expand All @@ -444,11 +450,11 @@ struct MemoryOffsetOptPass
int64_t maxOffset = getMaxOffset(kind);

// Set insertion point right before the memory op for any new
// instructions
// instructions.
builder.setInsertionPoint(op);
Location loc = op->getLoc();

// Extract constants from the address tree
// Extract constants from the address tree.
AddrAnalysis analysis = extractConstant(addr, builder, loc);

if (analysis.constOffset == 0)
Expand All @@ -460,48 +466,54 @@ struct MemoryOffsetOptPass
continue;

if (newOffset <= maxOffset) {
// Constant fits in hardware offset field: fold into offset:N
// Constant fits entirely in the hardware offset field.
op->setOperand(addrIdx, analysis.base);
setOffset(op, newOffset, kind);
totalFolded++;
} else {
// Constant exceeds hardware offset limit. Still apply the shift
// distribution to simplify the address tree, but leave the constant
// as an explicit v_add_u32 with a literal. This enables CSE to
// deduplicate the shared base (analysis.base) across multiple
// memory ops that differ only by their constant offset.
// Constant exceeds the hardware offset limit. Split it:
// v_add_u32(base, K) where K > maxOffset
// → v_add_u32(base, K_hi) offset:(K_lo + existingOffset)
// where K = K_hi + K_lo, K_lo = K % (maxOffset + 1).
//
// Check if constant fits in 32-bit integer (V_ADD_U32 limitation)
if (analysis.constOffset > std::numeric_limits<int32_t>::max() ||
analysis.constOffset < std::numeric_limits<int32_t>::min()) {
// The downstream ScopedCSE pass then merges v_add_u32 ops that
// now share the same K_hi and base (e.g., loads that originally
// had constants 67584 and 68608 both become v_add_u32(base, 67584)
// with offsets 0 and 1024 respectively).
int64_t K = analysis.constOffset;
int64_t K_lo = K % (maxOffset + 1);
int64_t K_hi = K - K_lo;
int64_t splitOffset = existingOffset + K_lo;

// Sanity: K_lo is in [0, maxOffset], so splitOffset is bounded.
assert(K_lo >= 0 && K_lo <= maxOffset && "bad constant split");

if (splitOffset < 0 || splitOffset > maxOffset)
continue;
// Check if K_hi fits in 32-bit integer (V_ADD_U32 limitation).
if (K_hi > std::numeric_limits<int32_t>::max() ||
K_hi < std::numeric_limits<int32_t>::min())
continue;
}

// Before: (base + K) << N + col [3 ops, K<<N > maxOffset]
// After: (base << N + col) + K<<N [1 op + shared base via CSE]
auto constImm = builder.getType<ImmType>(analysis.constOffset);
auto constOp =
ConstantOp::create(builder, loc, constImm, analysis.constOffset);
auto constImm = builder.getType<ImmType>(K_hi);
auto constOp = ConstantOp::create(builder, loc, constImm, K_hi);
auto vregType = builder.getType<VRegType>(1, 1);
// NOTE: constant must be src0 (first operand) for VOP2 encoding.
// src1 must be a VGPR on AMDGCN.
// NOTE: constant must be src0 for VOP2 encoding.
auto newAddr =
V_ADD_U32::create(builder, loc, vregType, constOp, analysis.base);
op->setOperand(addrIdx, newAddr.getResult());
setOffset(op, splitOffset, kind);
totalFolded++;
}
}

// Remove dead instructions created by the folding
// totalDead += removeDeadOps(program);
// NOTE: Dead code elimination is delegated to the standard Canonicalizer
// or CSE passes that should run after this pass.
});

LLVM_DEBUG(if (totalFolded > 0) {
llvm::dbgs() << "MemoryOffsetOpt: folded " << totalFolded
<< " constant address components into offset fields\n";
});
if (totalFolded > 0)
LDBG() << "folded " << totalFolded
<< " constant address components into offset fields";
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ waveasm.program @test_shift_distribution target = #waveasm.target<#waveasm.gfx94
}

//===----------------------------------------------------------------------===//
// Test: Buffer store offset exceeds limit -> NO fold
// Test: Buffer store offset exceeds limit -> split into K_hi + K_lo
// 8192 = 8192 + 0 (K_lo = 8192 % 4096 = 0), so no useful split.
//===----------------------------------------------------------------------===//

// CHECK-LABEL: waveasm.program @test_offset_limit
Expand All @@ -106,15 +107,80 @@ waveasm.program @test_offset_limit target = #waveasm.target<#waveasm.gfx942, 5>
%c8192 = waveasm.constant 8192 : !waveasm.imm<8192>
%addr = waveasm.v_add_u32 %base, %c8192 : !waveasm.pvreg<1>, !waveasm.imm<8192> -> !waveasm.vreg

// Should NOT be folded because 8192 > 4095 (buffer max offset)
// 8192 % 4096 = 0, so K_hi = 8192 and K_lo = 0. The v_add stays.
// CHECK: waveasm.constant 8192
// CHECK: waveasm.v_add_u32
// CHECK: waveasm.buffer_store_dword
// CHECK-NOT: instOffset = 8192
waveasm.buffer_store_dword %data, %srd, %addr : !waveasm.pvreg<0>, !waveasm.psreg<0, 4>, !waveasm.vreg

waveasm.s_endpgm
}

//===----------------------------------------------------------------------===//
// Test: Constant splitting for oversized buffer offset
// v_add_u32(base, 5120) where 5120 > 4095
// K_lo = 5120 % 4096 = 1024, K_hi = 4096
// -> v_add_u32(base, 4096) offset:1024
//===----------------------------------------------------------------------===//

// CHECK-LABEL: waveasm.program @test_constant_split_buffer
waveasm.program @test_constant_split_buffer target = #waveasm.target<#waveasm.gfx942, 5> abi = #waveasm.abi<> {
%srd = waveasm.precolored.sreg 0, 4 : !waveasm.psreg<0, 4>
%base = waveasm.precolored.vreg 0 : !waveasm.pvreg<0>
%soff = waveasm.precolored.sreg 4 : !waveasm.psreg<4>

// Address = base + 5120
%c5120 = waveasm.constant 5120 : !waveasm.imm<5120>
%addr = waveasm.v_add_u32 %base, %c5120 : !waveasm.pvreg<0>, !waveasm.imm<5120> -> !waveasm.vreg

// 5120 = 4096 + 1024. The v_add gets K_hi=4096, offset gets K_lo=1024.
// CHECK: waveasm.constant 4096
// CHECK: waveasm.v_add_u32 %{{.*}}, %{{.*}} : !waveasm.imm<4096>
// CHECK: waveasm.buffer_load_dwordx4
// CHECK-SAME: offset : 1024
%result = waveasm.buffer_load_dwordx4 %srd, %addr, %soff : !waveasm.psreg<0, 4>, !waveasm.vreg, !waveasm.psreg<4> -> !waveasm.vreg<4, 4>

waveasm.s_endpgm
}

//===----------------------------------------------------------------------===//
// Test: Two buffer loads differing by a small constant get the same K_hi
// after splitting, enabling CSE to merge them downstream.
// v_add_u32(base, 67584) -> v_add_u32(base, 65536) offset:2048
// v_add_u32(base, 68608) -> v_add_u32(base, 65536) offset:3072
// Both emit the same K_hi=65536, so ScopedCSE can share the v_add.
//===----------------------------------------------------------------------===//

// CHECK-LABEL: waveasm.program @test_constant_split_shared_khi
waveasm.program @test_constant_split_shared_khi target = #waveasm.target<#waveasm.gfx942, 5> abi = #waveasm.abi<> {
%srd = waveasm.precolored.sreg 0, 4 : !waveasm.psreg<0, 4>
%base = waveasm.precolored.vreg 0 : !waveasm.pvreg<0>
%soff = waveasm.precolored.sreg 4 : !waveasm.psreg<4>

%c67584 = waveasm.constant 67584 : !waveasm.imm<67584>
%addr1 = waveasm.v_add_u32 %base, %c67584 : !waveasm.pvreg<0>, !waveasm.imm<67584> -> !waveasm.vreg

%c68608 = waveasm.constant 68608 : !waveasm.imm<68608>
%addr2 = waveasm.v_add_u32 %base, %c68608 : !waveasm.pvreg<0>, !waveasm.imm<68608> -> !waveasm.vreg

// 67584 = 65536 + 2048 -> v_add(base, 65536) offset:2048
// 68608 = 65536 + 3072 -> v_add(base, 65536) offset:3072
// Both produce the same K_hi = 65536.
// CHECK: waveasm.constant 65536
// CHECK: waveasm.v_add_u32 %{{.*}}, %{{.*}} : !waveasm.imm<65536>
// CHECK: waveasm.buffer_load_dwordx4
// CHECK-SAME: offset : 2048
%r1 = waveasm.buffer_load_dwordx4 %srd, %addr1, %soff : !waveasm.psreg<0, 4>, !waveasm.vreg, !waveasm.psreg<4> -> !waveasm.vreg<4, 4>

// CHECK: waveasm.constant 65536
// CHECK: waveasm.v_add_u32 %{{.*}}, %{{.*}} : !waveasm.imm<65536>
// CHECK: waveasm.buffer_load_dwordx4
// CHECK-SAME: offset : 3072
%r2 = waveasm.buffer_load_dwordx4 %srd, %addr2, %soff : !waveasm.psreg<0, 4>, !waveasm.vreg, !waveasm.psreg<4> -> !waveasm.vreg<4, 4>

waveasm.s_endpgm
}

//===----------------------------------------------------------------------===//
// Test: DS read with large offset IS allowed (DS max = 65535)
//===----------------------------------------------------------------------===//
Expand Down
Loading