diff --git a/waveasm/lib/Transforms/Liveness.cpp b/waveasm/lib/Transforms/Liveness.cpp index 54661f1ad..2291195f8 100644 --- a/waveasm/lib/Transforms/Liveness.cpp +++ b/waveasm/lib/Transforms/Liveness.cpp @@ -232,22 +232,22 @@ LivenessInfo computeLiveness(ProgramOp program) { // the loop exits, so their live ranges should not overlap with the // loop body. Using the LoopOp index would inflate register pressure // by keeping these results "live" throughout the entire loop. - if (isa(op)) { - // Find the next sibling op after this LoopOp in the parent block. - // If there is no next sibling (loop is block-terminating), use idx + 1 - // as a synthetic "after loop" point so loop results still get def points. - int64_t loopResultDefPoint = idx + 1; + if (isa(op) || isa(op)) { + // Find the next sibling op after this LoopOp/IfOp in the parent block. + // If there is no next sibling (op is block-terminating), use idx + 1 + // as a synthetic "after" point so results still get def points. + int64_t resultDefPoint = idx + 1; Operation *nextOp = op->getNextNode(); if (nextOp) { auto nextIt = opToIdx.find(nextOp); if (nextIt != opToIdx.end()) { - loopResultDefPoint = nextIt->second; + resultDefPoint = nextIt->second; } } for (Value def : op->getResults()) { if (isVirtualRegType(def.getType())) { if (!info.defPoints.contains(def)) { - info.defPoints[def] = loopResultDefPoint; + info.defPoints[def] = resultDefPoint; } } } @@ -368,46 +368,63 @@ LivenessInfo computeLiveness(ProgramOp program) { continue; Operation *useOp = ops[useIdx]; - // Walk up parent chain to find enclosing loop ops + // Walk up parent chain to find enclosing loop/if ops Operation *parent = useOp->getParentOp(); while (parent && !isa(parent)) { - if (auto loopOp = dyn_cast(parent)) { - // Check if the value is defined inside the loop body - // (at any nesting depth). Values defined inside are recomputed - // each iteration and should keep their natural live ranges - // within the iteration. Only values defined OUTSIDE need - // extension across the loop. - bool definedInside = false; + // Check if the value is defined inside a given ancestor op + // (at any nesting depth). Values defined inside are recomputed + // each iteration (for loops) or only live in one branch (for ifs) + // and should keep their natural live ranges. Only values defined + // OUTSIDE need extension across the region op. + auto isDefinedInside = [&](Operation *ancestor) -> bool { if (auto defOp = value.getDefiningOp()) { - // Check if defOp is anywhere inside the loop's region, - // not just a direct child. This handles values defined - // inside nested if/loop ops within the loop body. - definedInside = loopOp->isProperAncestor(defOp); + return ancestor->isProperAncestor(defOp); } else if (auto blockArg = dyn_cast(value)) { - // BlockArguments don't have a defining op. Check if the - // block argument's parent op is the loop or nested inside it. Operation *argParentOp = blockArg.getOwner()->getParentOp(); - definedInside = (argParentOp == loopOp.getOperation()) || - loopOp->isProperAncestor(argParentOp); + return (argParentOp == ancestor) || + ancestor->isProperAncestor(argParentOp); } - - if (!definedInside) { - // Extend end to cover the entire loop body (value is - // used every iteration, must survive until loop exits) - Block &body = loopOp.getBodyBlock(); - Operation *terminator = body.getTerminator(); - if (terminator) { - auto termIt = opToIdx.find(terminator); - if (termIt != opToIdx.end()) { - it->second.end = std::max(it->second.end, termIt->second); + return false; + }; + + // Extend end to the last terminator in any region of the op. + auto extendToRegionEnd = [&](Operation *regionOp) { + for (Region ®ion : regionOp->getRegions()) { + for (Block &block : region) { + Operation *terminator = block.getTerminator(); + if (terminator) { + auto termIt = opToIdx.find(terminator); + if (termIt != opToIdx.end()) { + it->second.end = std::max(it->second.end, termIt->second); + } } } - // Extend start back to the loop op + } + }; + + if (auto loopOp = dyn_cast(parent)) { + if (!isDefinedInside(loopOp)) { + // Extend end to cover the entire loop body (value is + // used every iteration, must survive until loop exits). + extendToRegionEnd(loopOp); + // Extend start back to the loop op. auto loopIt = opToIdx.find(loopOp.getOperation()); if (loopIt != opToIdx.end()) { it->second.start = std::min(it->second.start, loopIt->second); } } + } else if (auto ifOp = dyn_cast(parent)) { + if (!isDefinedInside(ifOp)) { + // Extend to cover both branches (conservative: only one + // executes at runtime, but the linear scan allocator + // flattens both into a single instruction stream). + extendToRegionEnd(ifOp); + // Extend start back to the if op. + auto ifIt = opToIdx.find(ifOp.getOperation()); + if (ifIt != opToIdx.end()) { + it->second.start = std::min(it->second.start, ifIt->second); + } + } } parent = parent->getParentOp(); } @@ -540,6 +557,113 @@ LivenessInfo computeLiveness(ProgramOp program) { } }); + // Pass 3c: Build tied equivalence classes for IfOp results. + // + // IfOp results must share the same physical register as their + // corresponding yield operands from the then (and optionally else) + // region. Without this tying, the allocator may assign different + // registers (or sizes) to the yield operand and the IfOp result, + // causing incorrect assembly (e.g., MFMA accumulator tuple shrunk + // to a single register). + program.walk([&](IfOp ifOp) { + if (ifOp->getNumResults() == 0) + return; + + auto &thenBlock = ifOp.getThenBlock(); + auto thenYield = dyn_cast(thenBlock.getTerminator()); + if (!thenYield) + return; + + YieldOp elseYield = nullptr; + if (Block *elseBlock = ifOp.getElseBlock()) { + elseYield = dyn_cast(elseBlock->getTerminator()); + } + + for (unsigned i = 0; i < ifOp->getNumResults(); ++i) { + Value ifResult = ifOp->getResult(i); + auto resIt = info.ranges.find(ifResult); + if (resIt == info.ranges.end()) + continue; + + llvm::SmallVector members; + members.push_back(ifResult); + + if (i < thenYield.getResults().size()) { + Value thenVal = thenYield.getResults()[i]; + if (info.ranges.contains(thenVal)) + members.push_back(thenVal); + } + if (elseYield && i < elseYield.getResults().size()) { + Value elseVal = elseYield.getResults()[i]; + if (info.ranges.contains(elseVal)) + members.push_back(elseVal); + } + + if (members.size() <= 1) + continue; + + // Check if any member is already in a class. + int64_t classId = -1; + for (Value member : members) { + auto existingIt = tc.valueToClassId.find(member); + if (existingIt != tc.valueToClassId.end()) { + classId = existingIt->second; + break; + } + } + + if (classId < 0) { + classId = static_cast(tc.classes.size()); + tc.classes.push_back({}); + tc.classes.back().id = classId; + tc.classes.back().canonical = ifResult; + tc.classes.back().size = resIt->second.size; + tc.classes.back().alignment = resIt->second.alignment; + tc.classes.back().regClass = resIt->second.regClass; + tc.classes.back().envelopeStart = resIt->second.start; + tc.classes.back().envelopeEnd = resIt->second.end; + } + + TiedClass &cls = tc.classes[classId]; + + for (Value member : members) { + if (tc.valueToClassId.contains(member)) + continue; + tc.valueToClassId[member] = classId; + cls.members.push_back(member); + + auto rangeIt = info.ranges.find(member); + if (rangeIt != info.ranges.end()) { + cls.envelopeStart = + std::min(cls.envelopeStart, rangeIt->second.start); + cls.envelopeEnd = std::max(cls.envelopeEnd, rangeIt->second.end); + rangeIt->second.tiedClassId = classId; + } + } + + // Build tiedPairs: all three (ifResult, thenVal, elseVal) must share + // one physical register. The then yield is processed first in linear + // order, so it's the canonical source: + // ifResult -> thenVal (ifResult picks up thenVal's phys reg) + // elseVal -> thenVal (elseVal picks up thenVal's phys reg) + Value thenVal; + if (i < thenYield.getResults().size()) { + thenVal = thenYield.getResults()[i]; + if (info.ranges.contains(thenVal) && !tc.tiedPairs.contains(ifResult)) + tc.tiedPairs[ifResult] = thenVal; + } + if (elseYield && i < elseYield.getResults().size()) { + Value elseVal = elseYield.getResults()[i]; + if (info.ranges.contains(elseVal) && !tc.tiedPairs.contains(elseVal)) { + if (thenVal && info.ranges.contains(thenVal)) + tc.tiedPairs[elseVal] = thenVal; + else + tc.tiedPairs[elseVal] = ifResult; + } + } + } + }); + // Pass 4: Categorize ranges by register class and sort by start for (const auto &[value, range] : info.ranges) { if (range.regClass == RegClass::VGPR) { diff --git a/waveasm/lib/Transforms/RegionBuilder.cpp b/waveasm/lib/Transforms/RegionBuilder.cpp index bb91c6088..fc6e6e6a0 100644 --- a/waveasm/lib/Transforms/RegionBuilder.cpp +++ b/waveasm/lib/Transforms/RegionBuilder.cpp @@ -91,13 +91,16 @@ LoopOp RegionBuilder::buildLoopFromSCFFor(scf::ForOp forOp) { return nullptr; } - // Convert lower bound to sreg if it's an immediate (loop counter needs sreg - // type) + // The loop induction variable must be an SGPR (used in s_add_u32 / s_cmp). Value lowerBoundValue = *lowerBound; if (isa(lowerBoundValue.getType())) { auto sregType = ctx.createSRegType(); lowerBoundValue = S_MOV_B32::create(builder, loc, sregType, lowerBoundValue); + } else if (isVGPRType(lowerBoundValue.getType())) { + auto sregType = ctx.createSRegType(); + lowerBoundValue = + V_READFIRSTLANE_B32::create(builder, loc, sregType, lowerBoundValue); } initArgs.push_back(lowerBoundValue); @@ -350,7 +353,13 @@ IfOp RegionBuilder::buildIfFromSCFIf(scf::IfOp ifOp) { } } - YieldOp::create(builder, loc, yieldVals); + auto thenYieldOp = YieldOp::create(builder, loc, yieldVals); + + for (unsigned i = 0; i < waveIfOp->getNumResults(); ++i) { + if (i < thenYieldOp.getResults().size()) { + waveIfOp->getResult(i).setType(thenYieldOp.getResults()[i].getType()); + } + } } // Translate else region if present @@ -369,9 +378,20 @@ IfOp RegionBuilder::buildIfFromSCFIf(scf::IfOp ifOp) { cast(ifOp.getElseRegion().front().getTerminator()); SmallVector yieldVals; - for (Value res : scfYield.getResults()) { + for (auto [idx, res] : llvm::enumerate(scfYield.getResults())) { if (auto mapped = ctx.getMapper().getMapped(res)) { - yieldVals.push_back(*mapped); + Value val = *mapped; + // If the then-yield has a wider type (e.g. areg<4,4>) but this + // else-yield operand is an immediate, create a zero-initialized + // register of the matching type so the IfOp results are + // consistently typed across both branches. + if (idx < waveIfOp->getNumResults()) { + Type thenType = waveIfOp->getResult(idx).getType(); + if (thenType != val.getType() && isAGPRType(thenType)) { + val = V_MOV_B32::create(builder, loc, thenType, val); + } + } + yieldVals.push_back(val); } else { scfYield.emitError("yield result not mapped"); return nullptr; diff --git a/waveasm/lib/Transforms/TranslateFromMLIR.cpp b/waveasm/lib/Transforms/TranslateFromMLIR.cpp index 37983f34d..31cda85fb 100644 --- a/waveasm/lib/Transforms/TranslateFromMLIR.cpp +++ b/waveasm/lib/Transforms/TranslateFromMLIR.cpp @@ -563,6 +563,7 @@ LogicalResult handleVectorFma(Operation *op, TranslationContext &ctx); LogicalResult handleVectorReduction(Operation *op, TranslationContext &ctx); LogicalResult handleVectorExtractStridedSlice(Operation *op, TranslationContext &ctx); +LogicalResult handleVectorFromElements(Operation *op, TranslationContext &ctx); } // namespace waveasm @@ -1504,6 +1505,7 @@ LogicalResult handleMemRefAtomicRMW(Operation *op, TranslationContext &ctx); LogicalResult handleReadFirstLane(Operation *op, TranslationContext &ctx); LogicalResult handleROCDLSBarrier(Operation *op, TranslationContext &ctx); LogicalResult handleROCDLSetPrio(Operation *op, TranslationContext &ctx); +LogicalResult handleROCDLSchedBarrier(Operation *op, TranslationContext &ctx); LogicalResult handleSWaitcnt(Operation *op, TranslationContext &ctx); //===----------------------------------------------------------------------===// @@ -1650,6 +1652,7 @@ void OpHandlerRegistry::registerDefaultHandlers(mlir::MLIRContext *ctx) { REGISTER_HANDLER(vector::TransferWriteOp, handleVectorTransferWrite); REGISTER_HANDLER(vector::FMAOp, handleVectorFma); REGISTER_HANDLER(vector::ReductionOp, handleVectorReduction); + REGISTER_HANDLER(vector::FromElementsOp, handleVectorFromElements); // AMDGPU dialect REGISTER_HANDLER(amdgpu::LDSBarrierOp, handleAMDGPULdsBarrier); @@ -1666,6 +1669,7 @@ void OpHandlerRegistry::registerDefaultHandlers(mlir::MLIRContext *ctx) { REGISTER_HANDLER(ROCDL::ReadfirstlaneOp, handleReadFirstLane); REGISTER_HANDLER(ROCDL::SBarrierOp, handleROCDLSBarrier); REGISTER_HANDLER(ROCDL::SetPrioOp, handleROCDLSetPrio); + REGISTER_HANDLER(ROCDL::SchedBarrier, handleROCDLSchedBarrier); REGISTER_HANDLER(ROCDL::SWaitcntOp, handleSWaitcnt); // IREE/Stream dialect (unregistered operations) diff --git a/waveasm/lib/Transforms/handlers/AMDGPUHandlers.cpp b/waveasm/lib/Transforms/handlers/AMDGPUHandlers.cpp index 3127d1625..e02f463a7 100644 --- a/waveasm/lib/Transforms/handlers/AMDGPUHandlers.cpp +++ b/waveasm/lib/Transforms/handlers/AMDGPUHandlers.cpp @@ -1112,6 +1112,20 @@ LogicalResult handleMemRefAtomicRMW(Operation *op, TranslationContext &ctx) { return success(); } +LogicalResult handleROCDLSchedBarrier(Operation *op, TranslationContext &ctx) { + auto &builder = ctx.getBuilder(); + auto loc = op->getLoc(); + + int32_t mask = 0; + if (auto maskAttr = op->getAttrOfType("mask")) { + mask = maskAttr.getInt(); + } + + RawOp::create(builder, loc, + "s_sched_barrier 0x" + llvm::utohexstr(mask)); + return success(); +} + LogicalResult handleSWaitcnt(Operation *op, TranslationContext &ctx) { auto &builder = ctx.getBuilder(); auto loc = op->getLoc(); diff --git a/waveasm/lib/Transforms/handlers/Handlers.h b/waveasm/lib/Transforms/handlers/Handlers.h index 8d6481c1d..8ca41bc8f 100644 --- a/waveasm/lib/Transforms/handlers/Handlers.h +++ b/waveasm/lib/Transforms/handlers/Handlers.h @@ -155,6 +155,8 @@ mlir::LogicalResult handleVectorFma(mlir::Operation *op, TranslationContext &ctx); mlir::LogicalResult handleVectorReduction(mlir::Operation *op, TranslationContext &ctx); +mlir::LogicalResult handleVectorFromElements(mlir::Operation *op, + TranslationContext &ctx); //===----------------------------------------------------------------------===// // AMDGPU Dialect Handlers @@ -183,6 +185,8 @@ mlir::LogicalResult handleMemRefAtomicRMW(mlir::Operation *op, mlir::LogicalResult handleReadFirstLane(mlir::Operation *op, TranslationContext &ctx); +mlir::LogicalResult handleROCDLSchedBarrier(mlir::Operation *op, + TranslationContext &ctx); mlir::LogicalResult handleSWaitcnt(mlir::Operation *op, TranslationContext &ctx); diff --git a/waveasm/lib/Transforms/handlers/MemRefHandlers.cpp b/waveasm/lib/Transforms/handlers/MemRefHandlers.cpp index 3243f449f..6b047e8a2 100644 --- a/waveasm/lib/Transforms/handlers/MemRefHandlers.cpp +++ b/waveasm/lib/Transforms/handlers/MemRefHandlers.cpp @@ -213,25 +213,31 @@ LogicalResult handleMemRefLoad(Operation *op, TranslationContext &ctx) { ctx.getMapper().mapValue(loadOp.getResult(), readOp.getResult(0)); } else { // Global load - auto sregType = ctx.createSRegType(4, 4); - auto srd = PrecoloredSRegOp::create(builder, loc, sregType, 8, 4); - - Value voffset; - if (!loadOp.getIndices().empty()) { - if (auto mapped = ctx.getMapper().getMapped(loadOp.getIndices()[0])) { - voffset = *mapped; - } - } - if (!voffset) { - auto immType = ctx.createImmType(0); - voffset = ConstantOp::create(builder, loc, immType, 0); - } + auto [voffset, instOffset] = computeVOffsetFromIndices( + memrefType, loadOp.getIndices(), ctx, loc); + Value srd = lookupSRD(loadOp.getMemref(), ctx, loc); auto zeroImm = builder.getType(0); auto zeroConst = ConstantOp::create(builder, loc, zeroImm, 0); - auto loadInstr = BUFFER_LOAD_DWORD::create( - builder, loc, TypeRange{vregType}, srd, voffset, zeroConst); - ctx.getMapper().mapValue(loadOp.getResult(), loadInstr.getResult(0)); + + Type elemType = memrefType.getElementType(); + int64_t elemBytes = (elemType.getIntOrFloatBitWidth() + 7) / 8; + + Operation *loadInstr; + if (elemBytes <= 1) { + loadInstr = BUFFER_LOAD_UBYTE::create( + builder, loc, TypeRange{vregType}, srd, voffset, zeroConst, + instOffset); + } else if (elemBytes <= 2) { + loadInstr = BUFFER_LOAD_USHORT::create( + builder, loc, TypeRange{vregType}, srd, voffset, zeroConst, + instOffset); + } else { + loadInstr = BUFFER_LOAD_DWORD::create( + builder, loc, TypeRange{vregType}, srd, voffset, zeroConst, + instOffset); + } + ctx.getMapper().mapValue(loadOp.getResult(), loadInstr->getResult(0)); } return success(); @@ -266,21 +272,11 @@ LogicalResult handleMemRefStore(Operation *op, TranslationContext &ctx) { DS_WRITE_B32::create(builder, loc, *data, vaddr); } else { // Global store - auto sregType = ctx.createSRegType(4, 4); - auto srd = PrecoloredSRegOp::create(builder, loc, sregType, 8, 4); - - Value voffset; - if (!storeOp.getIndices().empty()) { - if (auto mapped = ctx.getMapper().getMapped(storeOp.getIndices()[0])) { - voffset = *mapped; - } - } - if (!voffset) { - auto immType = ctx.createImmType(0); - voffset = ConstantOp::create(builder, loc, immType, 0); - } + auto [voffset, instOffset] = computeVOffsetFromIndices( + memrefType, storeOp.getIndices(), ctx, loc); + Value srd = lookupSRD(storeOp.getMemref(), ctx, loc); - BUFFER_STORE_DWORD::create(builder, loc, *data, srd, voffset); + BUFFER_STORE_DWORD::create(builder, loc, *data, srd, voffset, instOffset); } return success(); diff --git a/waveasm/lib/Transforms/handlers/VectorHandlers.cpp b/waveasm/lib/Transforms/handlers/VectorHandlers.cpp index 03f74f122..45aa60642 100644 --- a/waveasm/lib/Transforms/handlers/VectorHandlers.cpp +++ b/waveasm/lib/Transforms/handlers/VectorHandlers.cpp @@ -271,4 +271,62 @@ LogicalResult handleVectorExtractStridedSlice(Operation *op, return success(); } +LogicalResult handleVectorFromElements(Operation *op, TranslationContext &ctx) { + auto fromElemsOp = cast(op); + auto &builder = ctx.getBuilder(); + auto loc = op->getLoc(); + + auto resultType = fromElemsOp.getDest().getType(); + auto elemType = resultType.getElementType(); + int64_t elemBitWidth = elemType.getIntOrFloatBitWidth(); + int64_t numElems = resultType.getNumElements(); + int64_t totalBits = numElems * elemBitWidth; + int64_t numDwords = (totalBits + 31) / 32; + int64_t elemsPerDword = 32 / elemBitWidth; + + SmallVector dwordValues; + for (int64_t d = 0; d < numDwords; ++d) { + Value packed; + for (int64_t e = 0; e < elemsPerDword; ++e) { + int64_t idx = d * elemsPerDword + e; + if (idx >= numElems) + break; + + auto elemMapped = ctx.getMapper().getMapped(fromElemsOp.getElements()[idx]); + if (!elemMapped) { + return op->emitError("element ") << idx << " not mapped"; + } + Value elem = *elemMapped; + + if (e == 0) { + if (isa(elem.getType())) { + auto vregType = ctx.createVRegType(1, 1); + packed = V_MOV_B32::create(builder, loc, vregType, elem); + } else { + packed = elem; + } + } else { + int64_t bitOffset = e * elemBitWidth; + auto shiftImm = ctx.createImmType(bitOffset); + auto shiftConst = + ConstantOp::create(builder, loc, shiftImm, bitOffset); + auto vregType = ctx.createVRegType(1, 1); + Value shifted = + V_LSHLREV_B32::create(builder, loc, vregType, shiftConst, elem); + packed = V_OR_B32::create(builder, loc, vregType, packed, shifted); + } + } + dwordValues.push_back(packed); + } + + if (dwordValues.size() == 1) { + ctx.getMapper().mapValue(fromElemsOp.getDest(), dwordValues[0]); + } else { + auto packedType = ctx.createVRegType(numDwords, numDwords > 1 ? numDwords : 1); + auto packResult = PackOp::create(builder, loc, packedType, dwordValues); + ctx.getMapper().mapValue(fromElemsOp.getDest(), packResult); + } + return success(); +} + } // namespace waveasm diff --git a/waveasm/test/Transforms/ifop-liveness.mlir b/waveasm/test/Transforms/ifop-liveness.mlir new file mode 100644 index 000000000..08b0a5b90 --- /dev/null +++ b/waveasm/test/Transforms/ifop-liveness.mlir @@ -0,0 +1,110 @@ +// RUN: waveasm-translate --disable-pass-verifier --waveasm-linear-scan %s 2>&1 | FileCheck %s +// +// Test that IfOp results are properly tied to their yield operands via +// the liveness analysis (Pass 3c). Without tied classes, the allocator +// may assign different physical registers to the yield operand and the +// IfOp result, causing verification failures or incorrect codegen. + +//===----------------------------------------------------------------------===// +// Test 1: IfOp with vreg result - tied to then-yield and else-yield +//===----------------------------------------------------------------------===// + +// Register allocation should succeed (no "Failed to allocate" error). +// The IfOp result should get the same physical register as the yield operands. +// CHECK-LABEL: waveasm.program @ifop_vreg_tying +// CHECK-NOT: Failed to allocate +// CHECK: waveasm.if +// CHECK: waveasm.s_endpgm + +waveasm.program @ifop_vreg_tying + target = #waveasm.target<#waveasm.gfx942, 5> + abi = #waveasm.abi<> { + + %c0 = waveasm.constant 0 : !waveasm.imm<0> + %c1 = waveasm.constant 1 : !waveasm.imm<1> + %cond = waveasm.precolored.sreg 2 : !waveasm.sreg + + %a = waveasm.v_mov_b32 %c0 : !waveasm.imm<0> -> !waveasm.vreg + %b = waveasm.v_mov_b32 %c1 : !waveasm.imm<1> -> !waveasm.vreg + + %result = waveasm.if %cond : !waveasm.sreg -> !waveasm.vreg { + %sum = waveasm.v_add_u32 %a, %b : !waveasm.vreg, !waveasm.vreg -> !waveasm.vreg + waveasm.yield %sum : !waveasm.vreg + } else { + waveasm.yield %a : !waveasm.vreg + } + + %out = waveasm.v_add_u32 %result, %c1 : !waveasm.vreg, !waveasm.imm<1> -> !waveasm.vreg + waveasm.s_endpgm +} + +//===----------------------------------------------------------------------===// +// Test 2: IfOp inside a loop -- tests that IfOp result def points are +// placed after the IfOp body, not at the IfOp itself +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: waveasm.program @ifop_in_loop +// CHECK-NOT: Failed to allocate +// CHECK: waveasm.loop +// CHECK: waveasm.if +// CHECK: waveasm.s_endpgm + +waveasm.program @ifop_in_loop + target = #waveasm.target<#waveasm.gfx942, 5> + abi = #waveasm.abi<> { + + %c0 = waveasm.constant 0 : !waveasm.imm<0> + %c1 = waveasm.constant 1 : !waveasm.imm<1> + %c10 = waveasm.constant 10 : !waveasm.imm<10> + %init_i = waveasm.s_mov_b32 %c0 : !waveasm.imm<0> -> !waveasm.sreg + %init_acc = waveasm.v_mov_b32 %c0 : !waveasm.imm<0> -> !waveasm.vreg + %cond = waveasm.precolored.sreg 2 : !waveasm.sreg + + %final:2 = waveasm.loop(%i = %init_i, %acc = %init_acc) + : (!waveasm.sreg, !waveasm.vreg) -> (!waveasm.sreg, !waveasm.vreg) { + + %v0 = waveasm.v_mov_b32 %c0 : !waveasm.imm<0> -> !waveasm.vreg + %v1 = waveasm.v_mov_b32 %c1 : !waveasm.imm<1> -> !waveasm.vreg + + %step = waveasm.if %cond : !waveasm.sreg -> !waveasm.vreg { + waveasm.yield %v1 : !waveasm.vreg + } else { + waveasm.yield %v0 : !waveasm.vreg + } + + %new_acc = waveasm.v_add_u32 %acc, %step : !waveasm.vreg, !waveasm.vreg -> !waveasm.vreg + %next:2 = waveasm.s_add_u32 %i, %c1 : !waveasm.sreg, !waveasm.imm<1> -> !waveasm.sreg, !waveasm.sreg + %cont = waveasm.s_cmp_lt_u32 %next#0, %c10 : !waveasm.sreg, !waveasm.imm<10> -> !waveasm.sreg + waveasm.condition %cont : !waveasm.sreg iter_args(%next#0, %new_acc) : !waveasm.sreg, !waveasm.vreg + } + + waveasm.s_endpgm +} + +//===----------------------------------------------------------------------===// +// Test 3: IfOp with wide (vreg<4,4>) accumulator results +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: waveasm.program @ifop_wide_accum +// CHECK-NOT: Failed to allocate +// CHECK: waveasm.if +// CHECK: waveasm.s_endpgm + +waveasm.program @ifop_wide_accum + target = #waveasm.target<#waveasm.gfx942, 5> + abi = #waveasm.abi<> { + + %c0 = waveasm.constant 0 : !waveasm.imm<0> + %cond = waveasm.precolored.sreg 2 : !waveasm.sreg + + %init = waveasm.v_mov_b32 %c0 : !waveasm.imm<0> -> !waveasm.vreg<4, 4> + + %result = waveasm.if %cond : !waveasm.sreg -> !waveasm.vreg<4, 4> { + waveasm.yield %init : !waveasm.vreg<4, 4> + } else { + %zero = waveasm.v_mov_b32 %c0 : !waveasm.imm<0> -> !waveasm.vreg<4, 4> + waveasm.yield %zero : !waveasm.vreg<4, 4> + } + + waveasm.s_endpgm +} diff --git a/waveasm/test/Translate/memref-load-subdword.mlir b/waveasm/test/Translate/memref-load-subdword.mlir new file mode 100644 index 000000000..42ad78435 --- /dev/null +++ b/waveasm/test/Translate/memref-load-subdword.mlir @@ -0,0 +1,43 @@ +// RUN: waveasm-translate --target=gfx942 %s 2>&1 | FileCheck %s +// +// Test: memref.load handler emits subdword buffer loads (UBYTE, USHORT) +// for sub-32-bit element types. + +module { + gpu.module @test_subdword_load { + // CHECK-LABEL: waveasm.program @load_i8 + gpu.func @load_i8(%buf: memref<64xi8>) kernel { + %c0 = arith.constant 0 : index + // 8-bit load -> buffer_load_ubyte + // CHECK: waveasm.buffer_load_ubyte + %val = memref.load %buf[%c0] : memref<64xi8> + gpu.return + } + } +} + +module { + gpu.module @test_short_load { + // CHECK-LABEL: waveasm.program @load_i16 + gpu.func @load_i16(%buf: memref<64xi16>) kernel { + %c0 = arith.constant 0 : index + // 16-bit load -> buffer_load_ushort + // CHECK: waveasm.buffer_load_ushort + %val = memref.load %buf[%c0] : memref<64xi16> + gpu.return + } + } +} + +module { + gpu.module @test_dword_load { + // CHECK-LABEL: waveasm.program @load_i32 + gpu.func @load_i32(%buf: memref<64xi32>) kernel { + %c0 = arith.constant 0 : index + // 32-bit load -> buffer_load_dword + // CHECK: waveasm.buffer_load_dword + %val = memref.load %buf[%c0] : memref<64xi32> + gpu.return + } + } +} diff --git a/waveasm/test/Translate/rocdl-sched-barrier.mlir b/waveasm/test/Translate/rocdl-sched-barrier.mlir new file mode 100644 index 000000000..b50bdcc57 --- /dev/null +++ b/waveasm/test/Translate/rocdl-sched-barrier.mlir @@ -0,0 +1,27 @@ +// RUN: waveasm-translate %s 2>&1 | FileCheck %s +// +// Test: rocdl.sched.barrier handler emits waveasm.raw "s_sched_barrier". + +// CHECK-LABEL: waveasm.program @sched_barrier_test + +// rocdl.sched.barrier 0 -> s_sched_barrier 0x0 +// CHECK: waveasm.raw "s_sched_barrier 0x0" + +// rocdl.sched.barrier 1 -> s_sched_barrier 0x1 +// CHECK: waveasm.raw "s_sched_barrier 0x1" + +// rocdl.sched.barrier 255 -> s_sched_barrier 0xFF +// CHECK: waveasm.raw "s_sched_barrier 0xFF" + +// CHECK: waveasm.s_endpgm + +module { + gpu.module @test_sched_barrier { + gpu.func @sched_barrier_test() kernel { + rocdl.sched.barrier 0 + rocdl.sched.barrier 1 + rocdl.sched.barrier 255 + gpu.return + } + } +} diff --git a/waveasm/test/Translate/scf-if.mlir b/waveasm/test/Translate/scf-if.mlir new file mode 100644 index 000000000..930f827c8 --- /dev/null +++ b/waveasm/test/Translate/scf-if.mlir @@ -0,0 +1,95 @@ +// RUN: waveasm-translate --target=gfx942 %s 2>&1 | FileCheck %s +// +// Test: scf.if translation to waveasm.if, including result type propagation +// from then-yield operands and else-branch type matching. + +//===----------------------------------------------------------------------===// +// Test 1: Simple if-then-else with computed results +//===----------------------------------------------------------------------===// + +module { + gpu.module @test_scf_if { + // CHECK-LABEL: waveasm.program @simple_if + gpu.func @simple_if() kernel { + %tid = gpu.thread_id x + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + + %cond = arith.cmpi ult, %tid, %c10 : index + + // CHECK: waveasm.if + // CHECK: waveasm.yield + // CHECK: } else { + // CHECK: waveasm.yield + // CHECK: } + %result = scf.if %cond -> index { + %sum = arith.addi %tid, %c1 : index + scf.yield %sum : index + } else { + %diff = arith.subi %tid, %c1 : index + scf.yield %diff : index + } + gpu.return + } + } +} + +//===----------------------------------------------------------------------===// +// Test 2: If-then-else nested inside scf.for loop +//===----------------------------------------------------------------------===// + +module { + gpu.module @test_scf_if_in_loop { + // CHECK-LABEL: waveasm.program @if_in_loop + gpu.func @if_in_loop() kernel { + %tid = gpu.thread_id x + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c10 = arith.constant 10 : index + + %cond = arith.cmpi ult, %tid, %c10 : index + + // CHECK: waveasm.loop + %result = scf.for %i = %c0 to %c4 step %c1 iter_args(%acc = %c0) -> index { + // CHECK: waveasm.if + %step = scf.if %cond -> index { + %v = arith.addi %acc, %c1 : index + scf.yield %v : index + } else { + %v = arith.addi %acc, %c0 : index + scf.yield %v : index + } + scf.yield %step : index + } + gpu.return + } + } +} + +//===----------------------------------------------------------------------===// +// Test 3: If without results (void if) +//===----------------------------------------------------------------------===// + +module { + gpu.module @test_void_if { + // CHECK-LABEL: waveasm.program @void_if + gpu.func @void_if() kernel { + %tid = gpu.thread_id x + %c10 = arith.constant 10 : index + + %cond = arith.cmpi ult, %tid, %c10 : index + + // CHECK: waveasm.if + // CHECK: } else { + // CHECK: } + scf.if %cond { + // empty + } else { + // empty + } + gpu.return + } + } +} diff --git a/waveasm/test/Translate/vector-from-elements.mlir b/waveasm/test/Translate/vector-from-elements.mlir new file mode 100644 index 000000000..f94ea50c2 --- /dev/null +++ b/waveasm/test/Translate/vector-from-elements.mlir @@ -0,0 +1,28 @@ +// RUN: waveasm-translate --target=gfx942 %s 2>&1 | FileCheck %s +// +// Test: vector.from_elements handler packs scalar elements into VGPR dwords. + +module { + gpu.module @test_from_elements { + // CHECK-LABEL: waveasm.program @from_elements_i32 + gpu.func @from_elements_i32(%a: i32, %b: i32) kernel { + // Two i32 elements -> 2-dword pack + // CHECK: waveasm.pack + %v = vector.from_elements %a, %b : vector<2xi32> + gpu.return + } + } +} + +module { + gpu.module @test_from_elements_f16 { + // CHECK-LABEL: waveasm.program @from_elements_f16_pair + gpu.func @from_elements_f16_pair(%a: f16, %b: f16) kernel { + // Two f16 elements -> 1-dword pack via shift+or + // CHECK: waveasm.v_lshlrev_b32 + // CHECK: waveasm.v_or_b32 + %v = vector.from_elements %a, %b : vector<2xf16> + gpu.return + } + } +}