From 820bdca7d48d611c203c7086b29a2dd240fae4d5 Mon Sep 17 00:00:00 2001 From: ShangkunLI Date: Thu, 15 Jan 2026 11:52:54 +0800 Subject: [PATCH 1/8] enbale baseline hyperblock fusion --- include/TaskflowDialect/TaskflowPasses.h | 1 + include/TaskflowDialect/TaskflowPasses.td | 12 ++ lib/TaskflowDialect/Transforms/CMakeLists.txt | 1 + .../FuseHyperblockToPipelinableTaskPass.cpp | 117 ++++++++++++++++++ 4 files changed, 131 insertions(+) create mode 100644 lib/TaskflowDialect/Transforms/FuseHyperblockToPipelinableTaskPass.cpp diff --git a/include/TaskflowDialect/TaskflowPasses.h b/include/TaskflowDialect/TaskflowPasses.h index f621951..4a623ec 100644 --- a/include/TaskflowDialect/TaskflowPasses.h +++ b/include/TaskflowDialect/TaskflowPasses.h @@ -16,6 +16,7 @@ namespace taskflow { #define GEN_PASS_DECL #include "TaskflowDialect/TaskflowPasses.h.inc" std::unique_ptr createConstructHyperblockFromTaskPass(); +std::unique_ptr createFuseHyperblockToPipelinableTaskPass(); #define GEN_PASS_REGISTRATION #include "TaskflowDialect/TaskflowPasses.h.inc" diff --git a/include/TaskflowDialect/TaskflowPasses.td b/include/TaskflowDialect/TaskflowPasses.td index 1bcf3b2..6daf938 100644 --- a/include/TaskflowDialect/TaskflowPasses.td +++ b/include/TaskflowDialect/TaskflowPasses.td @@ -15,4 +15,16 @@ def ConstructHyperblockFromTask : Pass<"construct-hyperblock-from-task", "func:: }]; let constructor = "taskflow::createConstructHyperblockFromTaskPass()"; } + +def FuseHyperblockToPipelinableTask : Pass<"fuse-hyperblock-to-pipelinable-task", "func::FuncOp"> { + let summary = "Fuses hyperblocks into pipelinable tasks"; + let description = [{ + This pass fuses hyperblock into pipelinable tasks by analyzing data dependencies + and segments memory access dependencies into different tasks. + + Conservative hyperblock fusion: split hyperblocks with memory access dependencies into different tasks, + ensure each task has exactly one hyperblock" + }]; + let constructor = "taskflow::createFuseHyperblockToPipelinableTaskPass()"; +} #endif // TASKFLOW_PASSES_TD \ No newline at end of file diff --git a/lib/TaskflowDialect/Transforms/CMakeLists.txt b/lib/TaskflowDialect/Transforms/CMakeLists.txt index 270ce96..f2abb43 100644 --- a/lib/TaskflowDialect/Transforms/CMakeLists.txt +++ b/lib/TaskflowDialect/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) add_mlir_library(MLIRTaskflowTransforms ConstructHyperblockFromTaskPass.cpp + FuseHyperblockToPipelinableTaskPass.cpp DEPENDS MLIRTaskflowTransformsIncGen diff --git a/lib/TaskflowDialect/Transforms/FuseHyperblockToPipelinableTaskPass.cpp b/lib/TaskflowDialect/Transforms/FuseHyperblockToPipelinableTaskPass.cpp new file mode 100644 index 0000000..33b1dad --- /dev/null +++ b/lib/TaskflowDialect/Transforms/FuseHyperblockToPipelinableTaskPass.cpp @@ -0,0 +1,117 @@ +#include "TaskflowDialect/TaskflowDialect.h" +#include "TaskflowDialect/TaskflowOps.h" +#include "TaskflowDialect/TaskflowPasses.h" + +#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/LogicalResult.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; +using namespace mlir::taskflow; + +namespace { +//--------------------------------------------------------------------------- +// Memory Access Analysis +//---------------------------------------------------------------------------- +struct MemoryAccessInfo { + SetVector read_memrefs; + SetVector write_memrefs; + SmallVector counter_indices; +}; + +//--------------------------------------------------------------------------- +// Hyperblocks Grouping +//---------------------------------------------------------------------------- + +struct HyperblockGroup { + SmallVector hyperblocks; + SmallVector shared_indices; + SetVector all_read_memrefs; + SetVector all_write_memrefs; + + void addHyperblock(TaskflowHyperblockOp hb_op, + const MemoryAccessInfo &mem_info) { + this->hyperblocks.push_back(hb_op); + if (this->shared_indices.empty()) { + this->shared_indices = mem_info.counter_indices; + } + all_read_memrefs.insert(mem_info.read_memrefs.begin(), + mem_info.read_memrefs.end()); + all_write_memrefs.insert(mem_info.write_memrefs.begin(), + mem_info.write_memrefs.end()); + } + + bool canAddHyperblock(const MemoryAccessInfo &mem_info) const {} +}; + +// Groups hyperblocks that can be fused together. +static SmallVector +groupHyperblocks(SmallVector &hyperblocks) { + SmallVector groups; + DenseMap hb_to_meminfo_map; +} + +struct FuseHyperblockToPipelinableTaskPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + FuseHyperblockToPipelinableTaskPass) + + StringRef getArgument() const final { + return "fuse-hyperblock-to-pipelinable-task"; + } + + StringRef getDescription() const final { + return "Conservative hyperblock fusion: split hyperblocks with memory " + "access dependencies into different tasks, ensure each task has " + "exactly one hyperblock"; + }; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + func::FuncOp func_op = getOperation(); + + // Collects all tasks. + SmallVector tasks; + func_op.walk([&](TaskflowTaskOp task_op) { tasks.push_back(task_op); }); + + // Process each task. + for (TaskflowTaskOp task_op : tasks) { + SmallVector hyperblocks; + task_op.walk( + [&](TaskflowHyperblockOp hb_op) { hyperblocks.push_back(hb_op); }); + + llvm::errs() << "Found " << hyperblocks.size() << " hyperblocks in task " + << task_op.getTaskName() << "\n"; + if (hyperblocks.size() <= 1) { + llvm::errs() << "Task already has <=1 hyperblock, skip.\n"; + continue; + } + + // Group hyperblocks that can be fused together (Do not have memory access + // dependencies). + auto hyperblock_groups = groupHyperblocks(hyperblocks); + } + } +}; +} // namespace + +std::unique_ptr +mlir::taskflow::createFuseHyperblockToPipelinableTaskPass() { + return std::make_unique(); +} \ No newline at end of file From 3075283e79fc989475861840f7190b770a6d1c57 Mon Sep 17 00:00:00 2001 From: ShangkunLI Date: Fri, 16 Jan 2026 16:48:10 +0800 Subject: [PATCH 2/8] enbale task canonicalization --- include/TaskflowDialect/TaskflowOps.td | 12 +- include/TaskflowDialect/TaskflowPasses.h | 2 +- include/TaskflowDialect/TaskflowPasses.td | 16 +- lib/TaskflowDialect/Transforms/CMakeLists.txt | 2 +- .../Transforms/CanonicalizeTaskPass.cpp | 361 ++++++++++++++++++ .../FuseHyperblockToPipelinableTaskPass.cpp | 117 ------ 6 files changed, 378 insertions(+), 132 deletions(-) create mode 100644 lib/TaskflowDialect/Transforms/CanonicalizeTaskPass.cpp delete mode 100644 lib/TaskflowDialect/Transforms/FuseHyperblockToPipelinableTaskPass.cpp diff --git a/include/TaskflowDialect/TaskflowOps.td b/include/TaskflowDialect/TaskflowOps.td index 66603f7..138f272 100644 --- a/include/TaskflowDialect/TaskflowOps.td +++ b/include/TaskflowDialect/TaskflowOps.td @@ -219,12 +219,12 @@ def TaskflowHyperblockOp : TaskflowOpBase<"hyperblock",[ let regions = (region SizedRegion<1>:$body); - let assemblyFormat = [{ - (`indices` `(` $indices^ `:` type($indices) `)`)? - attr-dict-with-keyword - $body - `->` `(` type($outputs) `)` - }]; + // let assemblyFormat = [{ + // (`indices` `(` $indices^ `:` type($indices) `)`)? + // attr-dict-with-keyword + // $body + // `->` `(` type($outputs) `)` + // }]; } def TaskflowHyperblockYieldOp : TaskflowOpBase<"hyperblock.yield", [ diff --git a/include/TaskflowDialect/TaskflowPasses.h b/include/TaskflowDialect/TaskflowPasses.h index 4a623ec..50f28d0 100644 --- a/include/TaskflowDialect/TaskflowPasses.h +++ b/include/TaskflowDialect/TaskflowPasses.h @@ -16,7 +16,7 @@ namespace taskflow { #define GEN_PASS_DECL #include "TaskflowDialect/TaskflowPasses.h.inc" std::unique_ptr createConstructHyperblockFromTaskPass(); -std::unique_ptr createFuseHyperblockToPipelinableTaskPass(); +std::unique_ptr createCanonicalizeTaskPass(); #define GEN_PASS_REGISTRATION #include "TaskflowDialect/TaskflowPasses.h.inc" diff --git a/include/TaskflowDialect/TaskflowPasses.td b/include/TaskflowDialect/TaskflowPasses.td index 6daf938..4728f13 100644 --- a/include/TaskflowDialect/TaskflowPasses.td +++ b/include/TaskflowDialect/TaskflowPasses.td @@ -16,15 +16,17 @@ def ConstructHyperblockFromTask : Pass<"construct-hyperblock-from-task", "func:: let constructor = "taskflow::createConstructHyperblockFromTaskPass()"; } -def FuseHyperblockToPipelinableTask : Pass<"fuse-hyperblock-to-pipelinable-task", "func::FuncOp"> { - let summary = "Fuses hyperblocks into pipelinable tasks"; +def CanonicalizeTask: Pass<"canonicalize-task", "func::FuncOp">{ + let summary = "Canonicalizes tasks by splitting each hyperblock into a separate atomic task"; let description = [{ - This pass fuses hyperblock into pipelinable tasks by analyzing data dependencies - and segments memory access dependencies into different tasks. + This pass splits tasks so that each task contains exactly one hyperblock. + This creates atomic task units that can be analyzed and optimized independently. - Conservative hyperblock fusion: split hyperblocks with memory access dependencies into different tasks, - ensure each task has exactly one hyperblock" + Input: Task with N hyperblocks + Output: N atomic tasks, each containing one hyperblock + + This is a prerequisite pass before fusion optimizations. }]; - let constructor = "taskflow::createFuseHyperblockToPipelinableTaskPass()"; + let constructor = "taskflow::createCanonicalizeTaskPass()"; } #endif // TASKFLOW_PASSES_TD \ No newline at end of file diff --git a/lib/TaskflowDialect/Transforms/CMakeLists.txt b/lib/TaskflowDialect/Transforms/CMakeLists.txt index f2abb43..ab118c8 100644 --- a/lib/TaskflowDialect/Transforms/CMakeLists.txt +++ b/lib/TaskflowDialect/Transforms/CMakeLists.txt @@ -2,7 +2,7 @@ get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) add_mlir_library(MLIRTaskflowTransforms ConstructHyperblockFromTaskPass.cpp - FuseHyperblockToPipelinableTaskPass.cpp + CanonicalizeTaskPass.cpp DEPENDS MLIRTaskflowTransformsIncGen diff --git a/lib/TaskflowDialect/Transforms/CanonicalizeTaskPass.cpp b/lib/TaskflowDialect/Transforms/CanonicalizeTaskPass.cpp new file mode 100644 index 0000000..0a05814 --- /dev/null +++ b/lib/TaskflowDialect/Transforms/CanonicalizeTaskPass.cpp @@ -0,0 +1,361 @@ +#include "TaskflowDialect/TaskflowDialect.h" +#include "TaskflowDialect/TaskflowOps.h" +#include "TaskflowDialect/TaskflowPasses.h" + +#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Unit.h" +#include "mlir/IR/Value.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/LogicalResult.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; +using namespace mlir::taskflow; + +namespace { +//--------------------------------------------------------------------------- +// Memory Access Info: Information about memory accesses in a hyperblock. +//---------------------------------------------------------------------------- +struct MemoryAccessInfo { + SetVector reads; // MemRefs that are read + SetVector writes; // MemRefs that are written + + void analyze(TaskflowHyperblockOp hyperblock) { + hyperblock.walk([&](Operation *op) { + if (auto load = dyn_cast(op)) { + reads.insert(load.getMemRef()); + } else if (auto store = dyn_cast(op)) { + writes.insert(store.getMemRef()); + } + }); + } + + // Get all memrefs (reads + writes, deduplicated) + SetVector getAllMemRefs() const { + SetVector all; + all.insert(reads.begin(), reads.end()); + all.insert(writes.begin(), writes.end()); + return all; + } +}; + +//--------------------------------------------------------------------------- +// Counter Collector: Collects and sorts counter operations. +//---------------------------------------------------------------------------- +class CounterCollector { +public: + // Collect all counters needed by a hyperblock (including parents) + void collect(TaskflowHyperblockOp hyperblock) { + for (Value idx : hyperblock.getIndices()) { + collectRecursively(idx); + } + } + + // Get counters sorted by depth (parents first) + SmallVector getSortedCounters() const { + SmallVector result(counters.begin(), counters.end()); + llvm::sort(result, [this](TaskflowCounterOp a, TaskflowCounterOp b) { + return getDepth(a) < getDepth(b); + }); + return result; + } + +private: + void collectRecursively(Value idx) { + auto counter = idx.getDefiningOp(); + if (!counter) + return; + + counters.insert(counter); + + if (Value parent = counter.getParentIndex()) { + collectRecursively(parent); + } + } + + size_t getDepth(TaskflowCounterOp counter) const { + size_t depth = 0; + Value parent = counter.getParentIndex(); + while (parent) { + depth++; + if (auto p = parent.getDefiningOp()) { + parent = p.getParentIndex(); + } else { + break; + } + } + return depth; + } + + SetVector counters; +}; + +//--------------------------------------------------------------------------- +// Block Argument Resolver: Resolves block arguments to their source values. +//--------------------------------------------------------------------------- +class BlockArgResolver { +public: + explicit BlockArgResolver(TaskflowTaskOp task) { + Block *body = &task.getBody().front(); + auto inputs = task.getMemoryInputs(); + auto args = body->getArguments(); + + for (auto [input, arg] : llvm::zip(inputs, args)) { + blockArgToSource[arg] = input; + sourceToBlockArg[input] = arg; + } + } + + // Given a value (possibly a block arg), return the source memref + Value resolveToSource(Value val) const { + auto it = blockArgToSource.find(val); + return it != blockArgToSource.end() ? it->second : val; + } + + // Given a source memref, return the block argument + Value getBlockArg(Value source) const { + auto it = sourceToBlockArg.find(source); + return it != sourceToBlockArg.end() ? it->second : Value(); + } + +private: + DenseMap blockArgToSource; + DenseMap sourceToBlockArg; +}; + +//--------------------------------------------------------------------------- +// Atomic Task Builder: Builds an atomic task from a single hyperblock. +//---------------------------------------------------------------------------- +class AtomicTaskBuilder { +public: + AtomicTaskBuilder(OpBuilder &builder, Location loc, unsigned global_task_idx, + DenseMap &memref_to_latest_version) + : builder(builder), loc(loc), global_task_idx(global_task_idx), + memref_to_latest_version(memref_to_latest_version) {} + TaskflowTaskOp build(TaskflowHyperblockOp hyperblock, + TaskflowTaskOp originalTask) { + // Step 1: Analyze memory accesses + MemoryAccessInfo memInfo; + memInfo.analyze(hyperblock); + + // Step 2: Resolve block arguments to source memrefs + BlockArgResolver resolver(originalTask); + + // Step 3: Determine task inputs (use latest versions) + SmallVector taskInputs; + DenseMap sourceToInputIdx; + + for (Value memref : memInfo.getAllMemRefs()) { + Value source = resolver.resolveToSource(memref); + Value inputVal = getLatestVersion(source); + + // Avoid duplicates + if (!sourceToInputIdx.count(source)) { + sourceToInputIdx[source] = taskInputs.size(); + taskInputs.push_back(inputVal); + } + } + + // Step 4: Determine task outputs (written memrefs) + SmallVector outputTypes; + SmallVector writtenSources; + + for (Value memref : memInfo.writes) { + Value source = resolver.resolveToSource(memref); + outputTypes.push_back(source.getType()); + writtenSources.push_back(source); + } + + // Step 5: Create the task operation + std::string taskName = "Task_" + std::to_string(this->global_task_idx); + auto newTask = builder.create( + loc, outputTypes, TypeRange{}, taskInputs, ValueRange{}, + builder.getStringAttr(taskName)); + + // Step 6: Create task body + Block *taskBody = new Block(); + newTask.getBody().push_back(taskBody); + + for (Value input : taskInputs) { + taskBody->addArgument(input.getType(), loc); + } + + // Step 7: Build value mapping + IRMapping mapping; + + // Map source memrefs -> new task's block arguments + for (auto [source, idx] : sourceToInputIdx) { + BlockArgument newArg = taskBody->getArgument(idx); + mapping.map(source, newArg); + + // Also map original block arguments that refer to this source + if (Value origArg = resolver.getBlockArg(source)) { + mapping.map(origArg, newArg); + } + } + + // Step 8: Clone counters and hyperblock + OpBuilder taskBuilder(taskBody, taskBody->begin()); + cloneCounters(taskBuilder, hyperblock, mapping); + cloneHyperblock(taskBuilder, hyperblock, mapping); + + // Step 9: Create yield + SmallVector yieldOperands; + for (Value memref : memInfo.writes) { + yieldOperands.push_back(mapping.lookupOrDefault(memref)); + } + taskBuilder.setInsertionPointToEnd(taskBody); + taskBuilder.create(loc, yieldOperands, ValueRange{}); + + // Step 10: Update latest versions + auto outputs = newTask.getMemoryOutputs(); + for (auto [source, output] : llvm::zip(writtenSources, outputs)) { + this->memref_to_latest_version[source] = output; + } + + return newTask; + } + +private: + Value getLatestVersion(Value source) { + auto it = this->memref_to_latest_version.find(source); + return it != this->memref_to_latest_version.end() ? it->second : source; + } + + void cloneCounters(OpBuilder &taskBuilder, TaskflowHyperblockOp hyperblock, + IRMapping &mapping) { + CounterCollector collector; + collector.collect(hyperblock); + + for (TaskflowCounterOp counter : collector.getSortedCounters()) { + taskBuilder.clone(*counter.getOperation(), mapping); + } + } + + void cloneHyperblock(OpBuilder &taskBuilder, TaskflowHyperblockOp hyperblock, + IRMapping &mapping) { + // Map indices + SmallVector mappedIndices; + for (Value idx : hyperblock.getIndices()) { + mappedIndices.push_back(mapping.lookupOrDefault(idx)); + } + + // Create new hyperblock + SmallVector outputTypes(hyperblock.getOutputs().getTypes()); + auto newHB = taskBuilder.create(loc, outputTypes, + mappedIndices); + + // Create body + Block *newBody = new Block(); + newHB.getBody().push_back(newBody); + + for (Value idx : mappedIndices) { + newBody->addArgument(idx.getType(), loc); + } + + // Map old block args -> new block args + Block *oldBody = &hyperblock.getBody().front(); + for (auto [oldArg, newArg] : + llvm::zip(oldBody->getArguments(), newBody->getArguments())) { + mapping.map(oldArg, newArg); + } + + // Clone operations + OpBuilder hbBuilder(newBody, newBody->begin()); + for (Operation &op : oldBody->without_terminator()) { + hbBuilder.clone(op, mapping); + } + + // Clone terminator + if (auto yield = + dyn_cast(oldBody->getTerminator())) { + SmallVector yieldOps; + for (Value v : yield.getOutputs()) { + yieldOps.push_back(mapping.lookupOrDefault(v)); + } + hbBuilder.create(loc, yieldOps); + } else { + hbBuilder.create(loc, ValueRange{}); + } + } + OpBuilder &builder; + Location loc; + unsigned global_task_idx; + DenseMap &memref_to_latest_version; +}; + +//--------------------------------------------------------------------------- +// Canonicalize Task Pass +//---------------------------------------------------------------------------- +struct CanonicalizeTaskPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CanonicalizeTaskPass) + + StringRef getArgument() const final { return "canonicalize-task"; } + + StringRef getDescription() const final { + return "Canonicalizes tasks by splitting each hyperblock into a separate " + "atomic task (one hyperblock per task)"; + }; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + func::FuncOp func_op = getOperation(); + + // Collects all tasks. + SmallVector tasks_to_process; + func_op.walk( + [&](TaskflowTaskOp task_op) { tasks_to_process.push_back(task_op); }); + + unsigned global_task_idx = 0; + + for (TaskflowTaskOp original_task : tasks_to_process) { + // Collects hyperblocks. + SmallVector hyperblocks; + original_task.walk( + [&](TaskflowHyperblockOp hb) { hyperblocks.push_back(hb); }); + assert(!hyperblocks.empty() && + "Expected at least one hyperblock in the task"); + if (hyperblocks.size() == 1) { + // No need to canonicalize single-hyperblock tasks. + continue; + } + + // Tracks latest versions of memrefs for dependency chaining. + DenseMap memref_to_latest_version; + + // Creates atomic tasks for each hyperblock. + OpBuilder builder(original_task); + + for (TaskflowHyperblockOp hb : hyperblocks) { + AtomicTaskBuilder task_builder(builder, original_task.getLoc(), + global_task_idx, + memref_to_latest_version); + task_builder.build(hb, original_task); + } + + // Erases the original task. + original_task.erase(); + } + } +}; +} // namespace + +std::unique_ptr mlir::taskflow::createCanonicalizeTaskPass() { + return std::make_unique(); +} \ No newline at end of file diff --git a/lib/TaskflowDialect/Transforms/FuseHyperblockToPipelinableTaskPass.cpp b/lib/TaskflowDialect/Transforms/FuseHyperblockToPipelinableTaskPass.cpp deleted file mode 100644 index 33b1dad..0000000 --- a/lib/TaskflowDialect/Transforms/FuseHyperblockToPipelinableTaskPass.cpp +++ /dev/null @@ -1,117 +0,0 @@ -#include "TaskflowDialect/TaskflowDialect.h" -#include "TaskflowDialect/TaskflowOps.h" -#include "TaskflowDialect/TaskflowPasses.h" - -#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/IRMapping.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Transforms/DialectConversion.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/Support/LogicalResult.h" -#include "llvm/Support/raw_ostream.h" - -using namespace mlir; -using namespace mlir::taskflow; - -namespace { -//--------------------------------------------------------------------------- -// Memory Access Analysis -//---------------------------------------------------------------------------- -struct MemoryAccessInfo { - SetVector read_memrefs; - SetVector write_memrefs; - SmallVector counter_indices; -}; - -//--------------------------------------------------------------------------- -// Hyperblocks Grouping -//---------------------------------------------------------------------------- - -struct HyperblockGroup { - SmallVector hyperblocks; - SmallVector shared_indices; - SetVector all_read_memrefs; - SetVector all_write_memrefs; - - void addHyperblock(TaskflowHyperblockOp hb_op, - const MemoryAccessInfo &mem_info) { - this->hyperblocks.push_back(hb_op); - if (this->shared_indices.empty()) { - this->shared_indices = mem_info.counter_indices; - } - all_read_memrefs.insert(mem_info.read_memrefs.begin(), - mem_info.read_memrefs.end()); - all_write_memrefs.insert(mem_info.write_memrefs.begin(), - mem_info.write_memrefs.end()); - } - - bool canAddHyperblock(const MemoryAccessInfo &mem_info) const {} -}; - -// Groups hyperblocks that can be fused together. -static SmallVector -groupHyperblocks(SmallVector &hyperblocks) { - SmallVector groups; - DenseMap hb_to_meminfo_map; -} - -struct FuseHyperblockToPipelinableTaskPass - : public PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( - FuseHyperblockToPipelinableTaskPass) - - StringRef getArgument() const final { - return "fuse-hyperblock-to-pipelinable-task"; - } - - StringRef getDescription() const final { - return "Conservative hyperblock fusion: split hyperblocks with memory " - "access dependencies into different tasks, ensure each task has " - "exactly one hyperblock"; - }; - - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - void runOnOperation() override { - func::FuncOp func_op = getOperation(); - - // Collects all tasks. - SmallVector tasks; - func_op.walk([&](TaskflowTaskOp task_op) { tasks.push_back(task_op); }); - - // Process each task. - for (TaskflowTaskOp task_op : tasks) { - SmallVector hyperblocks; - task_op.walk( - [&](TaskflowHyperblockOp hb_op) { hyperblocks.push_back(hb_op); }); - - llvm::errs() << "Found " << hyperblocks.size() << " hyperblocks in task " - << task_op.getTaskName() << "\n"; - if (hyperblocks.size() <= 1) { - llvm::errs() << "Task already has <=1 hyperblock, skip.\n"; - continue; - } - - // Group hyperblocks that can be fused together (Do not have memory access - // dependencies). - auto hyperblock_groups = groupHyperblocks(hyperblocks); - } - } -}; -} // namespace - -std::unique_ptr -mlir::taskflow::createFuseHyperblockToPipelinableTaskPass() { - return std::make_unique(); -} \ No newline at end of file From ad288fdf0b894397f7ee88e9cdbe25a6a612bcc2 Mon Sep 17 00:00:00 2001 From: ShangkunLI Date: Fri, 16 Jan 2026 21:50:01 +0800 Subject: [PATCH 3/8] fix bugs in construct-hyperblock-from-task pass --- include/TaskflowDialect/TaskflowOps.td | 10 +- .../AffineToTaskflow/AffineToTaskflowPass.cpp | 81 +++++----- .../Transforms/CanonicalizeTaskPass.cpp | 151 +++++++++++------- .../ConstructHyperblockFromTaskPass.cpp | 150 +++++++++++++---- 4 files changed, 259 insertions(+), 133 deletions(-) diff --git a/include/TaskflowDialect/TaskflowOps.td b/include/TaskflowDialect/TaskflowOps.td index 138f272..2e6159a 100644 --- a/include/TaskflowDialect/TaskflowOps.td +++ b/include/TaskflowDialect/TaskflowOps.td @@ -190,6 +190,7 @@ def TaskflowCounterOp : TaskflowOpBase<"counter", [Pure]>{ def TaskflowHyperblockOp : TaskflowOpBase<"hyperblock",[ AutomaticAllocationScope, + AttrSizedOperandSegments, SingleBlockImplicitTerminator<"TaskflowHyperblockYieldOp"> ]>{ let summary = "Hyperblock operation containing loop body computation"; @@ -201,16 +202,17 @@ def TaskflowHyperblockOp : TaskflowOpBase<"hyperblock",[ If the hyperblock has a return value, it must return the final value produced by the hyperblock (i.e., from the last iteration). Example: - %result = taskflow.hyperblock indices(%i : index) { - ^bb0(%idx: index): - // Loop body computation using %idx + %result = taskflow.hyperblock indices(%i : index), iter_args(%init_val : i32) { + ^bb0(%idx: index, %arg: i32): + // Loop body computation using %idx and %arg ... taskflow.hyperblock.yield %output : i32 } -> i32 }]; let arguments = (ins - Variadic:$indices + Variadic:$indices, + Variadic:$iter_args ); let results = (outs diff --git a/lib/Conversion/AffineToTaskflow/AffineToTaskflowPass.cpp b/lib/Conversion/AffineToTaskflow/AffineToTaskflowPass.cpp index f628364..111dec0 100644 --- a/lib/Conversion/AffineToTaskflow/AffineToTaskflowPass.cpp +++ b/lib/Conversion/AffineToTaskflow/AffineToTaskflowPass.cpp @@ -31,21 +31,6 @@ namespace { // Helper Functions. //------------------------------------------------------------------------------ -// Collects all top-level affine.for operations in a function. -static SmallVector -collectTopLevelLooops(func::FuncOp func_op) { - SmallVector top_level_loops; - for (Block &block : func_op.getBlocks()) { - for (Operation &op : block) { - if (auto for_op = dyn_cast(op)) { - top_level_loops.push_back(for_op); - } - } - } - - return top_level_loops; -} - // Collects memrefs that are loaded (read) within a given operation scope. static void collectReadMemrefs(Operation *op, SetVector &read_memrefs) { op->walk([&](Operation *nested_op) { @@ -106,6 +91,19 @@ static void collectExternalValues(Operation *root_op, } } +// Updates operands of an operation using the value mapping. +static void +updateOperationOperands(Operation *op, + const DenseMap &value_mapping) { + for (OpOperand &operand : op->getOpOperands()) { + Value original_value = operand.get(); + auto it = value_mapping.find(original_value); + if (it != value_mapping.end()) { + operand.set(it->second); + } + } +} + //------------------------------------------------------------------------------ // Task Conversion //------------------------------------------------------------------------------ @@ -284,42 +282,45 @@ static TaskflowTaskOp convertLoopToTask(OpBuilder &builder, //------------------------------------------------------------------------------ // Converts a single function to TaskFlow operations. static LogicalResult convertFuncToTaskflow(func::FuncOp func_op) { - // Collects top-level loops for conversion. - SmallVector top_level_loops = - collectTopLevelLooops(func_op); - - if (top_level_loops.empty()) { - // No loops to convert. - llvm::errs() << "No top-level affine.for loops found in function '" - << func_op.getName() << "'.\n"; - return success(); - } llvm::errs() << "\n===Converting function: " << func_op.getName() << "===\n"; - llvm::errs() << "Found " << top_level_loops.size() - << " top-level affine.for loops to convert:\n"; - for (affine::AffineForOp for_op : top_level_loops) { - llvm::errs() << for_op.getLoc() << "\n"; - } OpBuilder builder(func_op.getContext()); + SmallVector loops_to_erase; DenseMap value_mapping; + int task_id_counter = 0; - // Converts each top-level loop to taskflow.task operation. - for (auto [idx, loop] : llvm::enumerate(top_level_loops)) { - builder.setInsertionPoint(loop); - TaskflowTaskOp task_op = - convertLoopToTask(builder, loop, value_mapping, idx); + // Processes each block in the function. + for (Block &block : func_op.getBlocks()) { + // Collects operations to process (to avoid iterator invalidation). + SmallVector ops_to_process; + for (Operation &op : block) { + ops_to_process.push_back(&op); + } - // Replaces uses of loop results with task value outputs. - for (auto [loop_result, task_value_output] : - llvm::zip(loop.getResults(), task_op.getValueOutputs())) { - loop_result.replaceAllUsesWith(task_value_output); + // Processes each operation in order (top to bottom). + for (Operation *op : ops_to_process) { + if (auto for_op = dyn_cast(op)) { + // Converts affine.for to taskflow.task. + OpBuilder builder(for_op); + TaskflowTaskOp task_op = convertLoopToTask( + builder, for_op, value_mapping, task_id_counter++); + + // Replaces uses of loop results with task value outputs. + for (auto [loop_result, task_value_output] : + llvm::zip(for_op.getResults(), task_op.getValueOutputs())) { + loop_result.replaceAllUsesWith(task_value_output); + } + loops_to_erase.push_back(for_op); + } else { + // Updates operands of non-loop operations based on value_mapping. + updateOperationOperands(op, value_mapping); + } } } // Erases the original loops after conversion. - for (affine::AffineForOp for_op : top_level_loops) { + for (affine::AffineForOp for_op : loops_to_erase) { for_op.erase(); } diff --git a/lib/TaskflowDialect/Transforms/CanonicalizeTaskPass.cpp b/lib/TaskflowDialect/Transforms/CanonicalizeTaskPass.cpp index 0a05814..e00281b 100644 --- a/lib/TaskflowDialect/Transforms/CanonicalizeTaskPass.cpp +++ b/lib/TaskflowDialect/Transforms/CanonicalizeTaskPass.cpp @@ -24,12 +24,13 @@ using namespace mlir; using namespace mlir::taskflow; namespace { -//--------------------------------------------------------------------------- -// Memory Access Info: Information about memory accesses in a hyperblock. -//---------------------------------------------------------------------------- +//===----------------------------------------------------------------------===// +// Memory Access Info +//===----------------------------------------------------------------------===// + struct MemoryAccessInfo { - SetVector reads; // MemRefs that are read - SetVector writes; // MemRefs that are written + SetVector reads; + SetVector writes; void analyze(TaskflowHyperblockOp hyperblock) { hyperblock.walk([&](Operation *op) { @@ -41,7 +42,6 @@ struct MemoryAccessInfo { }); } - // Get all memrefs (reads + writes, deduplicated) SetVector getAllMemRefs() const { SetVector all; all.insert(reads.begin(), reads.end()); @@ -50,19 +50,18 @@ struct MemoryAccessInfo { } }; -//--------------------------------------------------------------------------- -// Counter Collector: Collects and sorts counter operations. -//---------------------------------------------------------------------------- +//===----------------------------------------------------------------------===// +// Counter Collector +//===----------------------------------------------------------------------===// + class CounterCollector { public: - // Collect all counters needed by a hyperblock (including parents) void collect(TaskflowHyperblockOp hyperblock) { for (Value idx : hyperblock.getIndices()) { collectRecursively(idx); } } - // Get counters sorted by depth (parents first) SmallVector getSortedCounters() const { SmallVector result(counters.begin(), counters.end()); llvm::sort(result, [this](TaskflowCounterOp a, TaskflowCounterOp b) { @@ -76,9 +75,7 @@ class CounterCollector { auto counter = idx.getDefiningOp(); if (!counter) return; - counters.insert(counter); - if (Value parent = counter.getParentIndex()) { collectRecursively(parent); } @@ -101,9 +98,10 @@ class CounterCollector { SetVector counters; }; -//--------------------------------------------------------------------------- -// Block Argument Resolver: Resolves block arguments to their source values. -//--------------------------------------------------------------------------- +//===----------------------------------------------------------------------===// +// Block Argument Resolver +//===----------------------------------------------------------------------===// + class BlockArgResolver { public: explicit BlockArgResolver(TaskflowTaskOp task) { @@ -117,13 +115,11 @@ class BlockArgResolver { } } - // Given a value (possibly a block arg), return the source memref Value resolveToSource(Value val) const { auto it = blockArgToSource.find(val); return it != blockArgToSource.end() ? it->second : val; } - // Given a source memref, return the block argument Value getBlockArg(Value source) const { auto it = sourceToBlockArg.find(source); return it != sourceToBlockArg.end() ? it->second : Value(); @@ -134,25 +130,25 @@ class BlockArgResolver { DenseMap sourceToBlockArg; }; -//--------------------------------------------------------------------------- -// Atomic Task Builder: Builds an atomic task from a single hyperblock. -//---------------------------------------------------------------------------- +//===----------------------------------------------------------------------===// +// Atomic Task Builder +//===----------------------------------------------------------------------===// + class AtomicTaskBuilder { public: AtomicTaskBuilder(OpBuilder &builder, Location loc, unsigned global_task_idx, DenseMap &memref_to_latest_version) : builder(builder), loc(loc), global_task_idx(global_task_idx), memref_to_latest_version(memref_to_latest_version) {} + TaskflowTaskOp build(TaskflowHyperblockOp hyperblock, TaskflowTaskOp originalTask) { - // Step 1: Analyze memory accesses MemoryAccessInfo memInfo; memInfo.analyze(hyperblock); - // Step 2: Resolve block arguments to source memrefs BlockArgResolver resolver(originalTask); - // Step 3: Determine task inputs (use latest versions) + // Determine task inputs SmallVector taskInputs; DenseMap sourceToInputIdx; @@ -160,14 +156,13 @@ class AtomicTaskBuilder { Value source = resolver.resolveToSource(memref); Value inputVal = getLatestVersion(source); - // Avoid duplicates if (!sourceToInputIdx.count(source)) { sourceToInputIdx[source] = taskInputs.size(); taskInputs.push_back(inputVal); } } - // Step 4: Determine task outputs (written memrefs) + // Determine task outputs SmallVector outputTypes; SmallVector writtenSources; @@ -177,13 +172,13 @@ class AtomicTaskBuilder { writtenSources.push_back(source); } - // Step 5: Create the task operation - std::string taskName = "Task_" + std::to_string(this->global_task_idx); + // Create task + std::string taskName = "Task_" + std::to_string(global_task_idx); auto newTask = builder.create( loc, outputTypes, TypeRange{}, taskInputs, ValueRange{}, builder.getStringAttr(taskName)); - // Step 6: Create task body + // Create task body Block *taskBody = new Block(); newTask.getBody().push_back(taskBody); @@ -191,26 +186,24 @@ class AtomicTaskBuilder { taskBody->addArgument(input.getType(), loc); } - // Step 7: Build value mapping + // Build value mapping IRMapping mapping; - // Map source memrefs -> new task's block arguments for (auto [source, idx] : sourceToInputIdx) { BlockArgument newArg = taskBody->getArgument(idx); mapping.map(source, newArg); - // Also map original block arguments that refer to this source if (Value origArg = resolver.getBlockArg(source)) { mapping.map(origArg, newArg); } } - // Step 8: Clone counters and hyperblock + // Clone counters and hyperblock OpBuilder taskBuilder(taskBody, taskBody->begin()); cloneCounters(taskBuilder, hyperblock, mapping); cloneHyperblock(taskBuilder, hyperblock, mapping); - // Step 9: Create yield + // Create yield SmallVector yieldOperands; for (Value memref : memInfo.writes) { yieldOperands.push_back(mapping.lookupOrDefault(memref)); @@ -218,10 +211,10 @@ class AtomicTaskBuilder { taskBuilder.setInsertionPointToEnd(taskBody); taskBuilder.create(loc, yieldOperands, ValueRange{}); - // Step 10: Update latest versions + // Update latest versions auto outputs = newTask.getMemoryOutputs(); for (auto [source, output] : llvm::zip(writtenSources, outputs)) { - this->memref_to_latest_version[source] = output; + memref_to_latest_version[source] = output; } return newTask; @@ -229,8 +222,8 @@ class AtomicTaskBuilder { private: Value getLatestVersion(Value source) { - auto it = this->memref_to_latest_version.find(source); - return it != this->memref_to_latest_version.end() ? it->second : source; + auto it = memref_to_latest_version.find(source); + return it != memref_to_latest_version.end() ? it->second : source; } void cloneCounters(OpBuilder &taskBuilder, TaskflowHyperblockOp hyperblock, @@ -245,18 +238,15 @@ class AtomicTaskBuilder { void cloneHyperblock(OpBuilder &taskBuilder, TaskflowHyperblockOp hyperblock, IRMapping &mapping) { - // Map indices SmallVector mappedIndices; for (Value idx : hyperblock.getIndices()) { mappedIndices.push_back(mapping.lookupOrDefault(idx)); } - // Create new hyperblock SmallVector outputTypes(hyperblock.getOutputs().getTypes()); auto newHB = taskBuilder.create(loc, outputTypes, mappedIndices); - // Create body Block *newBody = new Block(); newHB.getBody().push_back(newBody); @@ -264,20 +254,17 @@ class AtomicTaskBuilder { newBody->addArgument(idx.getType(), loc); } - // Map old block args -> new block args Block *oldBody = &hyperblock.getBody().front(); for (auto [oldArg, newArg] : llvm::zip(oldBody->getArguments(), newBody->getArguments())) { mapping.map(oldArg, newArg); } - // Clone operations OpBuilder hbBuilder(newBody, newBody->begin()); for (Operation &op : oldBody->without_terminator()) { hbBuilder.clone(op, mapping); } - // Clone terminator if (auto yield = dyn_cast(oldBody->getTerminator())) { SmallVector yieldOps; @@ -289,15 +276,17 @@ class AtomicTaskBuilder { hbBuilder.create(loc, ValueRange{}); } } + OpBuilder &builder; Location loc; unsigned global_task_idx; DenseMap &memref_to_latest_version; }; -//--------------------------------------------------------------------------- -// Canonicalize Task Pass -//---------------------------------------------------------------------------- +//===----------------------------------------------------------------------===// +// Pass Implementation +//===----------------------------------------------------------------------===// + struct CanonicalizeTaskPass : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CanonicalizeTaskPass) @@ -307,7 +296,7 @@ struct CanonicalizeTaskPass StringRef getDescription() const final { return "Canonicalizes tasks by splitting each hyperblock into a separate " "atomic task (one hyperblock per task)"; - }; + } void getDependentDialects(DialectRegistry ®istry) const override { registry.insert tasks_to_process; func_op.walk( [&](TaskflowTaskOp task_op) { tasks_to_process.push_back(task_op); }); @@ -325,35 +313,82 @@ struct CanonicalizeTaskPass unsigned global_task_idx = 0; for (TaskflowTaskOp original_task : tasks_to_process) { - // Collects hyperblocks. SmallVector hyperblocks; original_task.walk( [&](TaskflowHyperblockOp hb) { hyperblocks.push_back(hb); }); + assert(!hyperblocks.empty() && "Expected at least one hyperblock in the task"); + if (hyperblocks.size() == 1) { - // No need to canonicalize single-hyperblock tasks. continue; } - // Tracks latest versions of memrefs for dependency chaining. - DenseMap memref_to_latest_version; + //===----------------------------------------------------------------===// + // Step 1: Build mapping from original task's memory outputs to their + // corresponding source memrefs (the original inputs). + //===----------------------------------------------------------------===// + + // Get the yield operation to find which memrefs are yielded + auto yield_op = cast( + original_task.getBody().front().getTerminator()); + auto original_outputs = original_task.getMemoryOutputs(); + auto yielded_memrefs = yield_op.getMemoryResults(); + + // Map: yielded block argument -> original task output + DenseMap yielded_to_output; + for (auto [yielded, output] : + llvm::zip(yielded_memrefs, original_outputs)) { + yielded_to_output[yielded] = output; + } + + // Map: original input memref -> original task output (if it's yielded) + // This tells us which original outputs correspond to which input memrefs + Block *orig_body = &original_task.getBody().front(); + auto orig_mem_inputs = original_task.getMemoryInputs(); + DenseMap source_to_original_output; + + for (auto [input, arg] : + llvm::zip(orig_mem_inputs, orig_body->getArguments())) { + if (yielded_to_output.count(arg)) { + source_to_original_output[input] = yielded_to_output[arg]; + } + } + + //===----------------------------------------------------------------===// + // Step 2: Create atomic tasks for each hyperblock. + //===----------------------------------------------------------------===// - // Creates atomic tasks for each hyperblock. + DenseMap memref_to_latest_version; OpBuilder builder(original_task); - for (TaskflowHyperblockOp hb : hyperblocks) { + for (size_t i = 0; i < hyperblocks.size(); ++i) { AtomicTaskBuilder task_builder(builder, original_task.getLoc(), - global_task_idx, + global_task_idx++, memref_to_latest_version); - task_builder.build(hb, original_task); + task_builder.build(hyperblocks[i], original_task); } - // Erases the original task. + //===----------------------------------------------------------------===// + // Step 3: Replace uses of original task outputs with the latest versions. + //===----------------------------------------------------------------===// + + for (auto [source, original_output] : source_to_original_output) { + if (memref_to_latest_version.count(source)) { + Value latest = memref_to_latest_version[source]; + original_output.replaceAllUsesWith(latest); + } + } + + //===----------------------------------------------------------------===// + // Step 4: Erase the original task. + //===----------------------------------------------------------------===// + original_task.erase(); } } }; + } // namespace std::unique_ptr mlir::taskflow::createCanonicalizeTaskPass() { diff --git a/lib/TaskflowDialect/Transforms/ConstructHyperblockFromTaskPass.cpp b/lib/TaskflowDialect/Transforms/ConstructHyperblockFromTaskPass.cpp index 7ba0506..69f0de4 100644 --- a/lib/TaskflowDialect/Transforms/ConstructHyperblockFromTaskPass.cpp +++ b/lib/TaskflowDialect/Transforms/ConstructHyperblockFromTaskPass.cpp @@ -6,11 +6,14 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/LogicalResult.h" #include "llvm/Support/raw_ostream.h" #include @@ -30,11 +33,11 @@ struct LoopInfo { int upper_bound; int step; - // For nested loops + // For nested loops. LoopInfo *parent_loop_info = nullptr; SmallVector child_loops; - // Generated counter index + // Generated counter index. Value counter_index; }; @@ -50,7 +53,7 @@ struct HyperblockInfo { // operations before any loops). SmallVector trigger_indices; - // Whther this hyperblock is nested within loops. + // Whether this hyperblock is nested within loops. bool is_loop_body = false; // The corresponding loop. @@ -175,27 +178,22 @@ getTopLevelLoopsInfo(SmallVector &loops_info) { // Hyperblock Creation //---------------------------------------------------------------------------- // Recursively extracts hyperblocks from a region. +// Key insight: Operations in a loop body that are used by nested loops should +// be inlined into the nested loop's hyperblock. static void extractHyperblocksInfoFromRegion( Region ®ion, const DenseMap &loop_info_map, SmallVector parent_indices, - SmallVector &hyperblocks_info) { + SmallVector &hyperblocks_info, + affine::AffineForOp enclosing_loop = nullptr, + SmallVector inherited_ops = {}) { Block &block = region.front(); SmallVector current_block_ops; + current_block_ops.append(inherited_ops.begin(), inherited_ops.end()); + for (Operation &op : block.getOperations()) { if (auto for_op = dyn_cast(&op)) { - // Before processing the loop, emits any accumulated operations as a - // hyperblock. - if (!current_block_ops.empty()) { - HyperblockInfo info; - info.operations = current_block_ops; - info.trigger_indices = parent_indices; - info.is_loop_body = !parent_indices.empty(); - hyperblocks_info.push_back(info); - current_block_ops.clear(); - } - // Gets the loop info. LoopInfo *loop_info = loop_info_map.lookup(for_op); assert(loop_info && "Loop not found in loop_info_map"); @@ -205,12 +203,51 @@ static void extractHyperblocksInfoFromRegion( SmallVector loop_indices = parent_indices; loop_indices.push_back(loop_info->counter_index); + // 分析哪些 current_ops 被这个循环使用 + DenseSet values_used_in_loop; + for_op.walk([&](Operation *nested_op) { + for (Value operand : nested_op->getOperands()) { + values_used_in_loop.insert(operand); + } + }); + + SmallVector ops_for_nested_loop; + SmallVector ops_not_used; + bool used_by_loop = false; + for (Operation *current_op : current_block_ops) { + for (Value result : current_op->getResults()) { + if (values_used_in_loop.contains(result)) { + used_by_loop = true; + break; + } + } + } + if (used_by_loop) { + ops_for_nested_loop.append(current_block_ops.begin(), + current_block_ops.end()); + } else { + ops_not_used.append(current_block_ops.begin(), current_block_ops.end()); + } + + // Before processing the loop, emits any accumulated operations as a + // hyperblock. + if (!ops_not_used.empty()) { + HyperblockInfo info; + info.operations = ops_not_used; + info.trigger_indices = parent_indices; + info.is_loop_body = !parent_indices.empty(); + info.loop_op = enclosing_loop; + hyperblocks_info.push_back(info); + } + // Recursively extracts hyperblocks from the loop body. extractHyperblocksInfoFromRegion(for_op.getRegion(), loop_info_map, - loop_indices, hyperblocks_info); + loop_indices, hyperblocks_info, for_op, + ops_for_nested_loop); + current_block_ops.clear(); } else if (isa(&op) || (isa(&op) && op.getOperands().empty())) { - // Skips TaskflowYieldOp and TaskflowCounterOp. + // Skips TaskflowYieldOp, TaskflowCounterOp, and empty affine.yield. continue; } else { // Regular operation, accumulates it. @@ -224,6 +261,7 @@ static void extractHyperblocksInfoFromRegion( info.operations = current_block_ops; info.trigger_indices = parent_indices; info.is_loop_body = !parent_indices.empty(); + info.loop_op = enclosing_loop; hyperblocks_info.push_back(info); current_block_ops.clear(); } @@ -299,8 +337,7 @@ determineHyperblockOutputTypes(const SmallVector &operations) { // Creates a taskflow.hyperblock operation from HyperblockInfo. static TaskflowHyperblockOp createHyperblock( - OpBuilder &builder, Location loc, const HyperblockInfo &info, - Block *task_body, + OpBuilder &builder, Location loc, HyperblockInfo &info, Block *task_body, const DenseMap &loop_info_map) { // Collects only the indices that are actually used in the hyperblock. SmallVector used_indices = @@ -310,9 +347,25 @@ static TaskflowHyperblockOp createHyperblock( SmallVector output_types = determineHyperblockOutputTypes(info.operations); + // Checks if there is a reduction in the hyperblock (with iter_args). + SmallVector iter_args_init_values; + bool is_reduction = false; + if (info.loop_op && info.loop_op.getNumIterOperands() > 0) { + is_reduction = true; + for (Value init : info.loop_op.getInits()) { + iter_args_init_values.push_back(init); + } + } // Creates the hyperblock operation. - TaskflowHyperblockOp hyperblock_op = - builder.create(loc, output_types, used_indices); + TaskflowHyperblockOp hyperblock_op; + if (is_reduction) { + hyperblock_op = builder.create( + loc, output_types, used_indices, iter_args_init_values); + } else { + hyperblock_op = builder.create( + loc, output_types, used_indices, /*iter_args=*/ValueRange{}); + } + Block *hyperblock_body = new Block(); hyperblock_op.getBody().push_back(hyperblock_body); @@ -321,6 +374,14 @@ static TaskflowHyperblockOp createHyperblock( hyperblock_body->addArgument(idx.getType(), loc); } + SmallVector iter_args_block_args; + if (is_reduction) { + for (Value init : iter_args_init_values) { + BlockArgument arg = hyperblock_body->addArgument(init.getType(), loc); + iter_args_block_args.push_back(arg); + } + } + // Clone operations into the hyperblock body. OpBuilder hyperblock_builder(hyperblock_body, hyperblock_body->begin()); IRMapping mapping; @@ -346,6 +407,17 @@ static TaskflowHyperblockOp createHyperblock( } } + // If this hyperblock comes from a loop with iter_args, maps them. + if (is_reduction) { + Block &loop_body = info.loop_op.getRegion().front(); + auto loop_iter_args = loop_body.getArguments().drop_front(1); + + for (auto [loop_iter_arg, hb_iter_arg] : + llvm::zip(loop_iter_args, iter_args_block_args)) { + mapping.map(loop_iter_arg, hb_iter_arg); + } + } + // Clones all operations and handle terminators. bool has_terminator = false; for (Operation *op : info.operations) { @@ -376,11 +448,11 @@ static TaskflowHyperblockOp createHyperblock( MLIRContext *context = hyperblock_op.getContext(); RewritePatternSet patterns(context); - populateAffineToStdConversionPatterns(patterns); ConversionTarget target(*context); target.addLegalDialect(); + func::FuncDialect, taskflow::TaskflowDialect, + scf::SCFDialect>(); target.addIllegalOp(); if (failed( @@ -432,19 +504,34 @@ static LogicalResult transformTask(TaskflowTaskOp task_op) { // Step 4: Creates taskflow.hyperblock operations for each hyperblock. builder.setInsertionPoint(first_loop_op); + // Creates hyperblock ops. + for (auto &info : hyperblocks_info) { + TaskflowHyperblockOp hyperblock_op = + createHyperblock(builder, loc, info, task_body, loop_info_map); + + // If this hyperblock has outputs and belongs to a loop with iter_args, + // replace the loop results with the hyperblock outputs. + if (info.loop_op && info.loop_op.getNumResults() > 0 && + (hyperblock_op.getNumResults() == info.loop_op.getNumResults())) { + auto loop_results = info.loop_op.getResults(); + auto hyperblock_results = hyperblock_op.getOutputs(); + + for (auto [loop_result, hb_result] : + llvm::zip(loop_results, hyperblock_results)) { + loop_result.replaceAllUsesWith(hb_result); + } + } + } + + // Step 6: Collects and erases original loop operations. // Collects all operations to erase. SmallVector ops_to_erase; for (Operation &op : llvm::make_early_inc_range(task_body->getOperations())) { - if (!isa(&op)) { + if (!isa(&op)) { ops_to_erase.push_back(&op); } } - // Creates hyperblock ops. - for (const auto &info : hyperblocks_info) { - createHyperblock(builder, loc, info, task_body, loop_info_map); - } - // Erases original operations. for (Operation *op : ops_to_erase) { op->erase(); @@ -467,8 +554,9 @@ struct ConstructHyperblockFromTaskPass } void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry + .insert(); } void runOnOperation() override { From 2c105db6cc104518e65eb35c8f8fe134a703a619 Mon Sep 17 00:00:00 2001 From: ShangkunLI Date: Sat, 17 Jan 2026 09:56:07 +0800 Subject: [PATCH 4/8] modify the canonicalize-task pass --- .../Transforms/CanonicalizeTaskPass.cpp | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/lib/TaskflowDialect/Transforms/CanonicalizeTaskPass.cpp b/lib/TaskflowDialect/Transforms/CanonicalizeTaskPass.cpp index e00281b..6bf3c3c 100644 --- a/lib/TaskflowDialect/Transforms/CanonicalizeTaskPass.cpp +++ b/lib/TaskflowDialect/Transforms/CanonicalizeTaskPass.cpp @@ -243,9 +243,14 @@ class AtomicTaskBuilder { mappedIndices.push_back(mapping.lookupOrDefault(idx)); } + SmallVector mapped_iter_args; + for (Value arg : hyperblock.getIterArgs()) { + mapped_iter_args.push_back(mapping.lookupOrDefault(arg)); + } + SmallVector outputTypes(hyperblock.getOutputs().getTypes()); - auto newHB = taskBuilder.create(loc, outputTypes, - mappedIndices); + auto newHB = taskBuilder.create( + loc, outputTypes, mappedIndices, mapped_iter_args); Block *newBody = new Block(); newHB.getBody().push_back(newBody); @@ -254,6 +259,10 @@ class AtomicTaskBuilder { newBody->addArgument(idx.getType(), loc); } + for (Value arg : mapped_iter_args) { + newBody->addArgument(arg.getType(), loc); + } + Block *oldBody = &hyperblock.getBody().front(); for (auto [oldArg, newArg] : llvm::zip(oldBody->getArguments(), newBody->getArguments())) { From cd47a98ec2f1f486081bc9b7d1c01fafd42cb878 Mon Sep 17 00:00:00 2001 From: ShangkunLI Date: Sun, 18 Jan 2026 20:10:53 +0800 Subject: [PATCH 5/8] converse hyperblock with result to task --- .../Transforms/CanonicalizeTaskPass.cpp | 294 +++++++++++++----- 1 file changed, 209 insertions(+), 85 deletions(-) diff --git a/lib/TaskflowDialect/Transforms/CanonicalizeTaskPass.cpp b/lib/TaskflowDialect/Transforms/CanonicalizeTaskPass.cpp index 6bf3c3c..8ebfdce 100644 --- a/lib/TaskflowDialect/Transforms/CanonicalizeTaskPass.cpp +++ b/lib/TaskflowDialect/Transforms/CanonicalizeTaskPass.cpp @@ -7,6 +7,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/IRMapping.h" @@ -24,30 +25,44 @@ using namespace mlir; using namespace mlir::taskflow; namespace { -//===----------------------------------------------------------------------===// -// Memory Access Info -//===----------------------------------------------------------------------===// - -struct MemoryAccessInfo { - SetVector reads; - SetVector writes; +//---------------------------------------------------------------------- +// Memory and Value Access Info +//---------------------------------------------------------------------- +// This struct analyzes memory accesses within a hyperblock. +struct AccessInfo { + SetVector memref_reads; + SetVector memref_writes; + SetVector value_reads; + + void analyze(TaskflowHyperblockOp hyperblock, Block *task_body) { + DenseSet task_block_args; + for (Value arg : task_body->getArguments()) { + task_block_args.insert(arg); + } - void analyze(TaskflowHyperblockOp hyperblock) { hyperblock.walk([&](Operation *op) { if (auto load = dyn_cast(op)) { - reads.insert(load.getMemRef()); + memref_reads.insert(load.getMemRef()); } else if (auto store = dyn_cast(op)) { - writes.insert(store.getMemRef()); + memref_writes.insert(store.getMemRef()); + } + + for (Value operand : op->getOperands()) { + if (task_block_args.contains(operand)) { + value_reads.insert(operand); + } } }); } SetVector getAllMemRefs() const { SetVector all; - all.insert(reads.begin(), reads.end()); - all.insert(writes.begin(), writes.end()); + all.insert(memref_reads.begin(), memref_reads.end()); + all.insert(memref_writes.begin(), memref_writes.end()); return all; } + + SetVector getAllValues() const { return value_reads; } }; //===----------------------------------------------------------------------===// @@ -106,10 +121,19 @@ class BlockArgResolver { public: explicit BlockArgResolver(TaskflowTaskOp task) { Block *body = &task.getBody().front(); - auto inputs = task.getMemoryInputs(); - auto args = body->getArguments(); - for (auto [input, arg] : llvm::zip(inputs, args)) { + // Resolves memory inputs. + auto mem_inputs = task.getMemoryInputs(); + auto mem_args = body->getArguments().take_front(mem_inputs.size()); + for (auto [input, arg] : llvm::zip(mem_inputs, mem_args)) { + blockArgToSource[arg] = input; + sourceToBlockArg[input] = arg; + } + + // Resolves value inputs. + auto val_inputs = task.getValueInputs(); + auto val_args = body->getArguments().drop_front(mem_inputs.size()); + for (auto [input, arg] : llvm::zip(val_inputs, val_args)) { blockArgToSource[arg] = input; sourceToBlockArg[input] = arg; } @@ -130,66 +154,100 @@ class BlockArgResolver { DenseMap sourceToBlockArg; }; -//===----------------------------------------------------------------------===// -// Atomic Task Builder -//===----------------------------------------------------------------------===// +//---------------------------------------------------------------------- +// Atomic Task Builder. +//---------------------------------------------------------------------- class AtomicTaskBuilder { public: AtomicTaskBuilder(OpBuilder &builder, Location loc, unsigned global_task_idx, - DenseMap &memref_to_latest_version) + DenseMap &memref_to_latest_version, + DenseMap &value_to_latest_version) : builder(builder), loc(loc), global_task_idx(global_task_idx), - memref_to_latest_version(memref_to_latest_version) {} + memref_to_latest_version(memref_to_latest_version), + value_to_latest_version(value_to_latest_version) {} TaskflowTaskOp build(TaskflowHyperblockOp hyperblock, - TaskflowTaskOp originalTask) { - MemoryAccessInfo memInfo; - memInfo.analyze(hyperblock); + TaskflowTaskOp original_task) { + AccessInfo mem_info; + mem_info.analyze(hyperblock, &original_task.getBody().front()); - BlockArgResolver resolver(originalTask); + BlockArgResolver resolver(original_task); - // Determine task inputs - SmallVector taskInputs; - DenseMap sourceToInputIdx; + // Determines memref inputs. + SmallVector memref_inputs; + DenseMap source_to_memref_input_idx; - for (Value memref : memInfo.getAllMemRefs()) { + for (Value memref : mem_info.getAllMemRefs()) { Value source = resolver.resolveToSource(memref); - Value inputVal = getLatestVersion(source); + Value inputVal = getLatestMemrefVersion(source); - if (!sourceToInputIdx.count(source)) { - sourceToInputIdx[source] = taskInputs.size(); - taskInputs.push_back(inputVal); + if (!source_to_memref_input_idx.count(source)) { + source_to_memref_input_idx[source] = memref_inputs.size(); + memref_inputs.push_back(inputVal); } } - // Determine task outputs - SmallVector outputTypes; - SmallVector writtenSources; + // Determines value inputs. + SmallVector value_inputs; + DenseMap source_to_value_input_idx; - for (Value memref : memInfo.writes) { + for (Value val : mem_info.getAllValues()) { + Value source = resolver.resolveToSource(val); + Value inputVal = getLatestValueVersion(source); + + if (!source_to_value_input_idx.count(source)) { + source_to_value_input_idx[source] = value_inputs.size(); + value_inputs.push_back(inputVal); + } + } + + // Determines memref outputs. + SmallVector memref_output_types; + SmallVector written_memref_sources; + + for (Value memref : mem_info.memref_writes) { Value source = resolver.resolveToSource(memref); - outputTypes.push_back(source.getType()); - writtenSources.push_back(source); + memref_output_types.push_back(source.getType()); + written_memref_sources.push_back(source); + } + + // Determines value outputs. + SmallVector value_output_types; + SmallVector yielded_value_sources; + + if (!hyperblock.getOutputs().empty()) { + for (Value output : hyperblock.getOutputs()) { + value_output_types.push_back(output.getType()); + // For value outputs, they are source themselves. + yielded_value_sources.push_back(output); + } } - // Create task + // Creates task. std::string taskName = "Task_" + std::to_string(global_task_idx); auto newTask = builder.create( - loc, outputTypes, TypeRange{}, taskInputs, ValueRange{}, - builder.getStringAttr(taskName)); + loc, memref_output_types, value_output_types, memref_inputs, + value_inputs, builder.getStringAttr(taskName)); - // Create task body + // Creates task body. Block *taskBody = new Block(); newTask.getBody().push_back(taskBody); - for (Value input : taskInputs) { + // Adds memref input arguments. + for (Value input : memref_inputs) { + taskBody->addArgument(input.getType(), loc); + } + // Adds value input arguments. + for (Value input : value_inputs) { taskBody->addArgument(input.getType(), loc); } - // Build value mapping + // Builds value mapping. IRMapping mapping; - for (auto [source, idx] : sourceToInputIdx) { + // Maps memref inputs. + for (auto [source, idx] : source_to_memref_input_idx) { BlockArgument newArg = taskBody->getArgument(idx); mapping.map(source, newArg); @@ -198,32 +256,76 @@ class AtomicTaskBuilder { } } - // Clone counters and hyperblock + // Maps value inputs. + size_t value_arg_offset = memref_inputs.size(); + for (auto [source, idx] : source_to_value_input_idx) { + BlockArgument newArg = taskBody->getArgument(value_arg_offset + idx); + mapping.map(source, newArg); + + if (Value origArg = resolver.getBlockArg(source)) { + mapping.map(origArg, newArg); + } + } + + // Clones counters and hyperblock. OpBuilder taskBuilder(taskBody, taskBody->begin()); cloneCounters(taskBuilder, hyperblock, mapping); cloneHyperblock(taskBuilder, hyperblock, mapping); - // Create yield - SmallVector yieldOperands; - for (Value memref : memInfo.writes) { - yieldOperands.push_back(mapping.lookupOrDefault(memref)); + // Creates yield. + SmallVector memref_yield_operands; + for (Value memref : mem_info.memref_writes) { + memref_yield_operands.push_back(mapping.lookupOrDefault(memref)); + } + + SmallVector value_yield_operands; + // If this hyperblock has value outputs, we need to yield them from the + // mapped hyperblock. + if (!hyperblock.getOutputs().empty()) { + // Finds the cloned hyperblock op. + TaskflowHyperblockOp cloned_hb = nullptr; + for (Operation &op : taskBody->getOperations()) { + if (auto hb = dyn_cast(op)) { + cloned_hb = hb; + break; + } + if (cloned_hb) { + for (Value output : cloned_hb.getOutputs()) { + value_yield_operands.push_back(output); + } + } + } } + taskBuilder.setInsertionPointToEnd(taskBody); - taskBuilder.create(loc, yieldOperands, ValueRange{}); + taskBuilder.create(loc, memref_yield_operands, + value_yield_operands); + + // Updates latest versions. + auto memref_outputs = newTask.getMemoryOutputs(); + for (auto [source, output] : + llvm::zip(written_memref_sources, memref_outputs)) { + this->memref_to_latest_version[source] = output; + } - // Update latest versions - auto outputs = newTask.getMemoryOutputs(); - for (auto [source, output] : llvm::zip(writtenSources, outputs)) { - memref_to_latest_version[source] = output; + auto value_outputs = newTask.getValueOutputs(); + for (auto [source, output] : + llvm::zip(yielded_value_sources, value_outputs)) { + this->value_to_latest_version[source] = output; } return newTask; } private: - Value getLatestVersion(Value source) { - auto it = memref_to_latest_version.find(source); - return it != memref_to_latest_version.end() ? it->second : source; + Value getLatestMemrefVersion(Value source) { + auto it = this->memref_to_latest_version.find(source); + return it != this->memref_to_latest_version.end() ? it->second : source; + } + + Value getLatestValueVersion(Value source) { + auto it = this->value_to_latest_version.find(source); + return it != this->value_to_latest_version.end() ? it->second : source; } void cloneCounters(OpBuilder &taskBuilder, TaskflowHyperblockOp hyperblock, @@ -290,11 +392,12 @@ class AtomicTaskBuilder { Location loc; unsigned global_task_idx; DenseMap &memref_to_latest_version; + DenseMap &value_to_latest_version; }; -//===----------------------------------------------------------------------===// -// Pass Implementation -//===----------------------------------------------------------------------===// +//---------------------------------------------------------------------- +// Pass Implementation. +//---------------------------------------------------------------------- struct CanonicalizeTaskPass : public PassWrapper> { @@ -308,8 +411,9 @@ struct CanonicalizeTaskPass } void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry + .insert(); } void runOnOperation() override { @@ -322,6 +426,7 @@ struct CanonicalizeTaskPass unsigned global_task_idx = 0; for (TaskflowTaskOp original_task : tasks_to_process) { + // Collects hyperblocks within the original task. SmallVector hyperblocks; original_task.walk( [&](TaskflowHyperblockOp hb) { hyperblocks.push_back(hb); }); @@ -329,62 +434,81 @@ struct CanonicalizeTaskPass assert(!hyperblocks.empty() && "Expected at least one hyperblock in the task"); + // If there's only one hyperblock, it is already canonical. if (hyperblocks.size() == 1) { continue; } - //===----------------------------------------------------------------===// - // Step 1: Build mapping from original task's memory outputs to their + //---------------------------------------------------------------- + // Step 1: Builds mapping from original task's memory outputs to their // corresponding source memrefs (the original inputs). - //===----------------------------------------------------------------===// + //---------------------------------------------------------------- - // Get the yield operation to find which memrefs are yielded + // Get the yield operation to find which memrefs are yielded. auto yield_op = cast( original_task.getBody().front().getTerminator()); - auto original_outputs = original_task.getMemoryOutputs(); + auto original_mem_outputs = original_task.getMemoryOutputs(); + auto original_val_outputs = original_task.getValueOutputs(); auto yielded_memrefs = yield_op.getMemoryResults(); + auto yielded_values = yield_op.getValueResults(); - // Map: yielded block argument -> original task output + // Map: yielded -> original task output. DenseMap yielded_to_output; for (auto [yielded, output] : - llvm::zip(yielded_memrefs, original_outputs)) { + llvm::zip(yielded_memrefs, original_mem_outputs)) { + yielded_to_output[yielded] = output; + } + for (auto [yielded, output] : + llvm::zip(yielded_values, original_val_outputs)) { yielded_to_output[yielded] = output; } - // Map: original input memref -> original task output (if it's yielded) - // This tells us which original outputs correspond to which input memrefs + // Map: original input memref -> original task output (if it's yielded). + // This tells us which original outputs correspond to which input memrefs. Block *orig_body = &original_task.getBody().front(); auto orig_mem_inputs = original_task.getMemoryInputs(); + auto orig_val_inputs = original_task.getValueInputs(); + DenseMap source_to_original_output; - for (auto [input, arg] : - llvm::zip(orig_mem_inputs, orig_body->getArguments())) { + for (auto [input, arg] : llvm::zip( + orig_mem_inputs, + orig_body->getArguments().take_front(orig_mem_inputs.size()))) { if (yielded_to_output.count(arg)) { source_to_original_output[input] = yielded_to_output[arg]; } } - //===----------------------------------------------------------------===// - // Step 2: Create atomic tasks for each hyperblock. - //===----------------------------------------------------------------===// - + //---------------------------------------------------------------- + // Step 2: Creates atomic tasks for each hyperblock. + //---------------------------------------------------------------- + // Records the mapping from source memref to the latest version after + // executing each atomic task. DenseMap memref_to_latest_version; + DenseMap value_to_latest_version; OpBuilder builder(original_task); for (size_t i = 0; i < hyperblocks.size(); ++i) { - AtomicTaskBuilder task_builder(builder, original_task.getLoc(), - global_task_idx++, - memref_to_latest_version); + AtomicTaskBuilder task_builder( + builder, original_task.getLoc(), global_task_idx++, + memref_to_latest_version, value_to_latest_version); task_builder.build(hyperblocks[i], original_task); } - //===----------------------------------------------------------------===// - // Step 3: Replace uses of original task outputs with the latest versions. - //===----------------------------------------------------------------===// + //---------------------------------------------------------------- + // Step 3: Replaces uses of original task outputs with the latest + // versions. + //---------------------------------------------------------------- for (auto [source, original_output] : source_to_original_output) { + Value latest = nullptr; if (memref_to_latest_version.count(source)) { - Value latest = memref_to_latest_version[source]; + latest = memref_to_latest_version[source]; + } else if (value_to_latest_version.count(source)) { + latest = value_to_latest_version[source]; + } + + if (latest) { original_output.replaceAllUsesWith(latest); } } From beb4ce81cbd0a6bbb33fd7d9e1eb1337778c6b39 Mon Sep 17 00:00:00 2001 From: ShangkunLI Date: Sun, 18 Jan 2026 20:31:52 +0800 Subject: [PATCH 6/8] enable canonicalize task pass functionality --- .../Transforms/CanonicalizeTaskPass.cpp | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/lib/TaskflowDialect/Transforms/CanonicalizeTaskPass.cpp b/lib/TaskflowDialect/Transforms/CanonicalizeTaskPass.cpp index 8ebfdce..bc1665f 100644 --- a/lib/TaskflowDialect/Transforms/CanonicalizeTaskPass.cpp +++ b/lib/TaskflowDialect/Transforms/CanonicalizeTaskPass.cpp @@ -3,7 +3,6 @@ #include "TaskflowDialect/TaskflowPasses.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" @@ -11,15 +10,12 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/IRMapping.h" -#include "mlir/IR/PatternMatch.h" #include "mlir/IR/Unit.h" #include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/Support/LogicalResult.h" -#include "llvm/Support/raw_ostream.h" using namespace mlir; using namespace mlir::taskflow; @@ -426,6 +422,7 @@ struct CanonicalizeTaskPass unsigned global_task_idx = 0; for (TaskflowTaskOp original_task : tasks_to_process) { + OpBuilder builder(original_task); // Collects hyperblocks within the original task. SmallVector hyperblocks; original_task.walk( @@ -436,6 +433,8 @@ struct CanonicalizeTaskPass // If there's only one hyperblock, it is already canonical. if (hyperblocks.size() == 1) { + std::string task_name = "Task_" + std::to_string(global_task_idx++); + original_task.setTaskNameAttr(builder.getStringAttr(task_name)); continue; } @@ -471,6 +470,7 @@ struct CanonicalizeTaskPass DenseMap source_to_original_output; + // Maps memref inputs. for (auto [input, arg] : llvm::zip( orig_mem_inputs, orig_body->getArguments().take_front(orig_mem_inputs.size()))) { @@ -479,6 +479,15 @@ struct CanonicalizeTaskPass } } + // Maps value inputs. + for (auto [input, arg] : llvm::zip( + orig_val_inputs, + orig_body->getArguments().drop_front(orig_mem_inputs.size()))) { + if (yielded_to_output.count(arg)) { + source_to_original_output[input] = yielded_to_output[arg]; + } + } + //---------------------------------------------------------------- // Step 2: Creates atomic tasks for each hyperblock. //---------------------------------------------------------------- @@ -486,7 +495,6 @@ struct CanonicalizeTaskPass // executing each atomic task. DenseMap memref_to_latest_version; DenseMap value_to_latest_version; - OpBuilder builder(original_task); for (size_t i = 0; i < hyperblocks.size(); ++i) { AtomicTaskBuilder task_builder( From 5281b979358ef1356f32790d7c09d281b16b1eaf Mon Sep 17 00:00:00 2001 From: ShangkunLI Date: Sun, 18 Jan 2026 21:35:49 +0800 Subject: [PATCH 7/8] add test & comments --- .../Transforms/CanonicalizeTaskPass.cpp | 220 +++++++------- .../irregular-loop/irregular-loop.mlir | 216 +++++++++++++ .../taskflow/multi-nested/multi-nested.mlir | 285 ++++++++++++------ .../parallel-nested/parallel-nested.mlir | 56 +++- 4 files changed, 573 insertions(+), 204 deletions(-) create mode 100644 test/multi-cgra/taskflow/irregular-loop/irregular-loop.mlir diff --git a/lib/TaskflowDialect/Transforms/CanonicalizeTaskPass.cpp b/lib/TaskflowDialect/Transforms/CanonicalizeTaskPass.cpp index bc1665f..151226c 100644 --- a/lib/TaskflowDialect/Transforms/CanonicalizeTaskPass.cpp +++ b/lib/TaskflowDialect/Transforms/CanonicalizeTaskPass.cpp @@ -22,12 +22,15 @@ using namespace mlir::taskflow; namespace { //---------------------------------------------------------------------- -// Memory and Value Access Info +// Memory and Value Access Info. //---------------------------------------------------------------------- -// This struct analyzes memory accesses within a hyperblock. +// This struct analyzes accesses information within a hyperblock. struct AccessInfo { + // Set of read memrefs. SetVector memref_reads; + // Set of written memrefs. SetVector memref_writes; + // Set of read values. SetVector value_reads; void analyze(TaskflowHyperblockOp hyperblock, Block *task_body) { @@ -38,14 +41,14 @@ struct AccessInfo { hyperblock.walk([&](Operation *op) { if (auto load = dyn_cast(op)) { - memref_reads.insert(load.getMemRef()); + this->memref_reads.insert(load.getMemRef()); } else if (auto store = dyn_cast(op)) { - memref_writes.insert(store.getMemRef()); + this->memref_writes.insert(store.getMemRef()); } for (Value operand : op->getOperands()) { if (task_block_args.contains(operand)) { - value_reads.insert(operand); + this->value_reads.insert(operand); } } }); @@ -53,18 +56,18 @@ struct AccessInfo { SetVector getAllMemRefs() const { SetVector all; - all.insert(memref_reads.begin(), memref_reads.end()); - all.insert(memref_writes.begin(), memref_writes.end()); + all.insert(this->memref_reads.begin(), this->memref_reads.end()); + all.insert(this->memref_writes.begin(), this->memref_writes.end()); return all; } - SetVector getAllValues() const { return value_reads; } + SetVector getAllValues() const { return this->value_reads; } }; -//===----------------------------------------------------------------------===// -// Counter Collector -//===----------------------------------------------------------------------===// - +//---------------------------------------------------------------------- +// Counter Collector. +//---------------------------------------------------------------------- +// This class is used to collects all counters needed by a hyperblock. class CounterCollector { public: void collect(TaskflowHyperblockOp hyperblock) { @@ -73,8 +76,10 @@ class CounterCollector { } } + // Gets the collected counters sorted by their depth. SmallVector getSortedCounters() const { - SmallVector result(counters.begin(), counters.end()); + SmallVector result(this->counters.begin(), + this->counters.end()); llvm::sort(result, [this](TaskflowCounterOp a, TaskflowCounterOp b) { return getDepth(a) < getDepth(b); }); @@ -82,22 +87,25 @@ class CounterCollector { } private: + // Collects counters recursively. void collectRecursively(Value idx) { - auto counter = idx.getDefiningOp(); - if (!counter) + TaskflowCounterOp counter = idx.getDefiningOp(); + if (!counter) { return; - counters.insert(counter); + } + this->counters.insert(counter); if (Value parent = counter.getParentIndex()) { collectRecursively(parent); } } + // Gets the depth of a counter. size_t getDepth(TaskflowCounterOp counter) const { size_t depth = 0; Value parent = counter.getParentIndex(); while (parent) { depth++; - if (auto p = parent.getDefiningOp()) { + if (TaskflowCounterOp p = parent.getDefiningOp()) { parent = p.getParentIndex(); } else { break; @@ -109,10 +117,18 @@ class CounterCollector { SetVector counters; }; -//===----------------------------------------------------------------------===// -// Block Argument Resolver -//===----------------------------------------------------------------------===// - +//---------------------------------------------------------------------- +// Block Argument Resolver. +//---------------------------------------------------------------------- +// This class resolves the input arguments of a task block to their source +// values. +// For example: +// taskflow.task(%buf_input, %val_input) { +// ^bb0(%arg0: memref, %arg1: i32): // ← block arguments +// // %arg0 corresponds to %buf_input +// // %arg1 corresponds to %val_input +// } +// resolveToSource(%arg0) -> %buf_input class BlockArgResolver { public: explicit BlockArgResolver(TaskflowTaskOp task) { @@ -122,38 +138,42 @@ class BlockArgResolver { auto mem_inputs = task.getMemoryInputs(); auto mem_args = body->getArguments().take_front(mem_inputs.size()); for (auto [input, arg] : llvm::zip(mem_inputs, mem_args)) { - blockArgToSource[arg] = input; - sourceToBlockArg[input] = arg; + this->block_arg_to_source[arg] = input; + this->source_to_block_arg[input] = arg; } // Resolves value inputs. auto val_inputs = task.getValueInputs(); auto val_args = body->getArguments().drop_front(mem_inputs.size()); for (auto [input, arg] : llvm::zip(val_inputs, val_args)) { - blockArgToSource[arg] = input; - sourceToBlockArg[input] = arg; + this->block_arg_to_source[arg] = input; + this->source_to_block_arg[input] = arg; } } + // Gets the source value for a given block argument. Value resolveToSource(Value val) const { - auto it = blockArgToSource.find(val); - return it != blockArgToSource.end() ? it->second : val; + auto it = this->block_arg_to_source.find(val); + return it != this->block_arg_to_source.end() ? it->second : val; } + // Gets the block argument for a given source value. Value getBlockArg(Value source) const { - auto it = sourceToBlockArg.find(source); - return it != sourceToBlockArg.end() ? it->second : Value(); + auto it = this->source_to_block_arg.find(source); + return it != this->source_to_block_arg.end() ? it->second : Value(); } private: - DenseMap blockArgToSource; - DenseMap sourceToBlockArg; + // Maps block argument to its source value. + DenseMap block_arg_to_source; + // Maps source value to its block argument. + DenseMap source_to_block_arg; }; //---------------------------------------------------------------------- // Atomic Task Builder. //---------------------------------------------------------------------- - +// This class builds an atomic task from a hyperblock. class AtomicTaskBuilder { public: AtomicTaskBuilder(OpBuilder &builder, Location loc, unsigned global_task_idx, @@ -165,8 +185,8 @@ class AtomicTaskBuilder { TaskflowTaskOp build(TaskflowHyperblockOp hyperblock, TaskflowTaskOp original_task) { - AccessInfo mem_info; - mem_info.analyze(hyperblock, &original_task.getBody().front()); + AccessInfo access_info; + access_info.analyze(hyperblock, &original_task.getBody().front()); BlockArgResolver resolver(original_task); @@ -174,13 +194,13 @@ class AtomicTaskBuilder { SmallVector memref_inputs; DenseMap source_to_memref_input_idx; - for (Value memref : mem_info.getAllMemRefs()) { + for (Value memref : access_info.getAllMemRefs()) { Value source = resolver.resolveToSource(memref); - Value inputVal = getLatestMemrefVersion(source); + Value input_memref = getLatestMemrefVersion(source); if (!source_to_memref_input_idx.count(source)) { source_to_memref_input_idx[source] = memref_inputs.size(); - memref_inputs.push_back(inputVal); + memref_inputs.push_back(input_memref); } } @@ -188,21 +208,22 @@ class AtomicTaskBuilder { SmallVector value_inputs; DenseMap source_to_value_input_idx; - for (Value val : mem_info.getAllValues()) { + for (Value val : access_info.getAllValues()) { Value source = resolver.resolveToSource(val); - Value inputVal = getLatestValueVersion(source); + Value input_val = getLatestValueVersion(source); if (!source_to_value_input_idx.count(source)) { source_to_value_input_idx[source] = value_inputs.size(); - value_inputs.push_back(inputVal); + value_inputs.push_back(input_val); } } // Determines memref outputs. SmallVector memref_output_types; + // The source memrefs of the written memrefs. SmallVector written_memref_sources; - for (Value memref : mem_info.memref_writes) { + for (Value memref : access_info.memref_writes) { Value source = resolver.resolveToSource(memref); memref_output_types.push_back(source.getType()); written_memref_sources.push_back(source); @@ -220,23 +241,23 @@ class AtomicTaskBuilder { } } - // Creates task. - std::string taskName = "Task_" + std::to_string(global_task_idx); - auto newTask = builder.create( - loc, memref_output_types, value_output_types, memref_inputs, - value_inputs, builder.getStringAttr(taskName)); + // Creates a new task. + std::string task_name = "Task_" + std::to_string(this->global_task_idx); + auto new_task = builder.create( + this->loc, memref_output_types, value_output_types, memref_inputs, + value_inputs, builder.getStringAttr(task_name)); - // Creates task body. - Block *taskBody = new Block(); - newTask.getBody().push_back(taskBody); + // Creates the task body. + Block *task_body = new Block(); + new_task.getBody().push_back(task_body); // Adds memref input arguments. for (Value input : memref_inputs) { - taskBody->addArgument(input.getType(), loc); + task_body->addArgument(input.getType(), this->loc); } // Adds value input arguments. for (Value input : value_inputs) { - taskBody->addArgument(input.getType(), loc); + task_body->addArgument(input.getType(), this->loc); } // Builds value mapping. @@ -244,33 +265,33 @@ class AtomicTaskBuilder { // Maps memref inputs. for (auto [source, idx] : source_to_memref_input_idx) { - BlockArgument newArg = taskBody->getArgument(idx); - mapping.map(source, newArg); + BlockArgument new_arg = task_body->getArgument(idx); + mapping.map(source, new_arg); - if (Value origArg = resolver.getBlockArg(source)) { - mapping.map(origArg, newArg); + if (Value orig_arg = resolver.getBlockArg(source)) { + mapping.map(orig_arg, new_arg); } } // Maps value inputs. size_t value_arg_offset = memref_inputs.size(); for (auto [source, idx] : source_to_value_input_idx) { - BlockArgument newArg = taskBody->getArgument(value_arg_offset + idx); - mapping.map(source, newArg); + BlockArgument new_arg = task_body->getArgument(value_arg_offset + idx); + mapping.map(source, new_arg); - if (Value origArg = resolver.getBlockArg(source)) { - mapping.map(origArg, newArg); + if (Value orig_arg = resolver.getBlockArg(source)) { + mapping.map(orig_arg, new_arg); } } // Clones counters and hyperblock. - OpBuilder taskBuilder(taskBody, taskBody->begin()); - cloneCounters(taskBuilder, hyperblock, mapping); - cloneHyperblock(taskBuilder, hyperblock, mapping); + OpBuilder task_builder(task_body, task_body->begin()); + cloneCounters(task_builder, hyperblock, mapping); + cloneHyperblock(task_builder, hyperblock, mapping); // Creates yield. SmallVector memref_yield_operands; - for (Value memref : mem_info.memref_writes) { + for (Value memref : access_info.memref_writes) { memref_yield_operands.push_back(mapping.lookupOrDefault(memref)); } @@ -280,7 +301,7 @@ class AtomicTaskBuilder { if (!hyperblock.getOutputs().empty()) { // Finds the cloned hyperblock op. TaskflowHyperblockOp cloned_hb = nullptr; - for (Operation &op : taskBody->getOperations()) { + for (Operation &op : task_body->getOperations()) { if (auto hb = dyn_cast(op)) { cloned_hb = hb; break; @@ -293,24 +314,24 @@ class AtomicTaskBuilder { } } - taskBuilder.setInsertionPointToEnd(taskBody); - taskBuilder.create(loc, memref_yield_operands, - value_yield_operands); + task_builder.setInsertionPointToEnd(task_body); + task_builder.create(this->loc, memref_yield_operands, + value_yield_operands); // Updates latest versions. - auto memref_outputs = newTask.getMemoryOutputs(); + auto memref_outputs = new_task.getMemoryOutputs(); for (auto [source, output] : llvm::zip(written_memref_sources, memref_outputs)) { this->memref_to_latest_version[source] = output; } - auto value_outputs = newTask.getValueOutputs(); + auto value_outputs = new_task.getValueOutputs(); for (auto [source, output] : llvm::zip(yielded_value_sources, value_outputs)) { this->value_to_latest_version[source] = output; } - return newTask; + return new_task; } private: @@ -324,21 +345,21 @@ class AtomicTaskBuilder { return it != this->value_to_latest_version.end() ? it->second : source; } - void cloneCounters(OpBuilder &taskBuilder, TaskflowHyperblockOp hyperblock, + void cloneCounters(OpBuilder &task_builder, TaskflowHyperblockOp hyperblock, IRMapping &mapping) { CounterCollector collector; collector.collect(hyperblock); for (TaskflowCounterOp counter : collector.getSortedCounters()) { - taskBuilder.clone(*counter.getOperation(), mapping); + task_builder.clone(*counter.getOperation(), mapping); } } - void cloneHyperblock(OpBuilder &taskBuilder, TaskflowHyperblockOp hyperblock, + void cloneHyperblock(OpBuilder &task_builder, TaskflowHyperblockOp hyperblock, IRMapping &mapping) { - SmallVector mappedIndices; + SmallVector mapped_indices; for (Value idx : hyperblock.getIndices()) { - mappedIndices.push_back(mapping.lookupOrDefault(idx)); + mapped_indices.push_back(mapping.lookupOrDefault(idx)); } SmallVector mapped_iter_args; @@ -346,41 +367,41 @@ class AtomicTaskBuilder { mapped_iter_args.push_back(mapping.lookupOrDefault(arg)); } - SmallVector outputTypes(hyperblock.getOutputs().getTypes()); - auto newHB = taskBuilder.create( - loc, outputTypes, mappedIndices, mapped_iter_args); + SmallVector output_types(hyperblock.getOutputs().getTypes()); + auto newHB = task_builder.create( + this->loc, output_types, mapped_indices, mapped_iter_args); - Block *newBody = new Block(); - newHB.getBody().push_back(newBody); + Block *new_body = new Block(); + newHB.getBody().push_back(new_body); - for (Value idx : mappedIndices) { - newBody->addArgument(idx.getType(), loc); + for (Value idx : mapped_indices) { + new_body->addArgument(idx.getType(), this->loc); } for (Value arg : mapped_iter_args) { - newBody->addArgument(arg.getType(), loc); + new_body->addArgument(arg.getType(), this->loc); } - Block *oldBody = &hyperblock.getBody().front(); - for (auto [oldArg, newArg] : - llvm::zip(oldBody->getArguments(), newBody->getArguments())) { - mapping.map(oldArg, newArg); + Block *old_body = &hyperblock.getBody().front(); + for (auto [old_arg, new_arg] : + llvm::zip(old_body->getArguments(), new_body->getArguments())) { + mapping.map(old_arg, new_arg); } - OpBuilder hbBuilder(newBody, newBody->begin()); - for (Operation &op : oldBody->without_terminator()) { - hbBuilder.clone(op, mapping); + OpBuilder hb_builder(new_body, new_body->begin()); + for (Operation &op : old_body->without_terminator()) { + hb_builder.clone(op, mapping); } if (auto yield = - dyn_cast(oldBody->getTerminator())) { - SmallVector yieldOps; + dyn_cast(old_body->getTerminator())) { + SmallVector yield_ops; for (Value v : yield.getOutputs()) { - yieldOps.push_back(mapping.lookupOrDefault(v)); + yield_ops.push_back(mapping.lookupOrDefault(v)); } - hbBuilder.create(loc, yieldOps); + hb_builder.create(this->loc, yield_ops); } else { - hbBuilder.create(loc, ValueRange{}); + hb_builder.create(this->loc, ValueRange{}); } } @@ -442,8 +463,7 @@ struct CanonicalizeTaskPass // Step 1: Builds mapping from original task's memory outputs to their // corresponding source memrefs (the original inputs). //---------------------------------------------------------------- - - // Get the yield operation to find which memrefs are yielded. + // Gets the yield operation to find which memrefs are yielded. auto yield_op = cast( original_task.getBody().front().getTerminator()); auto original_mem_outputs = original_task.getMemoryOutputs(); @@ -507,7 +527,6 @@ struct CanonicalizeTaskPass // Step 3: Replaces uses of original task outputs with the latest // versions. //---------------------------------------------------------------- - for (auto [source, original_output] : source_to_original_output) { Value latest = nullptr; if (memref_to_latest_version.count(source)) { @@ -521,10 +540,9 @@ struct CanonicalizeTaskPass } } - //===----------------------------------------------------------------===// + //---------------------------------------------------------------- // Step 4: Erase the original task. - //===----------------------------------------------------------------===// - + //---------------------------------------------------------------- original_task.erase(); } } diff --git a/test/multi-cgra/taskflow/irregular-loop/irregular-loop.mlir b/test/multi-cgra/taskflow/irregular-loop/irregular-loop.mlir new file mode 100644 index 0000000..6ce8e5e --- /dev/null +++ b/test/multi-cgra/taskflow/irregular-loop/irregular-loop.mlir @@ -0,0 +1,216 @@ +// RUN: mlir-neura-opt %s --convert-affine-to-taskflow \ +// RUN: -o %t.taskflow.mlir +// RUN: FileCheck %s --input-file=%t.taskflow.mlir --check-prefixes=TASKFLOW + +// RUN: mlir-neura-opt %s --convert-affine-to-taskflow \ +// RUN: --construct-hyperblock-from-task \ +// RUN: -o %t.hyperblock.mlir +// RUN: FileCheck %s --input-file=%t.hyperblock.mlir --check-prefixes=HYPERBLOCK + +// RUN: mlir-neura-opt %s --convert-affine-to-taskflow \ +// RUN: --construct-hyperblock-from-task \ +// RUN: --canonicalize-task \ +// RUN: -o %t.canonicalized.mlir +// RUN: FileCheck %s --input-file=%t.canonicalized.mlir --check-prefixes=CANONICALIZE + +#set = affine_set<(d0, d1) : (d0 - 3 == 0, d1 - 7 == 0)> +module attributes {} { + func.func @_Z21irregularLoopExample1v() -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c2_i32 = arith.constant 2 : i32 + %c8_i32 = arith.constant 8 : i32 + %c0_i32 = arith.constant 0 : i32 + %alloca = memref.alloca() : memref + %alloca_0 = memref.alloca() : memref<4x8xi32> + %0 = affine.for %arg0 = 0 to 5 iter_args(%arg1 = %c0_i32) -> (i32) { + %2 = arith.index_cast %arg0 : index to i32 + %3 = arith.addi %arg1, %2 : i32 + affine.yield %3 : i32 + } + affine.for %arg0 = 0 to 4 { + %2 = arith.index_cast %arg0 : index to i32 + %3 = arith.muli %2, %c8_i32 : i32 + affine.for %arg1 = 0 to 8 { + %4 = arith.index_cast %arg1 : index to i32 + %5 = arith.addi %3, %4 : i32 + affine.store %5, %alloca_0[%arg0, %arg1] : memref<4x8xi32> + } + affine.for %arg1 = 0 to 8 { + %4 = affine.load %alloca_0[%arg0, %arg1] : memref<4x8xi32> + %5 = arith.addi %4, %0 : i32 + affine.if #set(%arg0, %arg1) { + affine.store %5, %alloca[] : memref + %6 = arith.muli %5, %c2_i32 : i32 + affine.store %6, %alloca[] : memref + } + } + } + %1 = affine.load %alloca[] : memref + return %1 : i32 + } +} + +// TASKFLOW: #set = affine_set<(d0, d1) : (d0 - 3 == 0, d1 - 7 == 0)> +// TASKFLOW-NEXT: module { +// TASKFLOW-NEXT: func.func @_Z21irregularLoopExample1v() -> i32 attributes {llvm.linkage = #llvm.linkage} { +// TASKFLOW-NEXT: %c2_i32 = arith.constant 2 : i32 +// TASKFLOW-NEXT: %c8_i32 = arith.constant 8 : i32 +// TASKFLOW-NEXT: %c0_i32 = arith.constant 0 : i32 +// TASKFLOW-NEXT: %alloca = memref.alloca() : memref +// TASKFLOW-NEXT: %alloca_0 = memref.alloca() : memref<4x8xi32> +// TASKFLOW-NEXT: %value_outputs = "taskflow.task"(%c0_i32) <{operandSegmentSizes = array, resultSegmentSizes = array, task_name = "Task_0"}> ({ +// TASKFLOW-NEXT: ^bb0(%arg0: i32): +// TASKFLOW-NEXT: %1 = affine.for %arg1 = 0 to 5 iter_args(%arg2 = %arg0) -> (i32) { +// TASKFLOW-NEXT: %2 = arith.index_cast %arg1 : index to i32 +// TASKFLOW-NEXT: %3 = arith.addi %arg2, %2 : i32 +// TASKFLOW-NEXT: affine.yield %3 : i32 +// TASKFLOW-NEXT: } +// TASKFLOW-NEXT: "taskflow.yield"(%1) <{operandSegmentSizes = array}> : (i32) -> () +// TASKFLOW-NEXT: }) : (i32) -> i32 +// TASKFLOW-NEXT: %memory_outputs:2 = "taskflow.task"(%alloca_0, %alloca, %c8_i32, %value_outputs, %c2_i32) <{operandSegmentSizes = array, resultSegmentSizes = array, task_name = "Task_1"}> ({ +// TASKFLOW-NEXT: ^bb0(%arg0: memref<4x8xi32>, %arg1: memref, %arg2: i32, %arg3: i32, %arg4: i32): +// TASKFLOW-NEXT: affine.for %arg5 = 0 to 4 { +// TASKFLOW-NEXT: %1 = arith.index_cast %arg5 : index to i32 +// TASKFLOW-NEXT: %2 = arith.muli %1, %arg2 : i32 +// TASKFLOW-NEXT: affine.for %arg6 = 0 to 8 { +// TASKFLOW-NEXT: %3 = arith.index_cast %arg6 : index to i32 +// TASKFLOW-NEXT: %4 = arith.addi %2, %3 : i32 +// TASKFLOW-NEXT: affine.store %4, %arg0[%arg5, %arg6] : memref<4x8xi32> +// TASKFLOW-NEXT: } +// TASKFLOW-NEXT: affine.for %arg6 = 0 to 8 { +// TASKFLOW-NEXT: %3 = affine.load %arg0[%arg5, %arg6] : memref<4x8xi32> +// TASKFLOW-NEXT: %4 = arith.addi %3, %arg3 : i32 +// TASKFLOW-NEXT: affine.if #set(%arg5, %arg6) { +// TASKFLOW-NEXT: affine.store %4, %arg1[] : memref +// TASKFLOW-NEXT: %5 = arith.muli %4, %arg4 : i32 +// TASKFLOW-NEXT: affine.store %5, %arg1[] : memref +// TASKFLOW-NEXT: } +// TASKFLOW-NEXT: } +// TASKFLOW-NEXT: } +// TASKFLOW-NEXT: "taskflow.yield"(%arg0, %arg1) <{operandSegmentSizes = array}> : (memref<4x8xi32>, memref) -> () +// TASKFLOW-NEXT: }) : (memref<4x8xi32>, memref, i32, i32, i32) -> (memref<4x8xi32>, memref) +// TASKFLOW-NEXT: %0 = affine.load %memory_outputs#1[] : memref +// TASKFLOW-NEXT: return %0 : i32 +// TASKFLOW-NEXT: } +// TASKFLOW-NEXT: } + +// HYPERBLOCK: module { +// HYPERBLOCK-NEXT: func.func @_Z21irregularLoopExample1v() -> i32 attributes {llvm.linkage = #llvm.linkage} { +// HYPERBLOCK-NEXT: %c2_i32 = arith.constant 2 : i32 +// HYPERBLOCK-NEXT: %c8_i32 = arith.constant 8 : i32 +// HYPERBLOCK-NEXT: %c0_i32 = arith.constant 0 : i32 +// HYPERBLOCK-NEXT: %alloca = memref.alloca() : memref +// HYPERBLOCK-NEXT: %alloca_0 = memref.alloca() : memref<4x8xi32> +// HYPERBLOCK-NEXT: %value_outputs = "taskflow.task"(%c0_i32) <{operandSegmentSizes = array, resultSegmentSizes = array, task_name = "Task_0"}> ({ +// HYPERBLOCK-NEXT: ^bb0(%arg0: i32): +// HYPERBLOCK-NEXT: %1 = taskflow.counter attributes {lower_bound = 0 : index, step = 1 : index, upper_bound = 5 : index} : index +// HYPERBLOCK-NEXT: %2 = "taskflow.hyperblock"(%1, %arg0) <{operandSegmentSizes = array}> ({ +// HYPERBLOCK-NEXT: ^bb0(%arg1: index, %arg2: i32): +// HYPERBLOCK-NEXT: %3 = arith.index_cast %arg1 : index to i32 +// HYPERBLOCK-NEXT: %4 = arith.addi %arg2, %3 : i32 +// HYPERBLOCK-NEXT: taskflow.hyperblock.yield outputs(%4 : i32) +// HYPERBLOCK-NEXT: }) : (index, i32) -> i32 +// HYPERBLOCK-NEXT: "taskflow.yield"(%2) <{operandSegmentSizes = array}> : (i32) -> () +// HYPERBLOCK-NEXT: }) : (i32) -> i32 +// HYPERBLOCK-NEXT: %memory_outputs:2 = "taskflow.task"(%alloca_0, %alloca, %c8_i32, %value_outputs, %c2_i32) <{operandSegmentSizes = array, resultSegmentSizes = array, task_name = "Task_1"}> ({ +// HYPERBLOCK-NEXT: ^bb0(%arg0: memref<4x8xi32>, %arg1: memref, %arg2: i32, %arg3: i32, %arg4: i32): +// HYPERBLOCK-NEXT: %1 = taskflow.counter attributes {lower_bound = 0 : index, step = 1 : index, upper_bound = 4 : index} : index +// HYPERBLOCK-NEXT: %2 = taskflow.counter parent(%1 : index) attributes {lower_bound = 0 : index, step = 1 : index, upper_bound = 8 : index} : index +// HYPERBLOCK-NEXT: %3 = taskflow.counter parent(%1 : index) attributes {lower_bound = 0 : index, step = 1 : index, upper_bound = 8 : index} : index +// HYPERBLOCK-NEXT: "taskflow.hyperblock"(%1, %2) <{operandSegmentSizes = array}> ({ +// HYPERBLOCK-NEXT: ^bb0(%arg5: index, %arg6: index): +// HYPERBLOCK-NEXT: %4 = arith.index_cast %arg5 : index to i32 +// HYPERBLOCK-NEXT: %5 = arith.muli %4, %arg2 : i32 +// HYPERBLOCK-NEXT: %6 = arith.index_cast %arg6 : index to i32 +// HYPERBLOCK-NEXT: %7 = arith.addi %5, %6 : i32 +// HYPERBLOCK-NEXT: memref.store %7, %arg0[%arg5, %arg6] : memref<4x8xi32> +// HYPERBLOCK-NEXT: taskflow.hyperblock.yield +// HYPERBLOCK-NEXT: }) : (index, index) -> () +// HYPERBLOCK-NEXT: "taskflow.hyperblock"(%1, %3) <{operandSegmentSizes = array}> ({ +// HYPERBLOCK-NEXT: ^bb0(%arg5: index, %arg6: index): +// HYPERBLOCK-NEXT: %4 = memref.load %arg0[%arg5, %arg6] : memref<4x8xi32> +// HYPERBLOCK-NEXT: %5 = arith.addi %4, %arg3 : i32 +// HYPERBLOCK-NEXT: %c0 = arith.constant 0 : index +// HYPERBLOCK-NEXT: %c-3 = arith.constant -3 : index +// HYPERBLOCK-NEXT: %6 = arith.addi %arg5, %c-3 : index +// HYPERBLOCK-NEXT: %7 = arith.cmpi eq, %6, %c0 : index +// HYPERBLOCK-NEXT: %c-7 = arith.constant -7 : index +// HYPERBLOCK-NEXT: %8 = arith.addi %arg6, %c-7 : index +// HYPERBLOCK-NEXT: %9 = arith.cmpi eq, %8, %c0 : index +// HYPERBLOCK-NEXT: %10 = arith.andi %7, %9 : i1 +// HYPERBLOCK-NEXT: scf.if %10 { +// HYPERBLOCK-NEXT: memref.store %5, %arg1[] : memref +// HYPERBLOCK-NEXT: %11 = arith.muli %5, %arg4 : i32 +// HYPERBLOCK-NEXT: memref.store %11, %arg1[] : memref +// HYPERBLOCK-NEXT: } +// HYPERBLOCK-NEXT: taskflow.hyperblock.yield +// HYPERBLOCK-NEXT: }) : (index, index) -> () +// HYPERBLOCK-NEXT: "taskflow.yield"(%arg0, %arg1) <{operandSegmentSizes = array}> : (memref<4x8xi32>, memref) -> () +// HYPERBLOCK-NEXT: }) : (memref<4x8xi32>, memref, i32, i32, i32) -> (memref<4x8xi32>, memref) +// HYPERBLOCK-NEXT: %0 = affine.load %memory_outputs#1[] : memref +// HYPERBLOCK-NEXT: return %0 : i32 +// HYPERBLOCK-NEXT: } +// HYPERBLOCK-NEXT: } + +// CANONICALIZE: module { +// CANONICALIZE-NEXT: func.func @_Z21irregularLoopExample1v() -> i32 attributes {llvm.linkage = #llvm.linkage} { +// CANONICALIZE-NEXT: %c2_i32 = arith.constant 2 : i32 +// CANONICALIZE-NEXT: %c8_i32 = arith.constant 8 : i32 +// CANONICALIZE-NEXT: %c0_i32 = arith.constant 0 : i32 +// CANONICALIZE-NEXT: %alloca = memref.alloca() : memref +// CANONICALIZE-NEXT: %alloca_0 = memref.alloca() : memref<4x8xi32> +// CANONICALIZE-NEXT: %value_outputs = "taskflow.task"(%c0_i32) <{operandSegmentSizes = array, resultSegmentSizes = array, task_name = "Task_0"}> ({ +// CANONICALIZE-NEXT: ^bb0(%arg0: i32): +// CANONICALIZE-NEXT: %1 = taskflow.counter attributes {lower_bound = 0 : index, step = 1 : index, upper_bound = 5 : index} : index +// CANONICALIZE-NEXT: %2 = "taskflow.hyperblock"(%1, %arg0) <{operandSegmentSizes = array}> ({ +// CANONICALIZE-NEXT: ^bb0(%arg1: index, %arg2: i32): +// CANONICALIZE-NEXT: %3 = arith.index_cast %arg1 : index to i32 +// CANONICALIZE-NEXT: %4 = arith.addi %arg2, %3 : i32 +// CANONICALIZE-NEXT: taskflow.hyperblock.yield outputs(%4 : i32) +// CANONICALIZE-NEXT: }) : (index, i32) -> i32 +// CANONICALIZE-NEXT: "taskflow.yield"(%2) <{operandSegmentSizes = array}> : (i32) -> () +// CANONICALIZE-NEXT: }) : (i32) -> i32 +// CANONICALIZE-NEXT: %memory_outputs = "taskflow.task"(%alloca_0, %c8_i32, %alloca_0) <{operandSegmentSizes = array, resultSegmentSizes = array, task_name = "Task_1"}> ({ +// CANONICALIZE-NEXT: ^bb0(%arg0: memref<4x8xi32>, %arg1: i32, %arg2: memref<4x8xi32>): +// CANONICALIZE-NEXT: %1 = taskflow.counter attributes {lower_bound = 0 : index, step = 1 : index, upper_bound = 4 : index} : index +// CANONICALIZE-NEXT: %2 = taskflow.counter parent(%1 : index) attributes {lower_bound = 0 : index, step = 1 : index, upper_bound = 8 : index} : index +// CANONICALIZE-NEXT: "taskflow.hyperblock"(%1, %2) <{operandSegmentSizes = array}> ({ +// CANONICALIZE-NEXT: ^bb0(%arg3: index, %arg4: index): +// CANONICALIZE-NEXT: %3 = arith.index_cast %arg3 : index to i32 +// CANONICALIZE-NEXT: %4 = arith.muli %3, %arg1 : i32 +// CANONICALIZE-NEXT: %5 = arith.index_cast %arg4 : index to i32 +// CANONICALIZE-NEXT: %6 = arith.addi %4, %5 : i32 +// CANONICALIZE-NEXT: memref.store %6, %arg2[%arg3, %arg4] : memref<4x8xi32> +// CANONICALIZE-NEXT: taskflow.hyperblock.yield +// CANONICALIZE-NEXT: }) : (index, index) -> () +// CANONICALIZE-NEXT: "taskflow.yield"(%arg2) <{operandSegmentSizes = array}> : (memref<4x8xi32>) -> () +// CANONICALIZE-NEXT: }) : (memref<4x8xi32>, i32, memref<4x8xi32>) -> memref<4x8xi32> +// CANONICALIZE-NEXT: %memory_outputs_1 = "taskflow.task"(%memory_outputs, %alloca, %alloca_0, %value_outputs, %alloca, %c2_i32) <{operandSegmentSizes = array, resultSegmentSizes = array, task_name = "Task_2"}> ({ +// CANONICALIZE-NEXT: ^bb0(%arg0: memref<4x8xi32>, %arg1: memref, %arg2: memref<4x8xi32>, %arg3: i32, %arg4: memref, %arg5: i32): +// CANONICALIZE-NEXT: %1 = taskflow.counter attributes {lower_bound = 0 : index, step = 1 : index, upper_bound = 4 : index} : index +// CANONICALIZE-NEXT: %2 = taskflow.counter parent(%1 : index) attributes {lower_bound = 0 : index, step = 1 : index, upper_bound = 8 : index} : index +// CANONICALIZE-NEXT: "taskflow.hyperblock"(%1, %2) <{operandSegmentSizes = array}> ({ +// CANONICALIZE-NEXT: ^bb0(%arg6: index, %arg7: index): +// CANONICALIZE-NEXT: %3 = memref.load %arg2[%arg6, %arg7] : memref<4x8xi32> +// CANONICALIZE-NEXT: %4 = arith.addi %3, %arg3 : i32 +// CANONICALIZE-NEXT: %c0 = arith.constant 0 : index +// CANONICALIZE-NEXT: %c-3 = arith.constant -3 : index +// CANONICALIZE-NEXT: %5 = arith.addi %arg6, %c-3 : index +// CANONICALIZE-NEXT: %6 = arith.cmpi eq, %5, %c0 : index +// CANONICALIZE-NEXT: %c-7 = arith.constant -7 : index +// CANONICALIZE-NEXT: %7 = arith.addi %arg7, %c-7 : index +// CANONICALIZE-NEXT: %8 = arith.cmpi eq, %7, %c0 : index +// CANONICALIZE-NEXT: %9 = arith.andi %6, %8 : i1 +// CANONICALIZE-NEXT: scf.if %9 { +// CANONICALIZE-NEXT: memref.store %4, %arg4[] : memref +// CANONICALIZE-NEXT: %10 = arith.muli %4, %arg5 : i32 +// CANONICALIZE-NEXT: memref.store %10, %arg4[] : memref +// CANONICALIZE-NEXT: } +// CANONICALIZE-NEXT: taskflow.hyperblock.yield +// CANONICALIZE-NEXT: }) : (index, index) -> () +// CANONICALIZE-NEXT: "taskflow.yield"(%arg4) <{operandSegmentSizes = array}> : (memref) -> () +// CANONICALIZE-NEXT: }) : (memref<4x8xi32>, memref, memref<4x8xi32>, i32, memref, i32) -> memref +// CANONICALIZE-NEXT: %0 = affine.load %memory_outputs_1[] : memref +// CANONICALIZE-NEXT: return %0 : i32 +// CANONICALIZE-NEXT: } +// CANONICALIZE-NEXT: } + diff --git a/test/multi-cgra/taskflow/multi-nested/multi-nested.mlir b/test/multi-cgra/taskflow/multi-nested/multi-nested.mlir index ac2881c..c5f75f2 100644 --- a/test/multi-cgra/taskflow/multi-nested/multi-nested.mlir +++ b/test/multi-cgra/taskflow/multi-nested/multi-nested.mlir @@ -1,9 +1,17 @@ // RUN: mlir-neura-opt %s --convert-affine-to-taskflow \ -// RUN: | FileCheck %s --check-prefixes=TASKFLOW +// RUN: -o %t.taskflow.mlir +// RUN: FileCheck %s --input-file=%t.taskflow.mlir --check-prefixes=TASKFLOW // RUN: mlir-neura-opt %s --convert-affine-to-taskflow \ // RUN: --construct-hyperblock-from-task \ -// RUN: | FileCheck %s --check-prefixes=HYPERBLOCK +// RUN: -o %t.hyperblock.mlir +// RUN: FileCheck %s --input-file=%t.hyperblock.mlir --check-prefixes=HYPERBLOCK + +// RUN: mlir-neura-opt %s --convert-affine-to-taskflow \ +// RUN: --construct-hyperblock-from-task \ +// RUN: --canonicalize-task \ +// RUN: -o %t.canonicalized.mlir +// RUN: FileCheck %s --input-file=%t.canonicalized.mlir --check-prefixes=CANONICALIZE module attributes {} { func.func @_Z21pureNestedLoopExamplePA8_A6_iPA8_A5_iS4_PA7_iPA9_iPiS9_S9_S9_S9_(%arg0: memref, %arg1: memref, %arg2: memref, %arg3: memref, %arg4: memref, %arg5: memref, %arg6: memref, %arg7: memref, %arg8: memref, %arg9: memref) -> i32 attributes {llvm.linkage = #llvm.linkage} { @@ -44,98 +52,183 @@ module attributes {} { } } -// TASKFLOW: module { -// TASKFLOW-NEXT: func.func @_Z21pureNestedLoopExamplePA8_A6_iPA8_A5_iS4_PA7_iPA9_iPiS9_S9_S9_S9_(%arg0: memref, %arg1: memref, %arg2: memref, %arg3: memref, %arg4: memref, %arg5: memref, %arg6: memref, %arg7: memref, %arg8: memref, %arg9: memref) -> i32 attributes {llvm.linkage = #llvm.linkage} { -// TASKFLOW-NEXT: %memory_outputs:5 = "taskflow.task"(%arg0, %arg1, %arg2, %arg5, %arg6, %arg9, %arg3, %arg4, %arg7, %arg8) <{operandSegmentSizes = array, resultSegmentSizes = array, task_name = "Task_0"}> ({ -// TASKFLOW-NEXT: ^bb0(%arg10: memref, %arg11: memref, %arg12: memref, %arg13: memref, %arg14: memref, %arg15: memref, %arg16: memref, %arg17: memref, %arg18: memref, %arg19: memref): -// TASKFLOW-NEXT: affine.for %arg20 = 0 to 4 { -// TASKFLOW-NEXT: affine.for %arg21 = 0 to 8 { -// TASKFLOW-NEXT: affine.for %arg22 = 0 to 6 { -// TASKFLOW-NEXT: %1 = affine.load %arg10[%arg20, %arg21, %arg22] : memref -// TASKFLOW-NEXT: affine.store %1, %arg13[%arg22] : memref -// TASKFLOW-NEXT: } -// TASKFLOW-NEXT: affine.for %arg22 = 0 to 5 { -// TASKFLOW-NEXT: %1 = affine.load %arg11[%arg20, %arg21, %arg22] : memref -// TASKFLOW-NEXT: %2 = affine.load %arg12[%arg20, %arg21, %arg22] : memref -// TASKFLOW-NEXT: %3 = arith.addi %1, %2 : i32 -// TASKFLOW-NEXT: affine.store %3, %arg14[%arg22] : memref -// TASKFLOW-NEXT: } -// TASKFLOW-NEXT: affine.for %arg22 = 0 to 6 { -// TASKFLOW-NEXT: %1 = affine.load %arg13[%arg22] : memref -// TASKFLOW-NEXT: %2 = affine.load %arg14[%arg22] : memref -// TASKFLOW-NEXT: %3 = arith.addi %1, %2 : i32 -// TASKFLOW-NEXT: %4 = affine.load %arg15[0] : memref -// TASKFLOW-NEXT: %5 = arith.addi %4, %3 : i32 -// TASKFLOW-NEXT: affine.store %5, %arg15[0] : memref -// TASKFLOW-NEXT: } -// TASKFLOW-NEXT: } -// TASKFLOW-NEXT: affine.for %arg21 = 0 to 7 { -// TASKFLOW-NEXT: %1 = affine.load %arg16[%arg20, %arg21] : memref -// TASKFLOW-NEXT: affine.store %1, %arg18[%arg21] : memref -// TASKFLOW-NEXT: } -// TASKFLOW-NEXT: affine.for %arg21 = 0 to 9 { -// TASKFLOW-NEXT: %1 = affine.load %arg17[%arg20, %arg21] : memref -// TASKFLOW-NEXT: %2 = affine.load %arg18[%arg21] : memref -// TASKFLOW-NEXT: %3 = arith.addi %1, %2 : i32 -// TASKFLOW-NEXT: affine.store %3, %arg19[%arg21] : memref -// TASKFLOW-NEXT: } -// TASKFLOW-NEXT: } -// TASKFLOW-NEXT: "taskflow.yield"(%arg13, %arg14, %arg15, %arg18, %arg19) <{operandSegmentSizes = array}> : (memref, memref, memref, memref, memref) -> () -// TASKFLOW-NEXT: }) : (memref, memref, memref, memref, memref, memref, memref, memref, memref, memref) -> (memref, memref, memref, memref, memref) -// TASKFLOW-NEXT: %0 = affine.load %arg9[0] : memref -// TASKFLOW-NEXT: return %0 : i32 -// TASKFLOW-NEXT: } -// TASKFLOW-NEXT:} +// TASKFLOW: module { +// TASKFLOW-NEXT: func.func @_Z21pureNestedLoopExamplePA8_A6_iPA8_A5_iS4_PA7_iPA9_iPiS9_S9_S9_S9_(%arg0: memref, %arg1: memref, %arg2: memref, %arg3: memref, %arg4: memref, %arg5: memref, %arg6: memref, %arg7: memref, %arg8: memref, %arg9: memref) -> i32 attributes {llvm.linkage = #llvm.linkage} { +// TASKFLOW-NEXT: %memory_outputs:5 = "taskflow.task"(%arg0, %arg1, %arg2, %arg5, %arg6, %arg9, %arg3, %arg4, %arg7, %arg8) <{operandSegmentSizes = array, resultSegmentSizes = array, task_name = "Task_0"}> ({ +// TASKFLOW-NEXT: ^bb0(%arg10: memref, %arg11: memref, %arg12: memref, %arg13: memref, %arg14: memref, %arg15: memref, %arg16: memref, %arg17: memref, %arg18: memref, %arg19: memref): +// TASKFLOW-NEXT: affine.for %arg20 = 0 to 4 { +// TASKFLOW-NEXT: affine.for %arg21 = 0 to 8 { +// TASKFLOW-NEXT: affine.for %arg22 = 0 to 6 { +// TASKFLOW-NEXT: %1 = affine.load %arg10[%arg20, %arg21, %arg22] : memref +// TASKFLOW-NEXT: affine.store %1, %arg13[%arg22] : memref +// TASKFLOW-NEXT: } +// TASKFLOW-NEXT: affine.for %arg22 = 0 to 5 { +// TASKFLOW-NEXT: %1 = affine.load %arg11[%arg20, %arg21, %arg22] : memref +// TASKFLOW-NEXT: %2 = affine.load %arg12[%arg20, %arg21, %arg22] : memref +// TASKFLOW-NEXT: %3 = arith.addi %1, %2 : i32 +// TASKFLOW-NEXT: affine.store %3, %arg14[%arg22] : memref +// TASKFLOW-NEXT: } +// TASKFLOW-NEXT: affine.for %arg22 = 0 to 6 { +// TASKFLOW-NEXT: %1 = affine.load %arg13[%arg22] : memref +// TASKFLOW-NEXT: %2 = affine.load %arg14[%arg22] : memref +// TASKFLOW-NEXT: %3 = arith.addi %1, %2 : i32 +// TASKFLOW-NEXT: %4 = affine.load %arg15[0] : memref +// TASKFLOW-NEXT: %5 = arith.addi %4, %3 : i32 +// TASKFLOW-NEXT: affine.store %5, %arg15[0] : memref +// TASKFLOW-NEXT: } +// TASKFLOW-NEXT: } +// TASKFLOW-NEXT: affine.for %arg21 = 0 to 7 { +// TASKFLOW-NEXT: %1 = affine.load %arg16[%arg20, %arg21] : memref +// TASKFLOW-NEXT: affine.store %1, %arg18[%arg21] : memref +// TASKFLOW-NEXT: } +// TASKFLOW-NEXT: affine.for %arg21 = 0 to 9 { +// TASKFLOW-NEXT: %1 = affine.load %arg17[%arg20, %arg21] : memref +// TASKFLOW-NEXT: %2 = affine.load %arg18[%arg21] : memref +// TASKFLOW-NEXT: %3 = arith.addi %1, %2 : i32 +// TASKFLOW-NEXT: affine.store %3, %arg19[%arg21] : memref +// TASKFLOW-NEXT: } +// TASKFLOW-NEXT: } +// TASKFLOW-NEXT: "taskflow.yield"(%arg13, %arg14, %arg15, %arg18, %arg19) <{operandSegmentSizes = array}> : (memref, memref, memref, memref, memref) -> () +// TASKFLOW-NEXT: }) : (memref, memref, memref, memref, memref, memref, memref, memref, memref, memref) -> (memref, memref, memref, memref, memref) +// TASKFLOW-NEXT: %0 = affine.load %memory_outputs#2[0] : memref +// TASKFLOW-NEXT: return %0 : i32 +// TASKFLOW-NEXT: } +// TASKFLOW-NEXT: } + +// HYPERBLOCK: module { +// HYPERBLOCK-NEXT: func.func @_Z21pureNestedLoopExamplePA8_A6_iPA8_A5_iS4_PA7_iPA9_iPiS9_S9_S9_S9_(%arg0: memref, %arg1: memref, %arg2: memref, %arg3: memref, %arg4: memref, %arg5: memref, %arg6: memref, %arg7: memref, %arg8: memref, %arg9: memref) -> i32 attributes {llvm.linkage = #llvm.linkage} { +// HYPERBLOCK-NEXT: %memory_outputs:5 = "taskflow.task"(%arg0, %arg1, %arg2, %arg5, %arg6, %arg9, %arg3, %arg4, %arg7, %arg8) <{operandSegmentSizes = array, resultSegmentSizes = array, task_name = "Task_0"}> ({ +// HYPERBLOCK-NEXT: ^bb0(%arg10: memref, %arg11: memref, %arg12: memref, %arg13: memref, %arg14: memref, %arg15: memref, %arg16: memref, %arg17: memref, %arg18: memref, %arg19: memref): +// HYPERBLOCK-NEXT: %1 = taskflow.counter attributes {lower_bound = 0 : index, step = 1 : index, upper_bound = 4 : index} : index +// HYPERBLOCK-NEXT: %2 = taskflow.counter parent(%1 : index) attributes {lower_bound = 0 : index, step = 1 : index, upper_bound = 8 : index} : index +// HYPERBLOCK-NEXT: %3 = taskflow.counter parent(%2 : index) attributes {lower_bound = 0 : index, step = 1 : index, upper_bound = 6 : index} : index +// HYPERBLOCK-NEXT: %4 = taskflow.counter parent(%2 : index) attributes {lower_bound = 0 : index, step = 1 : index, upper_bound = 5 : index} : index +// HYPERBLOCK-NEXT: %5 = taskflow.counter parent(%2 : index) attributes {lower_bound = 0 : index, step = 1 : index, upper_bound = 6 : index} : index +// HYPERBLOCK-NEXT: %6 = taskflow.counter parent(%1 : index) attributes {lower_bound = 0 : index, step = 1 : index, upper_bound = 7 : index} : index +// HYPERBLOCK-NEXT: %7 = taskflow.counter parent(%1 : index) attributes {lower_bound = 0 : index, step = 1 : index, upper_bound = 9 : index} : index +// HYPERBLOCK-NEXT: "taskflow.hyperblock"(%1, %2, %3) <{operandSegmentSizes = array}> ({ +// HYPERBLOCK-NEXT: ^bb0(%arg20: index, %arg21: index, %arg22: index): +// HYPERBLOCK-NEXT: %8 = memref.load %arg10[%arg20, %arg21, %arg22] : memref +// HYPERBLOCK-NEXT: memref.store %8, %arg13[%arg22] : memref +// HYPERBLOCK-NEXT: taskflow.hyperblock.yield +// HYPERBLOCK-NEXT: }) : (index, index, index) -> () +// HYPERBLOCK-NEXT: "taskflow.hyperblock"(%1, %2, %4) <{operandSegmentSizes = array}> ({ +// HYPERBLOCK-NEXT: ^bb0(%arg20: index, %arg21: index, %arg22: index): +// HYPERBLOCK-NEXT: %8 = memref.load %arg11[%arg20, %arg21, %arg22] : memref +// HYPERBLOCK-NEXT: %9 = memref.load %arg12[%arg20, %arg21, %arg22] : memref +// HYPERBLOCK-NEXT: %10 = arith.addi %8, %9 : i32 +// HYPERBLOCK-NEXT: memref.store %10, %arg14[%arg22] : memref +// HYPERBLOCK-NEXT: taskflow.hyperblock.yield +// HYPERBLOCK-NEXT: }) : (index, index, index) -> () +// HYPERBLOCK-NEXT: "taskflow.hyperblock"(%5) <{operandSegmentSizes = array}> ({ +// HYPERBLOCK-NEXT: ^bb0(%arg20: index): +// HYPERBLOCK-NEXT: %8 = memref.load %arg13[%arg20] : memref +// HYPERBLOCK-NEXT: %9 = memref.load %arg14[%arg20] : memref +// HYPERBLOCK-NEXT: %10 = arith.addi %8, %9 : i32 +// HYPERBLOCK-NEXT: %c0 = arith.constant 0 : index +// HYPERBLOCK-NEXT: %11 = memref.load %arg15[%c0] : memref +// HYPERBLOCK-NEXT: %12 = arith.addi %11, %10 : i32 +// HYPERBLOCK-NEXT: %c0_0 = arith.constant 0 : index +// HYPERBLOCK-NEXT: memref.store %12, %arg15[%c0_0] : memref +// HYPERBLOCK-NEXT: taskflow.hyperblock.yield +// HYPERBLOCK-NEXT: }) : (index) -> () +// HYPERBLOCK-NEXT: "taskflow.hyperblock"(%1, %6) <{operandSegmentSizes = array}> ({ +// HYPERBLOCK-NEXT: ^bb0(%arg20: index, %arg21: index): +// HYPERBLOCK-NEXT: %8 = memref.load %arg16[%arg20, %arg21] : memref +// HYPERBLOCK-NEXT: memref.store %8, %arg18[%arg21] : memref +// HYPERBLOCK-NEXT: taskflow.hyperblock.yield +// HYPERBLOCK-NEXT: }) : (index, index) -> () +// HYPERBLOCK-NEXT: "taskflow.hyperblock"(%1, %7) <{operandSegmentSizes = array}> ({ +// HYPERBLOCK-NEXT: ^bb0(%arg20: index, %arg21: index): +// HYPERBLOCK-NEXT: %8 = memref.load %arg17[%arg20, %arg21] : memref +// HYPERBLOCK-NEXT: %9 = memref.load %arg18[%arg21] : memref +// HYPERBLOCK-NEXT: %10 = arith.addi %8, %9 : i32 +// HYPERBLOCK-NEXT: memref.store %10, %arg19[%arg21] : memref +// HYPERBLOCK-NEXT: taskflow.hyperblock.yield +// HYPERBLOCK-NEXT: }) : (index, index) -> () +// HYPERBLOCK-NEXT: "taskflow.yield"(%arg13, %arg14, %arg15, %arg18, %arg19) <{operandSegmentSizes = array}> : (memref, memref, memref, memref, memref) -> () +// HYPERBLOCK-NEXT: }) : (memref, memref, memref, memref, memref, memref, memref, memref, memref, memref) -> (memref, memref, memref, memref, memref) +// HYPERBLOCK-NEXT: %0 = affine.load %memory_outputs#2[0] : memref +// HYPERBLOCK-NEXT: return %0 : i32 +// HYPERBLOCK-NEXT: } +// HYPERBLOCK-NEXT: } -// HYPERBLOCK: module { -// HYPERBLOCK-NEXT: func.func @_Z21pureNestedLoopExamplePA8_A6_iPA8_A5_iS4_PA7_iPA9_iPiS9_S9_S9_S9_(%arg0: memref, %arg1: memref, %arg2: memref, %arg3: memref, %arg4: memref, %arg5: memref, %arg6: memref, %arg7: memref, %arg8: memref, %arg9: memref) -> i32 attributes {llvm.linkage = #llvm.linkage} { -// HYPERBLOCK-NEXT: %memory_outputs:5 = "taskflow.task"(%arg0, %arg1, %arg2, %arg5, %arg6, %arg9, %arg3, %arg4, %arg7, %arg8) <{operandSegmentSizes = array, resultSegmentSizes = array, task_name = "Task_0"}> ({ -// HYPERBLOCK-NEXT: ^bb0(%arg10: memref, %arg11: memref, %arg12: memref, %arg13: memref, %arg14: memref, %arg15: memref, %arg16: memref, %arg17: memref, %arg18: memref, %arg19: memref): -// HYPERBLOCK-NEXT: %1 = taskflow.counter attributes {lower_bound = 0 : index, step = 1 : index, upper_bound = 4 : index} : index -// HYPERBLOCK-NEXT: %2 = taskflow.counter parent(%1 : index) attributes {lower_bound = 0 : index, step = 1 : index, upper_bound = 8 : index} : index -// HYPERBLOCK-NEXT: %3 = taskflow.counter parent(%2 : index) attributes {lower_bound = 0 : index, step = 1 : index, upper_bound = 6 : index} : index -// HYPERBLOCK-NEXT: %4 = taskflow.counter parent(%2 : index) attributes {lower_bound = 0 : index, step = 1 : index, upper_bound = 5 : index} : index -// HYPERBLOCK-NEXT: %5 = taskflow.counter parent(%2 : index) attributes {lower_bound = 0 : index, step = 1 : index, upper_bound = 6 : index} : index -// HYPERBLOCK-NEXT: %6 = taskflow.counter parent(%1 : index) attributes {lower_bound = 0 : index, step = 1 : index, upper_bound = 7 : index} : index -// HYPERBLOCK-NEXT: %7 = taskflow.counter parent(%1 : index) attributes {lower_bound = 0 : index, step = 1 : index, upper_bound = 9 : index} : index -// HYPERBLOCK-NEXT: taskflow.hyperblock indices(%1, %2, %3 : index, index, index) { -// HYPERBLOCK-NEXT: ^bb0(%arg20: index, %arg21: index, %arg22: index): -// HYPERBLOCK-NEXT: %8 = memref.load %arg10[%arg20, %arg21, %arg22] : memref -// HYPERBLOCK-NEXT: memref.store %8, %arg13[%arg22] : memref -// HYPERBLOCK-NEXT: } -> () -// HYPERBLOCK-NEXT: taskflow.hyperblock indices(%1, %2, %4 : index, index, index) { -// HYPERBLOCK-NEXT: ^bb0(%arg20: index, %arg21: index, %arg22: index): -// HYPERBLOCK-NEXT: %8 = memref.load %arg11[%arg20, %arg21, %arg22] : memref -// HYPERBLOCK-NEXT: %9 = memref.load %arg12[%arg20, %arg21, %arg22] : memref -// HYPERBLOCK-NEXT: %10 = arith.addi %8, %9 : i32 -// HYPERBLOCK-NEXT: memref.store %10, %arg14[%arg22] : memref -// HYPERBLOCK-NEXT: } -> () -// HYPERBLOCK-NEXT: taskflow.hyperblock indices(%5 : index) { -// HYPERBLOCK-NEXT: ^bb0(%arg20: index): -// HYPERBLOCK-NEXT: %8 = memref.load %arg13[%arg20] : memref -// HYPERBLOCK-NEXT: %9 = memref.load %arg14[%arg20] : memref -// HYPERBLOCK-NEXT: %10 = arith.addi %8, %9 : i32 -// HYPERBLOCK-NEXT: %c0 = arith.constant 0 : index -// HYPERBLOCK-NEXT: %11 = memref.load %arg15[%c0] : memref -// HYPERBLOCK-NEXT: %12 = arith.addi %11, %10 : i32 -// HYPERBLOCK-NEXT: %c0_0 = arith.constant 0 : index -// HYPERBLOCK-NEXT: memref.store %12, %arg15[%c0_0] : memref -// HYPERBLOCK-NEXT: } -> () -// HYPERBLOCK-NEXT: taskflow.hyperblock indices(%1, %6 : index, index) { -// HYPERBLOCK-NEXT: ^bb0(%arg20: index, %arg21: index): -// HYPERBLOCK-NEXT: %8 = memref.load %arg16[%arg20, %arg21] : memref -// HYPERBLOCK-NEXT: memref.store %8, %arg18[%arg21] : memref -// HYPERBLOCK-NEXT: } -> () -// HYPERBLOCK-NEXT: taskflow.hyperblock indices(%1, %7 : index, index) { -// HYPERBLOCK-NEXT: ^bb0(%arg20: index, %arg21: index): -// HYPERBLOCK-NEXT: %8 = memref.load %arg17[%arg20, %arg21] : memref -// HYPERBLOCK-NEXT: %9 = memref.load %arg18[%arg21] : memref -// HYPERBLOCK-NEXT: %10 = arith.addi %8, %9 : i32 -// HYPERBLOCK-NEXT: memref.store %10, %arg19[%arg21] : memref -// HYPERBLOCK-NEXT: } -> () -// HYPERBLOCK-NEXT: "taskflow.yield"(%arg13, %arg14, %arg15, %arg18, %arg19) <{operandSegmentSizes = array}> : (memref, memref, memref, memref, memref) -> () -// HYPERBLOCK-NEXT: }) : (memref, memref, memref, memref, memref, memref, memref, memref, memref, memref) -> (memref, memref, memref, memref, memref) -// HYPERBLOCK-NEXT: %0 = affine.load %arg9[0] : memref -// HYPERBLOCK-NEXT: return %0 : i32 -// HYPERBLOCK-NEXT: } -// HYPERBLOCK-NEXT:} \ No newline at end of file +// CANONICALIZE: module { +// CANONICALIZE-NEXT: func.func @_Z21pureNestedLoopExamplePA8_A6_iPA8_A5_iS4_PA7_iPA9_iPiS9_S9_S9_S9_(%arg0: memref, %arg1: memref, %arg2: memref, %arg3: memref, %arg4: memref, %arg5: memref, %arg6: memref, %arg7: memref, %arg8: memref, %arg9: memref) -> i32 attributes {llvm.linkage = #llvm.linkage} { +// CANONICALIZE-NEXT: %memory_outputs = "taskflow.task"(%arg0, %arg5, %arg0, %arg5) <{operandSegmentSizes = array, resultSegmentSizes = array, task_name = "Task_0"}> ({ +// CANONICALIZE-NEXT: ^bb0(%arg10: memref, %arg11: memref, %arg12: memref, %arg13: memref): +// CANONICALIZE-NEXT: %1 = taskflow.counter attributes {lower_bound = 0 : index, step = 1 : index, upper_bound = 4 : index} : index +// CANONICALIZE-NEXT: %2 = taskflow.counter parent(%1 : index) attributes {lower_bound = 0 : index, step = 1 : index, upper_bound = 8 : index} : index +// CANONICALIZE-NEXT: %3 = taskflow.counter parent(%2 : index) attributes {lower_bound = 0 : index, step = 1 : index, upper_bound = 6 : index} : index +// CANONICALIZE-NEXT: "taskflow.hyperblock"(%1, %2, %3) <{operandSegmentSizes = array}> ({ +// CANONICALIZE-NEXT: ^bb0(%arg14: index, %arg15: index, %arg16: index): +// CANONICALIZE-NEXT: %4 = memref.load %arg12[%arg14, %arg15, %arg16] : memref +// CANONICALIZE-NEXT: memref.store %4, %arg13[%arg16] : memref +// CANONICALIZE-NEXT: taskflow.hyperblock.yield +// CANONICALIZE-NEXT: }) : (index, index, index) -> () +// CANONICALIZE-NEXT: "taskflow.yield"(%arg13) <{operandSegmentSizes = array}> : (memref) -> () +// CANONICALIZE-NEXT: }) : (memref, memref, memref, memref) -> memref +// CANONICALIZE-NEXT: %memory_outputs_0 = "taskflow.task"(%arg1, %arg2, %arg6, %arg1, %arg2, %arg6) <{operandSegmentSizes = array, resultSegmentSizes = array, task_name = "Task_1"}> ({ +// CANONICALIZE-NEXT: ^bb0(%arg10: memref, %arg11: memref, %arg12: memref, %arg13: memref, %arg14: memref, %arg15: memref): +// CANONICALIZE-NEXT: %1 = taskflow.counter attributes {lower_bound = 0 : index, step = 1 : index, upper_bound = 4 : index} : index +// CANONICALIZE-NEXT: %2 = taskflow.counter parent(%1 : index) attributes {lower_bound = 0 : index, step = 1 : index, upper_bound = 8 : index} : index +// CANONICALIZE-NEXT: %3 = taskflow.counter parent(%2 : index) attributes {lower_bound = 0 : index, step = 1 : index, upper_bound = 5 : index} : index +// CANONICALIZE-NEXT: "taskflow.hyperblock"(%1, %2, %3) <{operandSegmentSizes = array}> ({ +// CANONICALIZE-NEXT: ^bb0(%arg16: index, %arg17: index, %arg18: index): +// CANONICALIZE-NEXT: %4 = memref.load %arg13[%arg16, %arg17, %arg18] : memref +// CANONICALIZE-NEXT: %5 = memref.load %arg14[%arg16, %arg17, %arg18] : memref +// CANONICALIZE-NEXT: %6 = arith.addi %4, %5 : i32 +// CANONICALIZE-NEXT: memref.store %6, %arg15[%arg18] : memref +// CANONICALIZE-NEXT: taskflow.hyperblock.yield +// CANONICALIZE-NEXT: }) : (index, index, index) -> () +// CANONICALIZE-NEXT: "taskflow.yield"(%arg15) <{operandSegmentSizes = array}> : (memref) -> () +// CANONICALIZE-NEXT: }) : (memref, memref, memref, memref, memref, memref) -> memref +// CANONICALIZE-NEXT: %memory_outputs_1 = "taskflow.task"(%memory_outputs, %memory_outputs_0, %arg9, %arg5, %arg6, %arg9) <{operandSegmentSizes = array, resultSegmentSizes = array, task_name = "Task_2"}> ({ +// CANONICALIZE-NEXT: ^bb0(%arg10: memref, %arg11: memref, %arg12: memref, %arg13: memref, %arg14: memref, %arg15: memref): +// CANONICALIZE-NEXT: %1 = taskflow.counter attributes {lower_bound = 0 : index, step = 1 : index, upper_bound = 4 : index} : index +// CANONICALIZE-NEXT: %2 = taskflow.counter parent(%1 : index) attributes {lower_bound = 0 : index, step = 1 : index, upper_bound = 8 : index} : index +// CANONICALIZE-NEXT: %3 = taskflow.counter parent(%2 : index) attributes {lower_bound = 0 : index, step = 1 : index, upper_bound = 6 : index} : index +// CANONICALIZE-NEXT: "taskflow.hyperblock"(%3) <{operandSegmentSizes = array}> ({ +// CANONICALIZE-NEXT: ^bb0(%arg16: index): +// CANONICALIZE-NEXT: %4 = memref.load %arg13[%arg16] : memref +// CANONICALIZE-NEXT: %5 = memref.load %arg14[%arg16] : memref +// CANONICALIZE-NEXT: %6 = arith.addi %4, %5 : i32 +// CANONICALIZE-NEXT: %c0 = arith.constant 0 : index +// CANONICALIZE-NEXT: %7 = memref.load %arg15[%c0] : memref +// CANONICALIZE-NEXT: %8 = arith.addi %7, %6 : i32 +// CANONICALIZE-NEXT: %c0_4 = arith.constant 0 : index +// CANONICALIZE-NEXT: memref.store %8, %arg15[%c0_4] : memref +// CANONICALIZE-NEXT: taskflow.hyperblock.yield +// CANONICALIZE-NEXT: }) : (index) -> () +// CANONICALIZE-NEXT: "taskflow.yield"(%arg15) <{operandSegmentSizes = array}> : (memref) -> () +// CANONICALIZE-NEXT: }) : (memref, memref, memref, memref, memref, memref) -> memref +// CANONICALIZE-NEXT: %memory_outputs_2 = "taskflow.task"(%arg3, %arg7, %arg3, %arg7) <{operandSegmentSizes = array, resultSegmentSizes = array, task_name = "Task_3"}> ({ +// CANONICALIZE-NEXT: ^bb0(%arg10: memref, %arg11: memref, %arg12: memref, %arg13: memref): +// CANONICALIZE-NEXT: %1 = taskflow.counter attributes {lower_bound = 0 : index, step = 1 : index, upper_bound = 4 : index} : index +// CANONICALIZE-NEXT: %2 = taskflow.counter parent(%1 : index) attributes {lower_bound = 0 : index, step = 1 : index, upper_bound = 7 : index} : index +// CANONICALIZE-NEXT: "taskflow.hyperblock"(%1, %2) <{operandSegmentSizes = array}> ({ +// CANONICALIZE-NEXT: ^bb0(%arg14: index, %arg15: index): +// CANONICALIZE-NEXT: %3 = memref.load %arg12[%arg14, %arg15] : memref +// CANONICALIZE-NEXT: memref.store %3, %arg13[%arg15] : memref +// CANONICALIZE-NEXT: taskflow.hyperblock.yield +// CANONICALIZE-NEXT: }) : (index, index) -> () +// CANONICALIZE-NEXT: "taskflow.yield"(%arg13) <{operandSegmentSizes = array}> : (memref) -> () +// CANONICALIZE-NEXT: }) : (memref, memref, memref, memref) -> memref +// CANONICALIZE-NEXT: %memory_outputs_3 = "taskflow.task"(%arg4, %memory_outputs_2, %arg8, %arg4, %arg7, %arg8) <{operandSegmentSizes = array, resultSegmentSizes = array, task_name = "Task_4"}> ({ +// CANONICALIZE-NEXT: ^bb0(%arg10: memref, %arg11: memref, %arg12: memref, %arg13: memref, %arg14: memref, %arg15: memref): +// CANONICALIZE-NEXT: %1 = taskflow.counter attributes {lower_bound = 0 : index, step = 1 : index, upper_bound = 4 : index} : index +// CANONICALIZE-NEXT: %2 = taskflow.counter parent(%1 : index) attributes {lower_bound = 0 : index, step = 1 : index, upper_bound = 9 : index} : index +// CANONICALIZE-NEXT: "taskflow.hyperblock"(%1, %2) <{operandSegmentSizes = array}> ({ +// CANONICALIZE-NEXT: ^bb0(%arg16: index, %arg17: index): +// CANONICALIZE-NEXT: %3 = memref.load %arg13[%arg16, %arg17] : memref +// CANONICALIZE-NEXT: %4 = memref.load %arg14[%arg17] : memref +// CANONICALIZE-NEXT: %5 = arith.addi %3, %4 : i32 +// CANONICALIZE-NEXT: memref.store %5, %arg15[%arg17] : memref +// CANONICALIZE-NEXT: taskflow.hyperblock.yield +// CANONICALIZE-NEXT: }) : (index, index) -> () +// CANONICALIZE-NEXT: "taskflow.yield"(%arg15) <{operandSegmentSizes = array}> : (memref) -> () +// CANONICALIZE-NEXT: }) : (memref, memref, memref, memref, memref, memref) -> memref +// CANONICALIZE-NEXT: %0 = affine.load %memory_outputs_1[0] : memref +// CANONICALIZE-NEXT: return %0 : i32 +// CANONICALIZE-NEXT: } +// CANONICALIZE-NEXT: } \ No newline at end of file diff --git a/test/multi-cgra/taskflow/parallel-nested/parallel-nested.mlir b/test/multi-cgra/taskflow/parallel-nested/parallel-nested.mlir index ab4360e..ee37c83 100644 --- a/test/multi-cgra/taskflow/parallel-nested/parallel-nested.mlir +++ b/test/multi-cgra/taskflow/parallel-nested/parallel-nested.mlir @@ -1,9 +1,17 @@ // RUN: mlir-neura-opt %s --convert-affine-to-taskflow \ -// RUN: | FileCheck %s --check-prefixes=TASKFLOW +// RUN: -o %t.taskflow.mlir +// RUN: FileCheck %s --input-file=%t.taskflow.mlir --check-prefixes=TASKFLOW // RUN: mlir-neura-opt %s --convert-affine-to-taskflow \ // RUN: --construct-hyperblock-from-task \ -// RUN: | FileCheck %s --check-prefixes=HYPERBLOCK +// RUN: -o %t.hyperblock.mlir +// RUN: FileCheck %s --input-file=%t.hyperblock.mlir --check-prefixes=HYPERBLOCK + +// RUN: mlir-neura-opt %s --convert-affine-to-taskflow \ +// RUN: --construct-hyperblock-from-task \ +// RUN: --canonicalize-task \ +// RUN: -o %t.canonicalized.mlir +// RUN: FileCheck %s --input-file=%t.canonicalized.mlir --check-prefixes=CANONICALIZE module { // Example: Parallel nested loops scenario @@ -68,27 +76,61 @@ module { // HYPERBLOCK-NEXT: %memory_outputs = "taskflow.task"(%arg0, %arg4) <{operandSegmentSizes = array, resultSegmentSizes = array, task_name = "Task_0"}> ({ // HYPERBLOCK-NEXT: ^bb0(%arg5: memref<16xf32>, %arg6: f32): // HYPERBLOCK-NEXT: %0 = taskflow.counter attributes {lower_bound = 0 : index, step = 1 : index, upper_bound = 16 : index} : index -// HYPERBLOCK-NEXT: taskflow.hyperblock indices(%0 : index) { +// HYPERBLOCK-NEXT: "taskflow.hyperblock"(%0) <{operandSegmentSizes = array}> ({ // HYPERBLOCK-NEXT: ^bb0(%arg7: index): // HYPERBLOCK-NEXT: %1 = memref.load %arg5[%arg7] : memref<16xf32> // HYPERBLOCK-NEXT: %2 = arith.mulf %1, %arg6 : f32 // HYPERBLOCK-NEXT: memref.store %2, %arg5[%arg7] : memref<16xf32> -// HYPERBLOCK-NEXT: } -> () +// HYPERBLOCK-NEXT: taskflow.hyperblock.yield +// HYPERBLOCK-NEXT: }) : (index) -> () // HYPERBLOCK-NEXT: "taskflow.yield"(%arg5) <{operandSegmentSizes = array}> : (memref<16xf32>) -> () // HYPERBLOCK-NEXT: }) : (memref<16xf32>, f32) -> memref<16xf32> // HYPERBLOCK-NEXT: %memory_outputs_0 = "taskflow.task"(%arg1, %arg2, %arg3) <{operandSegmentSizes = array, resultSegmentSizes = array, task_name = "Task_1"}> ({ // HYPERBLOCK-NEXT: ^bb0(%arg5: memref<8x8xf32>, %arg6: memref<8x8xf32>, %arg7: memref<8x8xf32>): // HYPERBLOCK-NEXT: %0 = taskflow.counter attributes {lower_bound = 0 : index, step = 1 : index, upper_bound = 8 : index} : index // HYPERBLOCK-NEXT: %1 = taskflow.counter parent(%0 : index) attributes {lower_bound = 0 : index, step = 1 : index, upper_bound = 8 : index} : index -// HYPERBLOCK-NEXT: taskflow.hyperblock indices(%0, %1 : index, index) { +// HYPERBLOCK-NEXT: "taskflow.hyperblock"(%0, %1) <{operandSegmentSizes = array}> ({ // HYPERBLOCK-NEXT: ^bb0(%arg8: index, %arg9: index): // HYPERBLOCK-NEXT: %2 = memref.load %arg5[%arg8, %arg9] : memref<8x8xf32> // HYPERBLOCK-NEXT: %3 = memref.load %arg6[%arg8, %arg9] : memref<8x8xf32> // HYPERBLOCK-NEXT: %4 = arith.mulf %2, %3 : f32 // HYPERBLOCK-NEXT: memref.store %4, %arg7[%arg8, %arg9] : memref<8x8xf32> -// HYPERBLOCK-NEXT: } -> () +// HYPERBLOCK-NEXT: taskflow.hyperblock.yield +// HYPERBLOCK-NEXT: }) : (index, index) -> () // HYPERBLOCK-NEXT: "taskflow.yield"(%arg7) <{operandSegmentSizes = array}> : (memref<8x8xf32>) -> () // HYPERBLOCK-NEXT: }) : (memref<8x8xf32>, memref<8x8xf32>, memref<8x8xf32>) -> memref<8x8xf32> // HYPERBLOCK-NEXT: return // HYPERBLOCK-NEXT: } -// HYPERBLOCK-NEXT: } \ No newline at end of file +// HYPERBLOCK-NEXT: } + +// CANONICALIZE: module { +// CANONICALIZE-NEXT: func.func @parallel_nested_example(%arg0: memref<16xf32>, %arg1: memref<8x8xf32>, %arg2: memref<8x8xf32>, %arg3: memref<8x8xf32>, %arg4: f32) { +// CANONICALIZE-NEXT: %memory_outputs = "taskflow.task"(%arg0, %arg4) <{operandSegmentSizes = array, resultSegmentSizes = array, task_name = "Task_0"}> ({ +// CANONICALIZE-NEXT: ^bb0(%arg5: memref<16xf32>, %arg6: f32): +// CANONICALIZE-NEXT: %0 = taskflow.counter attributes {lower_bound = 0 : index, step = 1 : index, upper_bound = 16 : index} : index +// CANONICALIZE-NEXT: "taskflow.hyperblock"(%0) <{operandSegmentSizes = array}> ({ +// CANONICALIZE-NEXT: ^bb0(%arg7: index): +// CANONICALIZE-NEXT: %1 = memref.load %arg5[%arg7] : memref<16xf32> +// CANONICALIZE-NEXT: %2 = arith.mulf %1, %arg6 : f32 +// CANONICALIZE-NEXT: memref.store %2, %arg5[%arg7] : memref<16xf32> +// CANONICALIZE-NEXT: taskflow.hyperblock.yield +// CANONICALIZE-NEXT: }) : (index) -> () +// CANONICALIZE-NEXT: "taskflow.yield"(%arg5) <{operandSegmentSizes = array}> : (memref<16xf32>) -> () +// CANONICALIZE-NEXT: }) : (memref<16xf32>, f32) -> memref<16xf32> +// CANONICALIZE-NEXT: %memory_outputs_0 = "taskflow.task"(%arg1, %arg2, %arg3) <{operandSegmentSizes = array, resultSegmentSizes = array, task_name = "Task_1"}> ({ +// CANONICALIZE-NEXT: ^bb0(%arg5: memref<8x8xf32>, %arg6: memref<8x8xf32>, %arg7: memref<8x8xf32>): +// CANONICALIZE-NEXT: %0 = taskflow.counter attributes {lower_bound = 0 : index, step = 1 : index, upper_bound = 8 : index} : index +// CANONICALIZE-NEXT: %1 = taskflow.counter parent(%0 : index) attributes {lower_bound = 0 : index, step = 1 : index, upper_bound = 8 : index} : index +// CANONICALIZE-NEXT: "taskflow.hyperblock"(%0, %1) <{operandSegmentSizes = array}> ({ +// CANONICALIZE-NEXT: ^bb0(%arg8: index, %arg9: index): +// CANONICALIZE-NEXT: %2 = memref.load %arg5[%arg8, %arg9] : memref<8x8xf32> +// CANONICALIZE-NEXT: %3 = memref.load %arg6[%arg8, %arg9] : memref<8x8xf32> +// CANONICALIZE-NEXT: %4 = arith.mulf %2, %3 : f32 +// CANONICALIZE-NEXT: memref.store %4, %arg7[%arg8, %arg9] : memref<8x8xf32> +// CANONICALIZE-NEXT: taskflow.hyperblock.yield +// CANONICALIZE-NEXT: }) : (index, index) -> () +// CANONICALIZE-NEXT: "taskflow.yield"(%arg7) <{operandSegmentSizes = array}> : (memref<8x8xf32>) -> () +// CANONICALIZE-NEXT: }) : (memref<8x8xf32>, memref<8x8xf32>, memref<8x8xf32>) -> memref<8x8xf32> +// CANONICALIZE-NEXT: return +// CANONICALIZE-NEXT: } +// CANONICALIZE-NEXT: } \ No newline at end of file From 26bd047c05f726381143464ee69dbb927fa3e810 Mon Sep 17 00:00:00 2001 From: ShangkunLI Date: Mon, 19 Jan 2026 14:03:35 +0800 Subject: [PATCH 8/8] update comments --- .../Transforms/ConstructHyperblockFromTaskPass.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/TaskflowDialect/Transforms/ConstructHyperblockFromTaskPass.cpp b/lib/TaskflowDialect/Transforms/ConstructHyperblockFromTaskPass.cpp index 69f0de4..763e615 100644 --- a/lib/TaskflowDialect/Transforms/ConstructHyperblockFromTaskPass.cpp +++ b/lib/TaskflowDialect/Transforms/ConstructHyperblockFromTaskPass.cpp @@ -203,7 +203,7 @@ static void extractHyperblocksInfoFromRegion( SmallVector loop_indices = parent_indices; loop_indices.push_back(loop_info->counter_index); - // 分析哪些 current_ops 被这个循环使用 + // Analyzes which of the current_ops are used by this loop. DenseSet values_used_in_loop; for_op.walk([&](Operation *nested_op) { for (Value operand : nested_op->getOperands()) {