-
Notifications
You must be signed in to change notification settings - Fork 28
waveasm: add more ops #1055
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
waveasm: add more ops #1055
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -232,22 +232,22 @@ LivenessInfo computeLiveness(ProgramOp program) { | |||||
| // the loop exits, so their live ranges should not overlap with the | ||||||
| // loop body. Using the LoopOp index would inflate register pressure | ||||||
| // by keeping these results "live" throughout the entire loop. | ||||||
| if (isa<LoopOp>(op)) { | ||||||
| // Find the next sibling op after this LoopOp in the parent block. | ||||||
| // If there is no next sibling (loop is block-terminating), use idx + 1 | ||||||
| // as a synthetic "after loop" point so loop results still get def points. | ||||||
| int64_t loopResultDefPoint = idx + 1; | ||||||
| if (isa<LoopOp>(op) || isa<IfOp>(op)) { | ||||||
| // Find the next sibling op after this LoopOp/IfOp in the parent block. | ||||||
| // If there is no next sibling (op is block-terminating), use idx + 1 | ||||||
| // as a synthetic "after" point so results still get def points. | ||||||
| int64_t resultDefPoint = idx + 1; | ||||||
| Operation *nextOp = op->getNextNode(); | ||||||
| if (nextOp) { | ||||||
| auto nextIt = opToIdx.find(nextOp); | ||||||
| if (nextIt != opToIdx.end()) { | ||||||
| loopResultDefPoint = nextIt->second; | ||||||
| resultDefPoint = nextIt->second; | ||||||
| } | ||||||
| } | ||||||
| for (Value def : op->getResults()) { | ||||||
| if (isVirtualRegType(def.getType())) { | ||||||
| if (!info.defPoints.contains(def)) { | ||||||
| info.defPoints[def] = loopResultDefPoint; | ||||||
| info.defPoints[def] = resultDefPoint; | ||||||
| } | ||||||
| } | ||||||
| } | ||||||
|
|
@@ -368,46 +368,63 @@ LivenessInfo computeLiveness(ProgramOp program) { | |||||
| continue; | ||||||
| Operation *useOp = ops[useIdx]; | ||||||
|
|
||||||
| // Walk up parent chain to find enclosing loop ops | ||||||
| // Walk up parent chain to find enclosing loop/if ops | ||||||
| Operation *parent = useOp->getParentOp(); | ||||||
| while (parent && !isa<ProgramOp>(parent)) { | ||||||
| if (auto loopOp = dyn_cast<LoopOp>(parent)) { | ||||||
| // Check if the value is defined inside the loop body | ||||||
| // (at any nesting depth). Values defined inside are recomputed | ||||||
| // each iteration and should keep their natural live ranges | ||||||
| // within the iteration. Only values defined OUTSIDE need | ||||||
| // extension across the loop. | ||||||
| bool definedInside = false; | ||||||
| // Check if the value is defined inside a given ancestor op | ||||||
| // (at any nesting depth). Values defined inside are recomputed | ||||||
| // each iteration (for loops) or only live in one branch (for ifs) | ||||||
| // and should keep their natural live ranges. Only values defined | ||||||
| // OUTSIDE need extension across the region op. | ||||||
| auto isDefinedInside = [&](Operation *ancestor) -> bool { | ||||||
| if (auto defOp = value.getDefiningOp()) { | ||||||
| // Check if defOp is anywhere inside the loop's region, | ||||||
| // not just a direct child. This handles values defined | ||||||
| // inside nested if/loop ops within the loop body. | ||||||
| definedInside = loopOp->isProperAncestor(defOp); | ||||||
| return ancestor->isProperAncestor(defOp); | ||||||
| } else if (auto blockArg = dyn_cast<BlockArgument>(value)) { | ||||||
| // BlockArguments don't have a defining op. Check if the | ||||||
| // block argument's parent op is the loop or nested inside it. | ||||||
| Operation *argParentOp = blockArg.getOwner()->getParentOp(); | ||||||
| definedInside = (argParentOp == loopOp.getOperation()) || | ||||||
| loopOp->isProperAncestor(argParentOp); | ||||||
| return (argParentOp == ancestor) || | ||||||
| ancestor->isProperAncestor(argParentOp); | ||||||
| } | ||||||
|
|
||||||
| if (!definedInside) { | ||||||
| // Extend end to cover the entire loop body (value is | ||||||
| // used every iteration, must survive until loop exits) | ||||||
| Block &body = loopOp.getBodyBlock(); | ||||||
| Operation *terminator = body.getTerminator(); | ||||||
| if (terminator) { | ||||||
| auto termIt = opToIdx.find(terminator); | ||||||
| if (termIt != opToIdx.end()) { | ||||||
| it->second.end = std::max(it->second.end, termIt->second); | ||||||
| return false; | ||||||
| }; | ||||||
|
|
||||||
| // Extend end to the last terminator in any region of the op. | ||||||
| auto extendToRegionEnd = [&](Operation *regionOp) { | ||||||
| for (Region ®ion : regionOp->getRegions()) { | ||||||
| for (Block &block : region) { | ||||||
| Operation *terminator = block.getTerminator(); | ||||||
| if (terminator) { | ||||||
| auto termIt = opToIdx.find(terminator); | ||||||
| if (termIt != opToIdx.end()) { | ||||||
| it->second.end = std::max(it->second.end, termIt->second); | ||||||
| } | ||||||
| } | ||||||
| } | ||||||
| // Extend start back to the loop op | ||||||
| } | ||||||
| }; | ||||||
|
|
||||||
| if (auto loopOp = dyn_cast<LoopOp>(parent)) { | ||||||
| if (!isDefinedInside(loopOp)) { | ||||||
| // Extend end to cover the entire loop body (value is | ||||||
| // used every iteration, must survive until loop exits). | ||||||
| extendToRegionEnd(loopOp); | ||||||
| // Extend start back to the loop op. | ||||||
| auto loopIt = opToIdx.find(loopOp.getOperation()); | ||||||
| if (loopIt != opToIdx.end()) { | ||||||
| it->second.start = std::min(it->second.start, loopIt->second); | ||||||
| } | ||||||
| } | ||||||
| } else if (auto ifOp = dyn_cast<IfOp>(parent)) { | ||||||
| if (!isDefinedInside(ifOp)) { | ||||||
| // Extend to cover both branches (conservative: only one | ||||||
| // executes at runtime, but the linear scan allocator | ||||||
| // flattens both into a single instruction stream). | ||||||
| extendToRegionEnd(ifOp); | ||||||
| // Extend start back to the if op. | ||||||
| auto ifIt = opToIdx.find(ifOp.getOperation()); | ||||||
| if (ifIt != opToIdx.end()) { | ||||||
| it->second.start = std::min(it->second.start, ifIt->second); | ||||||
| } | ||||||
| } | ||||||
| } | ||||||
| parent = parent->getParentOp(); | ||||||
| } | ||||||
|
|
@@ -540,6 +557,113 @@ LivenessInfo computeLiveness(ProgramOp program) { | |||||
| } | ||||||
| }); | ||||||
|
|
||||||
| // Pass 3c: Build tied equivalence classes for IfOp results. | ||||||
| // | ||||||
| // IfOp results must share the same physical register as their | ||||||
| // corresponding yield operands from the then (and optionally else) | ||||||
| // region. Without this tying, the allocator may assign different | ||||||
| // registers (or sizes) to the yield operand and the IfOp result, | ||||||
| // causing incorrect assembly (e.g., MFMA accumulator tuple shrunk | ||||||
| // to a single register). | ||||||
| program.walk([&](IfOp ifOp) { | ||||||
| if (ifOp->getNumResults() == 0) | ||||||
| return; | ||||||
|
|
||||||
| auto &thenBlock = ifOp.getThenBlock(); | ||||||
| auto thenYield = dyn_cast<YieldOp>(thenBlock.getTerminator()); | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||
| if (!thenYield) | ||||||
| return; | ||||||
|
|
||||||
| YieldOp elseYield = nullptr; | ||||||
| if (Block *elseBlock = ifOp.getElseBlock()) { | ||||||
| elseYield = dyn_cast<YieldOp>(elseBlock->getTerminator()); | ||||||
| } | ||||||
|
|
||||||
| for (unsigned i = 0; i < ifOp->getNumResults(); ++i) { | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
| Value ifResult = ifOp->getResult(i); | ||||||
| auto resIt = info.ranges.find(ifResult); | ||||||
| if (resIt == info.ranges.end()) | ||||||
| continue; | ||||||
|
|
||||||
| llvm::SmallVector<Value> members; | ||||||
| members.push_back(ifResult); | ||||||
|
|
||||||
| if (i < thenYield.getResults().size()) { | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| Value thenVal = thenYield.getResults()[i]; | ||||||
| if (info.ranges.contains(thenVal)) | ||||||
| members.push_back(thenVal); | ||||||
| } | ||||||
| if (elseYield && i < elseYield.getResults().size()) { | ||||||
| Value elseVal = elseYield.getResults()[i]; | ||||||
| if (info.ranges.contains(elseVal)) | ||||||
| members.push_back(elseVal); | ||||||
| } | ||||||
|
|
||||||
| if (members.size() <= 1) | ||||||
| continue; | ||||||
|
|
||||||
| // Check if any member is already in a class. | ||||||
| int64_t classId = -1; | ||||||
| for (Value member : members) { | ||||||
| auto existingIt = tc.valueToClassId.find(member); | ||||||
| if (existingIt != tc.valueToClassId.end()) { | ||||||
| classId = existingIt->second; | ||||||
| break; | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| if (classId < 0) { | ||||||
| classId = static_cast<int64_t>(tc.classes.size()); | ||||||
| tc.classes.push_back({}); | ||||||
| tc.classes.back().id = classId; | ||||||
| tc.classes.back().canonical = ifResult; | ||||||
| tc.classes.back().size = resIt->second.size; | ||||||
| tc.classes.back().alignment = resIt->second.alignment; | ||||||
| tc.classes.back().regClass = resIt->second.regClass; | ||||||
| tc.classes.back().envelopeStart = resIt->second.start; | ||||||
| tc.classes.back().envelopeEnd = resIt->second.end; | ||||||
| } | ||||||
|
|
||||||
| TiedClass &cls = tc.classes[classId]; | ||||||
|
|
||||||
| for (Value member : members) { | ||||||
| if (tc.valueToClassId.contains(member)) | ||||||
| continue; | ||||||
| tc.valueToClassId[member] = classId; | ||||||
| cls.members.push_back(member); | ||||||
|
|
||||||
| auto rangeIt = info.ranges.find(member); | ||||||
| if (rangeIt != info.ranges.end()) { | ||||||
| cls.envelopeStart = | ||||||
| std::min(cls.envelopeStart, rangeIt->second.start); | ||||||
| cls.envelopeEnd = std::max(cls.envelopeEnd, rangeIt->second.end); | ||||||
| rangeIt->second.tiedClassId = classId; | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| // Build tiedPairs: all three (ifResult, thenVal, elseVal) must share | ||||||
| // one physical register. The then yield is processed first in linear | ||||||
| // order, so it's the canonical source: | ||||||
| // ifResult -> thenVal (ifResult picks up thenVal's phys reg) | ||||||
| // elseVal -> thenVal (elseVal picks up thenVal's phys reg) | ||||||
| Value thenVal; | ||||||
| if (i < thenYield.getResults().size()) { | ||||||
| thenVal = thenYield.getResults()[i]; | ||||||
| if (info.ranges.contains(thenVal) && !tc.tiedPairs.contains(ifResult)) | ||||||
| tc.tiedPairs[ifResult] = thenVal; | ||||||
| } | ||||||
| if (elseYield && i < elseYield.getResults().size()) { | ||||||
| Value elseVal = elseYield.getResults()[i]; | ||||||
| if (info.ranges.contains(elseVal) && !tc.tiedPairs.contains(elseVal)) { | ||||||
| if (thenVal && info.ranges.contains(thenVal)) | ||||||
| tc.tiedPairs[elseVal] = thenVal; | ||||||
| else | ||||||
| tc.tiedPairs[elseVal] = ifResult; | ||||||
| } | ||||||
| } | ||||||
| } | ||||||
| }); | ||||||
|
|
||||||
| // Pass 4: Categorize ranges by register class and sort by start | ||||||
| for (const auto &[value, range] : info.ranges) { | ||||||
| if (range.regClass == RegClass::VGPR) { | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1112,6 +1112,20 @@ LogicalResult handleMemRefAtomicRMW(Operation *op, TranslationContext &ctx) { | |
| return success(); | ||
| } | ||
|
|
||
| LogicalResult handleROCDLSchedBarrier(Operation *op, TranslationContext &ctx) { | ||
| auto &builder = ctx.getBuilder(); | ||
| auto loc = op->getLoc(); | ||
|
|
||
| int32_t mask = 0; | ||
| if (auto maskAttr = op->getAttrOfType<IntegerAttr>("mask")) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Normally I'd expect the function to take an op of a specific type, or at least to cast and use a named accessor. |
||
| mask = maskAttr.getInt(); | ||
| } | ||
|
|
||
| RawOp::create(builder, loc, | ||
| "s_sched_barrier 0x" + llvm::utohexstr(mask)); | ||
| return success(); | ||
| } | ||
|
|
||
| LogicalResult handleSWaitcnt(Operation *op, TranslationContext &ctx) { | ||
| auto &builder = ctx.getBuilder(); | ||
| auto loc = op->getLoc(); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIRC,
getTerminatorwill assert if the block does not have a terminator so its result should never be null, but may be worth doubl-checking.