Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 158 additions & 34 deletions waveasm/lib/Transforms/Liveness.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,22 +232,22 @@ LivenessInfo computeLiveness(ProgramOp program) {
// the loop exits, so their live ranges should not overlap with the
// loop body. Using the LoopOp index would inflate register pressure
// by keeping these results "live" throughout the entire loop.
if (isa<LoopOp>(op)) {
// Find the next sibling op after this LoopOp in the parent block.
// If there is no next sibling (loop is block-terminating), use idx + 1
// as a synthetic "after loop" point so loop results still get def points.
int64_t loopResultDefPoint = idx + 1;
if (isa<LoopOp>(op) || isa<IfOp>(op)) {
// Find the next sibling op after this LoopOp/IfOp in the parent block.
// If there is no next sibling (op is block-terminating), use idx + 1
// as a synthetic "after" point so results still get def points.
int64_t resultDefPoint = idx + 1;
Operation *nextOp = op->getNextNode();
if (nextOp) {
auto nextIt = opToIdx.find(nextOp);
if (nextIt != opToIdx.end()) {
loopResultDefPoint = nextIt->second;
resultDefPoint = nextIt->second;
}
}
for (Value def : op->getResults()) {
if (isVirtualRegType(def.getType())) {
if (!info.defPoints.contains(def)) {
info.defPoints[def] = loopResultDefPoint;
info.defPoints[def] = resultDefPoint;
}
}
}
Expand Down Expand Up @@ -368,46 +368,63 @@ LivenessInfo computeLiveness(ProgramOp program) {
continue;
Operation *useOp = ops[useIdx];

// Walk up parent chain to find enclosing loop ops
// Walk up parent chain to find enclosing loop/if ops
Operation *parent = useOp->getParentOp();
while (parent && !isa<ProgramOp>(parent)) {
if (auto loopOp = dyn_cast<LoopOp>(parent)) {
// Check if the value is defined inside the loop body
// (at any nesting depth). Values defined inside are recomputed
// each iteration and should keep their natural live ranges
// within the iteration. Only values defined OUTSIDE need
// extension across the loop.
bool definedInside = false;
// Check if the value is defined inside a given ancestor op
// (at any nesting depth). Values defined inside are recomputed
// each iteration (for loops) or only live in one branch (for ifs)
// and should keep their natural live ranges. Only values defined
// OUTSIDE need extension across the region op.
auto isDefinedInside = [&](Operation *ancestor) -> bool {
if (auto defOp = value.getDefiningOp()) {
// Check if defOp is anywhere inside the loop's region,
// not just a direct child. This handles values defined
// inside nested if/loop ops within the loop body.
definedInside = loopOp->isProperAncestor(defOp);
return ancestor->isProperAncestor(defOp);
} else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
// BlockArguments don't have a defining op. Check if the
// block argument's parent op is the loop or nested inside it.
Operation *argParentOp = blockArg.getOwner()->getParentOp();
definedInside = (argParentOp == loopOp.getOperation()) ||
loopOp->isProperAncestor(argParentOp);
return (argParentOp == ancestor) ||
ancestor->isProperAncestor(argParentOp);
}

if (!definedInside) {
// Extend end to cover the entire loop body (value is
// used every iteration, must survive until loop exits)
Block &body = loopOp.getBodyBlock();
Operation *terminator = body.getTerminator();
if (terminator) {
auto termIt = opToIdx.find(terminator);
if (termIt != opToIdx.end()) {
it->second.end = std::max(it->second.end, termIt->second);
return false;
};

// Extend end to the last terminator in any region of the op.
auto extendToRegionEnd = [&](Operation *regionOp) {
for (Region &region : regionOp->getRegions()) {
for (Block &block : region) {
Operation *terminator = block.getTerminator();
if (terminator) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC, getTerminator will assert if the block does not have a terminator so its result should never be null, but may be worth doubl-checking.

auto termIt = opToIdx.find(terminator);
if (termIt != opToIdx.end()) {
it->second.end = std::max(it->second.end, termIt->second);
}
}
}
// Extend start back to the loop op
}
};

if (auto loopOp = dyn_cast<LoopOp>(parent)) {
if (!isDefinedInside(loopOp)) {
// Extend end to cover the entire loop body (value is
// used every iteration, must survive until loop exits).
extendToRegionEnd(loopOp);
// Extend start back to the loop op.
auto loopIt = opToIdx.find(loopOp.getOperation());
if (loopIt != opToIdx.end()) {
it->second.start = std::min(it->second.start, loopIt->second);
}
}
} else if (auto ifOp = dyn_cast<IfOp>(parent)) {
if (!isDefinedInside(ifOp)) {
// Extend to cover both branches (conservative: only one
// executes at runtime, but the linear scan allocator
// flattens both into a single instruction stream).
extendToRegionEnd(ifOp);
// Extend start back to the if op.
auto ifIt = opToIdx.find(ifOp.getOperation());
if (ifIt != opToIdx.end()) {
it->second.start = std::min(it->second.start, ifIt->second);
}
}
}
parent = parent->getParentOp();
}
Expand Down Expand Up @@ -540,6 +557,113 @@ LivenessInfo computeLiveness(ProgramOp program) {
}
});

// Pass 3c: Build tied equivalence classes for IfOp results.
//
// IfOp results must share the same physical register as their
// corresponding yield operands from the then (and optionally else)
// region. Without this tying, the allocator may assign different
// registers (or sizes) to the yield operand and the IfOp result,
// causing incorrect assembly (e.g., MFMA accumulator tuple shrunk
// to a single register).
program.walk([&](IfOp ifOp) {
if (ifOp->getNumResults() == 0)
return;

auto &thenBlock = ifOp.getThenBlock();
auto thenYield = dyn_cast<YieldOp>(thenBlock.getTerminator());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cast, the terminator of an if is guaranteed to be a yield.

if (!thenYield)
return;

YieldOp elseYield = nullptr;
if (Block *elseBlock = ifOp.getElseBlock()) {
elseYield = dyn_cast<YieldOp>(elseBlock->getTerminator());
}

for (unsigned i = 0; i < ifOp->getNumResults(); ++i) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Value ifResult = ifOp->getResult(i);
auto resIt = info.ranges.find(ifResult);
if (resIt == info.ranges.end())
continue;

llvm::SmallVector<Value> members;
members.push_back(ifResult);

if (i < thenYield.getResults().size()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (i < thenYield.getResults().size()) {
if (i < thenYield->getNumResults()) {

Value thenVal = thenYield.getResults()[i];
if (info.ranges.contains(thenVal))
members.push_back(thenVal);
}
if (elseYield && i < elseYield.getResults().size()) {
Value elseVal = elseYield.getResults()[i];
if (info.ranges.contains(elseVal))
members.push_back(elseVal);
}

if (members.size() <= 1)
continue;

// Check if any member is already in a class.
int64_t classId = -1;
for (Value member : members) {
auto existingIt = tc.valueToClassId.find(member);
if (existingIt != tc.valueToClassId.end()) {
classId = existingIt->second;
break;
}
}

if (classId < 0) {
classId = static_cast<int64_t>(tc.classes.size());
tc.classes.push_back({});
tc.classes.back().id = classId;
tc.classes.back().canonical = ifResult;
tc.classes.back().size = resIt->second.size;
tc.classes.back().alignment = resIt->second.alignment;
tc.classes.back().regClass = resIt->second.regClass;
tc.classes.back().envelopeStart = resIt->second.start;
tc.classes.back().envelopeEnd = resIt->second.end;
}

TiedClass &cls = tc.classes[classId];

for (Value member : members) {
if (tc.valueToClassId.contains(member))
continue;
tc.valueToClassId[member] = classId;
cls.members.push_back(member);

auto rangeIt = info.ranges.find(member);
if (rangeIt != info.ranges.end()) {
cls.envelopeStart =
std::min(cls.envelopeStart, rangeIt->second.start);
cls.envelopeEnd = std::max(cls.envelopeEnd, rangeIt->second.end);
rangeIt->second.tiedClassId = classId;
}
}

// Build tiedPairs: all three (ifResult, thenVal, elseVal) must share
// one physical register. The then yield is processed first in linear
// order, so it's the canonical source:
// ifResult -> thenVal (ifResult picks up thenVal's phys reg)
// elseVal -> thenVal (elseVal picks up thenVal's phys reg)
Value thenVal;
if (i < thenYield.getResults().size()) {
thenVal = thenYield.getResults()[i];
if (info.ranges.contains(thenVal) && !tc.tiedPairs.contains(ifResult))
tc.tiedPairs[ifResult] = thenVal;
}
if (elseYield && i < elseYield.getResults().size()) {
Value elseVal = elseYield.getResults()[i];
if (info.ranges.contains(elseVal) && !tc.tiedPairs.contains(elseVal)) {
if (thenVal && info.ranges.contains(thenVal))
tc.tiedPairs[elseVal] = thenVal;
else
tc.tiedPairs[elseVal] = ifResult;
}
}
}
});

// Pass 4: Categorize ranges by register class and sort by start
for (const auto &[value, range] : info.ranges) {
if (range.regClass == RegClass::VGPR) {
Expand Down
30 changes: 25 additions & 5 deletions waveasm/lib/Transforms/RegionBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,16 @@ LoopOp RegionBuilder::buildLoopFromSCFFor(scf::ForOp forOp) {
return nullptr;
}

// Convert lower bound to sreg if it's an immediate (loop counter needs sreg
// type)
// The loop induction variable must be an SGPR (used in s_add_u32 / s_cmp).
Value lowerBoundValue = *lowerBound;
if (isa<ImmType>(lowerBoundValue.getType())) {
auto sregType = ctx.createSRegType();
lowerBoundValue =
S_MOV_B32::create(builder, loc, sregType, lowerBoundValue);
} else if (isVGPRType(lowerBoundValue.getType())) {
auto sregType = ctx.createSRegType();
lowerBoundValue =
V_READFIRSTLANE_B32::create(builder, loc, sregType, lowerBoundValue);
}
initArgs.push_back(lowerBoundValue);

Expand Down Expand Up @@ -350,7 +353,13 @@ IfOp RegionBuilder::buildIfFromSCFIf(scf::IfOp ifOp) {
}
}

YieldOp::create(builder, loc, yieldVals);
auto thenYieldOp = YieldOp::create(builder, loc, yieldVals);

for (unsigned i = 0; i < waveIfOp->getNumResults(); ++i) {
if (i < thenYieldOp.getResults().size()) {
waveIfOp->getResult(i).setType(thenYieldOp.getResults()[i].getType());
}
}
}

// Translate else region if present
Expand All @@ -369,9 +378,20 @@ IfOp RegionBuilder::buildIfFromSCFIf(scf::IfOp ifOp) {
cast<scf::YieldOp>(ifOp.getElseRegion().front().getTerminator());

SmallVector<Value> yieldVals;
for (Value res : scfYield.getResults()) {
for (auto [idx, res] : llvm::enumerate(scfYield.getResults())) {
if (auto mapped = ctx.getMapper().getMapped(res)) {
yieldVals.push_back(*mapped);
Value val = *mapped;
// If the then-yield has a wider type (e.g. areg<4,4>) but this
// else-yield operand is an immediate, create a zero-initialized
// register of the matching type so the IfOp results are
// consistently typed across both branches.
if (idx < waveIfOp->getNumResults()) {
Type thenType = waveIfOp->getResult(idx).getType();
if (thenType != val.getType() && isAGPRType(thenType)) {
val = V_MOV_B32::create(builder, loc, thenType, val);
}
}
yieldVals.push_back(val);
} else {
scfYield.emitError("yield result not mapped");
return nullptr;
Expand Down
4 changes: 4 additions & 0 deletions waveasm/lib/Transforms/TranslateFromMLIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,7 @@ LogicalResult handleVectorFma(Operation *op, TranslationContext &ctx);
LogicalResult handleVectorReduction(Operation *op, TranslationContext &ctx);
LogicalResult handleVectorExtractStridedSlice(Operation *op,
TranslationContext &ctx);
LogicalResult handleVectorFromElements(Operation *op, TranslationContext &ctx);

} // namespace waveasm

Expand Down Expand Up @@ -1504,6 +1505,7 @@ LogicalResult handleMemRefAtomicRMW(Operation *op, TranslationContext &ctx);
LogicalResult handleReadFirstLane(Operation *op, TranslationContext &ctx);
LogicalResult handleROCDLSBarrier(Operation *op, TranslationContext &ctx);
LogicalResult handleROCDLSetPrio(Operation *op, TranslationContext &ctx);
LogicalResult handleROCDLSchedBarrier(Operation *op, TranslationContext &ctx);
LogicalResult handleSWaitcnt(Operation *op, TranslationContext &ctx);

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1650,6 +1652,7 @@ void OpHandlerRegistry::registerDefaultHandlers(mlir::MLIRContext *ctx) {
REGISTER_HANDLER(vector::TransferWriteOp, handleVectorTransferWrite);
REGISTER_HANDLER(vector::FMAOp, handleVectorFma);
REGISTER_HANDLER(vector::ReductionOp, handleVectorReduction);
REGISTER_HANDLER(vector::FromElementsOp, handleVectorFromElements);

// AMDGPU dialect
REGISTER_HANDLER(amdgpu::LDSBarrierOp, handleAMDGPULdsBarrier);
Expand All @@ -1666,6 +1669,7 @@ void OpHandlerRegistry::registerDefaultHandlers(mlir::MLIRContext *ctx) {
REGISTER_HANDLER(ROCDL::ReadfirstlaneOp, handleReadFirstLane);
REGISTER_HANDLER(ROCDL::SBarrierOp, handleROCDLSBarrier);
REGISTER_HANDLER(ROCDL::SetPrioOp, handleROCDLSetPrio);
REGISTER_HANDLER(ROCDL::SchedBarrier, handleROCDLSchedBarrier);
REGISTER_HANDLER(ROCDL::SWaitcntOp, handleSWaitcnt);

// IREE/Stream dialect (unregistered operations)
Expand Down
14 changes: 14 additions & 0 deletions waveasm/lib/Transforms/handlers/AMDGPUHandlers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1112,6 +1112,20 @@ LogicalResult handleMemRefAtomicRMW(Operation *op, TranslationContext &ctx) {
return success();
}

LogicalResult handleROCDLSchedBarrier(Operation *op, TranslationContext &ctx) {
auto &builder = ctx.getBuilder();
auto loc = op->getLoc();

int32_t mask = 0;
if (auto maskAttr = op->getAttrOfType<IntegerAttr>("mask")) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Normally I'd expect the function to take an op of a specific type, or at least to cast and use a named accessor.

mask = maskAttr.getInt();
}

RawOp::create(builder, loc,
"s_sched_barrier 0x" + llvm::utohexstr(mask));
return success();
}

LogicalResult handleSWaitcnt(Operation *op, TranslationContext &ctx) {
auto &builder = ctx.getBuilder();
auto loc = op->getLoc();
Expand Down
4 changes: 4 additions & 0 deletions waveasm/lib/Transforms/handlers/Handlers.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ mlir::LogicalResult handleVectorFma(mlir::Operation *op,
TranslationContext &ctx);
mlir::LogicalResult handleVectorReduction(mlir::Operation *op,
TranslationContext &ctx);
mlir::LogicalResult handleVectorFromElements(mlir::Operation *op,
TranslationContext &ctx);

//===----------------------------------------------------------------------===//
// AMDGPU Dialect Handlers
Expand Down Expand Up @@ -183,6 +185,8 @@ mlir::LogicalResult handleMemRefAtomicRMW(mlir::Operation *op,

mlir::LogicalResult handleReadFirstLane(mlir::Operation *op,
TranslationContext &ctx);
mlir::LogicalResult handleROCDLSchedBarrier(mlir::Operation *op,
TranslationContext &ctx);
mlir::LogicalResult handleSWaitcnt(mlir::Operation *op,
TranslationContext &ctx);

Expand Down
Loading
Loading