From a89b09325ef31ef53e9062b4571eae8b84aed98a Mon Sep 17 00:00:00 2001 From: ShouyangDong Date: Mon, 10 Nov 2025 10:58:31 +0800 Subject: [PATCH] add sm90 related passes --- python/tvm/tir/transform/transform.py | 58 + .../eliminate_storage_sync_for_mbarrier.cc | 119 ++ src/tir/transforms/inject_ptx_async_copy.cc | 4 + src/tir/transforms/inject_tma_barrier.cc | 599 ++++++++ src/tir/transforms/lower_hopper_intrin.cc | 171 +++ src/tir/transforms/lower_shared_tmem.cc | 319 ++++ .../transforms/warp_specialized_rewriter.cc | 1302 +++++++++++++++++ .../transforms/warp_specialized_rewriter.h | 99 ++ src/tir/transforms/wgmma_sync_rewriter.cc | 270 ++++ 9 files changed, 2941 insertions(+) create mode 100644 src/tir/transforms/eliminate_storage_sync_for_mbarrier.cc create mode 100644 src/tir/transforms/inject_tma_barrier.cc create mode 100644 src/tir/transforms/lower_hopper_intrin.cc create mode 100644 src/tir/transforms/lower_shared_tmem.cc create mode 100644 src/tir/transforms/warp_specialized_rewriter.cc create mode 100644 src/tir/transforms/warp_specialized_rewriter.h create mode 100644 src/tir/transforms/wgmma_sync_rewriter.cc diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index d8531401d49d..1eb90c1b74d2 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -1212,3 +1212,61 @@ def UseAssumeToReduceBranches(): The result pass """ return _ffi_api.UseAssumeToReduceBranches() # type: ignore + +def LowerHopperIntrin(): + """LowerHopperIntrin + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return (_ffi_api.LowerHopperIntrin() if hasattr(_ffi_api, "LowerHopperIntrin") else lambda f: f + ) # type: ignore + + +def WarpSpecializedPipeline(): + """WarpSpecializedPipeline + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.WarpSpecializedPipeline() # type: ignore + + +def RewriteWgmmaSync(): + """RewriteWgmmaSync + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.RewriteWgmmaSync() # type: ignore + +def InjectTmaBarrier(): + """InjectTmaBarrier + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.InjectTmaBarrier() # type: ignore + +def EliminateStorageSyncForMBarrier(): + """EliminateStorageSyncForMBarrier + """ + return _ffi_api.EliminateStorageSyncForMBarrier() # type: ignore + +def LowerSharedBarrier(): + """LowerSharedBarrier + """ + return _ffi_api.LowerSharedBarrier() # type: ignore + +def LowerSharedTmem(): + """LowerSharedTmem + """ + return _ffi_api.LowerSharedTmem() # type: ignore diff --git a/src/tir/transforms/eliminate_storage_sync_for_mbarrier.cc b/src/tir/transforms/eliminate_storage_sync_for_mbarrier.cc new file mode 100644 index 000000000000..f9a899f27f58 --- /dev/null +++ b/src/tir/transforms/eliminate_storage_sync_for_mbarrier.cc @@ -0,0 +1,119 @@ +/*! + * \file eliminate_storage_sync_for_mbarrier.cc + */ +#include "../op/builtin.h" +#include "./storage_access.h" +#include "arith/ir_mutator_with_analyzer.h" +#include "arith/ir_visitor_with_analyzer.h" +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace tir { + +using namespace tir; +using arith::IRMutatorWithAnalyzer; +using arith::IRVisitorWithAnalyzer; + +class Eliminator : public IRMutatorWithAnalyzer { +public: + static Stmt Substitute(const Stmt &stmt, bool skip_thread_partition = false) { + arith::Analyzer analyzer; + Eliminator transformer(&analyzer); + return transformer.VisitStmt(stmt); + } + + Eliminator(arith::Analyzer *analyzer) : IRMutatorWithAnalyzer(analyzer) { + im_mbarrier_for_ = false; + in_mbarrier_region_ = false; + } + + Stmt VisitStmt_(const AttrStmtNode *op) final { + if (op->attr_key == "thread_extent") { + if (const auto *var = op->node.as()) { + if (var->name_hint == "threadIdx.x") { + thread_extent_ = op; + } + } + } + return IRMutatorWithAnalyzer::VisitStmt_(op); + } + + Stmt VisitStmt_(const EvaluateNode *op) final { + const CallNode *call = nullptr; + if (op->value->IsInstance()) { + call = op->value.as(); + if (call->op.same_as(builtin::tvm_storage_sync())) { + // Skip storage sync if we're in a region with mbarrier operations + // and we're not in a for loop with mbarrier operations + if (in_mbarrier_region_ || im_mbarrier_for_) { + return Stmt(); + } + } else if (call->op.same_as(builtin::ptx_arrive_barrier()) || + call->op.same_as(builtin::ptx_wait_barrier())) { + in_mbarrier_region_ = true; + } + } + return IRMutatorWithAnalyzer::VisitStmt_(op); + } + + Stmt VisitStmt_(const IfThenElseNode *op) final { + bool old_in_mbarrier = in_mbarrier_region_; + Stmt then_case = VisitStmt(op->then_case); + + Stmt ret; + if (op->else_case.defined()) { + in_mbarrier_region_ = old_in_mbarrier; + Stmt else_case = VisitStmt(op->else_case.value()); + in_mbarrier_region_ = old_in_mbarrier || in_mbarrier_region_; + ret = IfThenElse(VisitExpr(op->condition), then_case, else_case); + } else { + in_mbarrier_region_ = old_in_mbarrier || in_mbarrier_region_; + ret = IfThenElse(VisitExpr(op->condition), then_case, Stmt()); + } + return ret; + } + + Stmt VisitStmt_(const ForNode *op) final { + PostOrderVisit(tvm::ffi::GetRef(op), [&](const ObjectRef &node) { + if (const auto *call = node.as()) { + if (call->op.same_as(create_list_of_mbarrier()) || + call->op.same_as(mbarrier_wait_parity()) || + call->op.same_as(builtin::ptx_arrive_barrier()) || + call->op.same_as(builtin::ptx_cp_async_barrier())) { + im_mbarrier_for_ = true; + } + } + }); + auto stmt = IRMutatorWithAnalyzer::VisitStmt_(op); + im_mbarrier_for_ = false; + return stmt; + } + +private: + bool im_mbarrier_for_; + bool in_mbarrier_region_; + const AttrStmtNode *thread_extent_{nullptr}; +}; +using namespace tir::transform; + +namespace transform { + +tvm::transform::Pass EliminateStorageSyncForMBarrier() { + auto pass_func = [](PrimFunc f, const IRModule &m, const PassContext &ctx) { + auto *n = f.CopyOnWrite(); + n->body = Eliminator::Substitute(n->body); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.transform.EliminateStorageSyncForMBarrier", + {}); +} + +} // namespace transform +} // namespace tir +} // namespace tvm \ No newline at end of file diff --git a/src/tir/transforms/inject_ptx_async_copy.cc b/src/tir/transforms/inject_ptx_async_copy.cc index 5d23e854be02..5b5208ae1855 100644 --- a/src/tir/transforms/inject_ptx_async_copy.cc +++ b/src/tir/transforms/inject_ptx_async_copy.cc @@ -169,10 +169,14 @@ class PTXAsyncCopyInjector : public StmtMutator { if (auto* b = call->args[2].as()) { if (auto* f = b->value.as()) { else_value_is_zero = f->value == 0.0f; + } else if (auto *i = b->value.as()) { + else_value_is_zero = i->value == 0; } } if (auto* f = call->args[2].as()) { else_value_is_zero = f->value == 0.0f; + } else if (auto *i = call->args[2].as()) { + else_value_is_zero = i->value == 0; } if (else_value_is_zero) { return InjectPTX(load, store, true, call->args[0]); diff --git a/src/tir/transforms/inject_tma_barrier.cc b/src/tir/transforms/inject_tma_barrier.cc new file mode 100644 index 000000000000..2e64accca65f --- /dev/null +++ b/src/tir/transforms/inject_tma_barrier.cc @@ -0,0 +1,599 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tma_barrier_rewriter.cc + * \brief Rewrite TMA barriers for cuda GPU (sm90+) + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "../op/builtin.h" +#include "./common/attr.h" +#include "./common/collector.h" +#include "arith/ir_mutator_with_analyzer.h" +#include "arith/ir_visitor_with_analyzer.h" + +namespace tvm { +namespace tir { +namespace transform { + +using arith::IRMutatorWithAnalyzer; +using arith::IRVisitorWithAnalyzer; + +class TmaTraitsCollector : public StmtExprVisitor { +public: + TmaTraitsCollector() { Initialize(); } + + void Initialize() { + bulk_copy_bytes = 0; + loop_extents = 1; + } + + void Collect(const Stmt &stmt) { VisitStmt(stmt); } + + PrimExpr BulkCopyBytes() { return bulk_copy_bytes; } + +private: + void VisitExpr_(const CallNode *call) final { + if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) { + auto arg0 = call->args[0].as(); + if (call->op.same_as(tma_load()) && arg0 && + !arg0.value()->op.same_as(create_tma_descriptor())) { + // 1D TMA load has tvm_access_ptr of shared tensor in its args[0] + bulk_copy_bytes = call->args[3] * loop_extents; + } else { + Call access_ptr = Downcast(call->args[2]); + ICHECK(access_ptr->op.same_as(builtin::tvm_access_ptr())); + int type_bytes = access_ptr->args[0]->dtype.bytes(); + bulk_copy_bytes += access_ptr->args[3] * loop_extents * type_bytes; + } + } + StmtExprVisitor::VisitExpr_(call); + } + + void VisitStmt_(const ForNode *op) final { + PrimExpr old_loop_evtents = loop_extents; + loop_extents *= op->extent; + StmtExprVisitor::VisitStmt_(op); + loop_extents = old_loop_evtents; + } + + PrimExpr bulk_copy_bytes = 0; + PrimExpr loop_extents = 1; +}; + +class TmaExpectTxRewriter : public IRMutatorWithAnalyzer { +public: + static PrimFunc Rewrite(PrimFunc f, arith::Analyzer *analyzer) { + TmaExpectTxRewriter rewriter(analyzer); + f.CopyOnWrite()->body = rewriter(f->body); + return f; + } + +private: + bool inside_tma_block_{false}; + bool visited_tma_load_{false}; + IterVar thread_var_ = IterVar(Range::FromMinExtent(0, 1), Var("v_thread"), + IterVarType::kDataPar); + + PrimExpr makeGetBarrier(PrimExpr barrier_id) { + return Call(DataType::Handle(), get_mbarrier(), {std::move(barrier_id)}); + } + + Stmt makeExpectTX(PrimExpr barrier_id, PrimExpr bytes) { + auto call = Call(DataType::Handle(), mbarrier_expect_tx(), + {makeGetBarrier(std::move(barrier_id)), std::move(bytes)}); + return Evaluate(call); + } + + TmaExpectTxRewriter(arith::Analyzer *analyzer) + : IRMutatorWithAnalyzer(analyzer) {} + + Stmt VisitStmt_(const AttrStmtNode *op) final { + + if (op->attr_key == tir::attr::thread_extent) { + IterVar iv = Downcast(op->node); + if (iv->thread_tag == "threadIdx.x") { + ICHECK(iv->dom->extent.as()); + thread_var_ = iv; + } + } + return IRMutatorWithAnalyzer::VisitStmt_(op); + } + + Stmt VisitStmt_(const IfThenElseNode *op) { + // Check if this is the TMA block + bool flag = false; + if (op->condition.as()) { + flag = op->condition.as()->op.same_as(tl_shuffle_elect()); + } + if (op->condition.as() || flag) { + Stmt ret = IRMutatorWithAnalyzer::VisitStmt_(op); + + if (visited_tma_load_) { + auto then_case = op->then_case; + TmaTraitsCollector collector; + collector.Collect(then_case); + + Array stmts; + if (!is_zero(collector.BulkCopyBytes())) { + auto expect_tx = makeExpectTX(0, collector.BulkCopyBytes()); + stmts.push_back(expect_tx); + } + stmts.push_back(then_case); + if (stmts.size() == 1) { + return IfThenElse(op->condition, stmts[0], op->else_case); + } else { + auto seq_stmt = SeqStmt(stmts); + return IfThenElse(op->condition, seq_stmt, op->else_case); + } + } + visited_tma_load_ = false; + return ret; + } + return IRMutatorWithAnalyzer::VisitStmt_(op); + } + + PrimExpr VisitExpr_(const CallNode *op) { + if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col())) { + auto arg0 = op->args[0].as(); + bool is_1d_tma_load = + arg0 && !arg0.value()->op.same_as(create_tma_descriptor()) && + op->op.same_as(tma_load()); + visited_tma_load_ = true; + Array new_args = op->args; + new_args.Set(is_1d_tma_load ? 2 : 1, + Call(DataType::Handle(), get_mbarrier(), + {IntImm(DataType::Int(32), 0)})); + return Call(op->dtype, op->op, new_args); + } + return IRMutatorWithAnalyzer::VisitExpr_(op); + } +}; + +class TmaBarrierCollector : public IRVisitorWithAnalyzer { +public: + TmaBarrierCollector(Map buffer_data_to_buffer) + : buffer_data_to_buffer_(std::move(buffer_data_to_buffer)) {} + + Map tma_op_to_barrier_id() { + return tma_op_to_barrier_id_; + } + Map barrier_id_to_range() { return barrier_id_to_range_; } + +private: + void UpdateBarrierRange(const PrimExpr &barrier_id, const IntImm &extent) { + if (barrier_id_to_range_.count(barrier_id)) { + auto old_extent = barrier_id_to_range_[barrier_id]; + ICHECK_EQ(old_extent->value, extent->value) + << "barrier_id: " << barrier_id << " has different extent"; + barrier_id_to_range_.Set(barrier_id, extent); + } else { + barrier_id_to_range_.Set(barrier_id, extent); + } + } + + void VisitStmt_(const EvaluateNode *op) final { + if (const auto *call = op->value.as()) { + if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) { + pending_tma_ops_.push_back(tvm::ffi::GetRef(call)); + } else if (call->op.same_as(mbarrier_expect_tx())) { + pending_tma_ops_.push_back(tvm::ffi::GetRef(call)); + } else if (call->op.same_as(builtin::ptx_arrive_barrier())) { + PrimExpr barrier_id = call->args[0]; + for (const auto &tma_call : pending_tma_ops_) { + tma_op_to_barrier_id_.Set(tma_call, barrier_id); + } + auto const_int_bound = analyzer_.const_int_bound(thread_var_); + auto extent = + const_int_bound->max_value - const_int_bound->min_value + 1; + UpdateBarrierRange(barrier_id, IntImm(DataType::Int(32), extent)); + pending_tma_ops_.clear(); + } else if (call->op.same_as(builtin::ptx_wait_barrier())) { + PrimExpr barrier_id = call->args[0]; + auto const_int_bound = analyzer_.const_int_bound(thread_var_); + auto extent = + const_int_bound->max_value - const_int_bound->min_value + 1; + UpdateBarrierRange(barrier_id, IntImm(DataType::Int(32), extent)); + } + } + StmtExprVisitor::VisitStmt_(op); + } + + void VisitStmt_(const AttrStmtNode *op) { + if (op->attr_key == tir::attr::thread_extent) { + IterVar iv = Downcast(op->node); + if (iv->thread_tag == "threadIdx.x") { + thread_var_ = iv; + } + } + IRVisitorWithAnalyzer::VisitStmt_(op); + } + + IterVar thread_var_; + std::vector pending_tma_ops_; + Map tma_op_to_barrier_id_; + Map barrier_id_to_range_; + Map buffer_data_to_buffer_; +}; + +class TmaSequenceCollector : public IRVisitorWithAnalyzer { +public: + TmaSequenceCollector(Map tma_op_to_barrier_id) + : tma_op_to_barrier_id_(std::move(tma_op_to_barrier_id)) {} + + std::vector GetSequence() { + std::vector clear_zero_list(expect_tx_count_, false); + int zero_idx = -1; + int zero_count = 0; + + for (auto v : sequence) { + if (v == 0) { + zero_count += 1; + zero_idx += 1; + } else { + if (zero_count == 1) { + clear_zero_list[zero_idx] = expect_[zero_idx] && !has_simt_copy_; + if (clear_zero_list[zero_idx] == false) { + int begin = int_sets_[zero_idx].min().as()->value; + int end = int_sets_[zero_idx].max().as()->value; + for (int i = begin; i <= end; ++i) { + restore_barrier_ids_.push_back(i); + } + } + } else { + for (int i{zero_idx}; i > zero_idx - zero_count; --i) { + int begin = int_sets_[i].min().as()->value; + int end = int_sets_[i].max().as()->value; + for (int i = begin; i <= end; ++i) { + restore_barrier_ids_.push_back(i); + } + } + } + zero_count = 0; + } + } + + return clear_zero_list; + } + + std::vector GetRestoreBarrierIds() { return restore_barrier_ids_; } + + void VisitStmt_(const ForNode *op) final { + var_int_set_.Set(op->loop_var, + arith::IntSet::FromMinExtent(op->min, op->extent)); + IRVisitorWithAnalyzer::VisitStmt_(op); + } + + void VisitExpr_(const CallNode *op) final { + if (op->op.same_as(mbarrier_expect_tx())) { + auto call_ref = tvm::ffi::GetRef(op); + if (tma_op_to_barrier_id_.count(call_ref)) { + PrimExpr e = tma_op_to_barrier_id_[call_ref].as()->args[0]; + auto int_set = arith::EvalSet(e, var_int_set_); + expect_.push_back(if_depth_ == 1); + sequence.push_back(0); + int_sets_.push_back(int_set); + expect_tx_count_ += 1; + } + } else if (op->op.same_as(builtin::ptx_arrive_barrier())) { + sequence.push_back(1); + } else if (op->op.same_as(builtin::ptx_cp_async_barrier())) { + has_simt_copy_ = true; + } + IRVisitorWithAnalyzer::VisitExpr_(op); + } + + void VisitStmt_(const IfThenElseNode *op) final { + if_depth_ += 1; + + IRVisitorWithAnalyzer::VisitStmt(op->then_case); + + if (op->else_case) { + IRVisitorWithAnalyzer::VisitStmt(op->else_case.value()); + } + if_depth_ -= 1; + } + + std::vector sequence; + int expect_tx_count_{0}; + std::vector expect_; + bool has_simt_copy_{false}; + std::vector restore_barrier_ids_; + int if_depth_{0}; + Map tma_op_to_barrier_id_; + arith::Analyzer *analyzer_{}; + Map var_int_set_; + std::vector int_sets_; +}; + +class BarrierCreationRewriter : public StmtExprMutator { +public: + BarrierCreationRewriter(std::vector restore_barrier_ids, + PrimExpr producer_thread_extent, + int ensure_min_count = 0, + PrimExpr default_barrier_thread_count = 1) + : restore_barrier_ids_(std::move(restore_barrier_ids)), + producer_thread_extent_(std::move(producer_thread_extent)), + ensure_min_count_(ensure_min_count), + default_barrier_thread_count_(std::move(default_barrier_thread_count)) { + } + + PrimExpr VisitExpr_(const CallNode *op) { + if (op->op.same_as(create_list_of_mbarrier())) { + size_t cur_n = op->args.size(); + size_t need_n = + std::max(cur_n, static_cast(ensure_min_count_)); + + // Mark barriers to restore across the full needed length, not just the + // original length, so newly appended entries can be restored as well. + std::vector replace(need_n, false); + for (auto &id : restore_barrier_ids_) { + if (id >= 0 && static_cast(id) < replace.size()) { + replace[id] = true; + } + } + + Array new_args; + new_args.reserve(need_n); + + // Preserve/override existing entries + for (size_t i{0}; i < cur_n; ++i) { + if (replace[i]) { + new_args.push_back(producer_thread_extent_); + } else { + new_args.push_back(op->args[i]); + } + } + // Append additional barriers if required + for (size_t i = cur_n; i < need_n; ++i) { + if (replace[i]) { + new_args.push_back(producer_thread_extent_); + } else { + new_args.push_back(default_barrier_thread_count_); + } + } + + return Call(op->dtype, op->op, new_args); + } else { + return StmtExprMutator::VisitExpr_(op); + } + } + +private: + std::vector restore_barrier_ids_; + PrimExpr producer_thread_extent_; + int ensure_min_count_{0}; + PrimExpr default_barrier_thread_count_{1}; +}; + +// we trust mbarrier_wait_parity to be correct +class TmaBarrierRewriter : public IRMutatorWithAnalyzer { +public: + TmaBarrierRewriter(arith::Analyzer *analyzer, + Map tma_op_to_barrier_id, + Map barrier_id_to_range, + bool has_create_list_of_mbarrier) + : IRMutatorWithAnalyzer(analyzer), + tma_op_to_barrier_id_(std::move(tma_op_to_barrier_id)), + barrier_id_to_range_(std::move(barrier_id_to_range)), + has_create_list_of_mbarrier_(has_create_list_of_mbarrier) {} + + static PrimFunc Rewrite(PrimFunc f, arith::Analyzer *analyzer) { + auto buffer_lca = DetectBufferAccessLCA(f); + Map buffer_data_to_buffer_; + for (auto [buffer, _] : buffer_lca) + buffer_data_to_buffer_.Set(buffer->data, buffer); + f = TmaExpectTxRewriter::Rewrite(f, analyzer); + TmaBarrierCollector collector(buffer_data_to_buffer_); + collector(f->body); + bool has_create_list_of_mbarrier = false; + PostOrderVisit(f->body, [&](const ObjectRef &node) { + if (const auto *call = node.as()) { + if (call->op.same_as(create_list_of_mbarrier())) { + has_create_list_of_mbarrier = true; + } else if (call->op.same_as(builtin::ptx_init_barrier_thread_count())) { + has_create_list_of_mbarrier = true; + } + } + }); + TmaBarrierRewriter rewriter(analyzer, collector.tma_op_to_barrier_id(), + collector.barrier_id_to_range(), + has_create_list_of_mbarrier); + f.CopyOnWrite()->body = rewriter(f->body); + // Compute the minimum number of barriers actually referenced in the body + // after TMA barrier rewrites (e.g., get_mbarrier(0) inserted for TMA). + struct GetMbarrierMaxIdxCollector : public StmtExprVisitor { + int max_idx{-1}; + void VisitExpr_(const CallNode *op) final { + if (op->op.same_as(get_mbarrier())) { + if (op->args.size() == 1) { + if (const auto *imm = op->args[0].as()) { + max_idx = std::max(max_idx, static_cast(imm->value)); + } + } + } + StmtExprVisitor::VisitExpr_(op); + } + }; + + GetMbarrierMaxIdxCollector max_idx_collector; + max_idx_collector(f->body); + int ensure_min_count = max_idx_collector.max_idx + 1; // 0-based -> count + + // For simple TMA-only producers, default barrier arrive count should be 1 + // (only the elected leader performs the TMA arrive/expect). + auto barrier_creation_rewriter = BarrierCreationRewriter( + rewriter.restore_barrier_ids_, rewriter.producer_thread_extent_, + ensure_min_count, Integer(1)); + f.CopyOnWrite()->body = barrier_creation_rewriter(f->body); + return f; + } + +private: + Stmt VisitStmt_(const BlockNode *op) { + auto block = tvm::ffi::GetRef(op); + if (!has_create_list_of_mbarrier_ && !barrier_id_to_range_.empty() && + op->name_hint == MainBlockName) { + ICHECK(false) << "Please declare create_list_of_mbarrier."; + } + return IRMutatorWithAnalyzer::VisitStmt_(op); + } + + Stmt VisitStmt_(const IfThenElseNode *op) { + if (first_if) { + if (op->condition.as()) { + producer_thread_extent_ = + thread_var_->dom->extent - op->condition.as()->b; + } + TmaSequenceCollector collector(tma_op_to_barrier_id_); + collector(op->then_case); + clear_expect_list_ = collector.GetSequence(); + restore_barrier_ids_ = collector.GetRestoreBarrierIds(); + first_if = false; + + is_producer_ = true; + + auto then_case = StmtExprMutator::VisitStmt(op->then_case); + + is_producer_ = false; + Stmt else_case; + if (op->else_case.defined()) + else_case = StmtExprMutator::VisitStmt(op->else_case.value()); + return IfThenElse(op->condition, then_case, else_case); + } + return StmtExprMutator::VisitStmt_(op); + } + + Stmt VisitStmt_(const AttrStmtNode *op) final { + if (op->attr_key == "kWarpSpecializationScope") { + has_warp_specialization_ = true; + first_if = true; + } else if (op->attr_key == tir::attr::thread_extent && + Downcast(op->node)->thread_tag == "threadIdx.x") { + thread_var_ = Downcast(op->node); + } + return IRMutatorWithAnalyzer::VisitStmt_(op); + } + + PrimExpr VisitExpr_(const CallNode *op) { + if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col())) { + auto call_ref = tvm::ffi::GetRef(op); + if (!tma_op_to_barrier_id_.count(call_ref)) { + // For 1D TMA loads, promote raw integer barrier id to get_mbarrier(id) + // so codegen can emit mbarrier[index]. This handles degenerate + // producer-only kernels where no arrive() is seen and mapping is empty. + auto arg0 = op->args[0].as(); + bool is_1d_tma_load = + arg0 && !arg0.value()->op.same_as(create_tma_descriptor()) && + !arg0.value()->op.same_as(create_tma_im2col_descriptor()); + if (is_1d_tma_load && op->args.size() >= 3) { + if (const auto *imm = op->args[2].as()) { + Array new_args = op->args; + new_args.Set(2, Call(DataType::Handle(), get_mbarrier(), + {IntImm(DataType::Int(32), + static_cast(imm->value))})); + return Call(op->dtype, op->op, new_args); + } + } + return IRMutatorWithAnalyzer::VisitExpr_(op); + } + auto barrier_id = tma_op_to_barrier_id_[call_ref]; + auto new_args = op->args; + auto arg0 = op->args[0].as(); + auto is_1d_tma_load = + arg0 && !arg0.value()->op.same_as(create_tma_descriptor()) && + !arg0.value()->op.same_as(create_tma_im2col_descriptor()); + if (is_1d_tma_load) { + new_args.Set(2, barrier_id); + } else { + new_args.Set(1, barrier_id); + } + return Call(op->dtype, op->op, new_args); + } else if (op->op.same_as(mbarrier_expect_tx())) { + auto call_ref = tvm::ffi::GetRef(op); + if (!tma_op_to_barrier_id_.count(call_ref)) { + return IRMutatorWithAnalyzer::VisitExpr_(op); + } + auto barrier_id = tma_op_to_barrier_id_[call_ref]; + auto new_args = op->args; + new_args.Set(0, barrier_id); + if (!has_warp_specialization_) + clear_arrive_ = false; + else + clear_arrive_ = clear_expect_list_[cur_expect_idx_++]; + if (clear_arrive_) { + return Call(op->dtype, builtin::ptx_arrive_barrier_expect_tx(), + new_args); + } + return Call(op->dtype, op->op, new_args); + } else if (op->op.same_as(builtin::ptx_arrive_barrier())) { + if (clear_arrive_) { + clear_arrive_ = false; + return 0; + } + // by default, all threads must wait. + auto new_args = op->args; + return Call(op->dtype, op->op, new_args); + } + return IRMutatorWithAnalyzer::VisitExpr_(op); + } + Map tma_op_to_barrier_id_; + Map barrier_id_to_range_; + bool has_create_list_of_mbarrier_; + bool clear_arrive_{false}; + bool first_if{false}, has_warp_specialization_{false}, is_producer_{false}; + IterVar thread_var_; + int tma_expect_tx_{0}, cur_expect_idx_{0}; + std::vector clear_expect_list_; + std::vector restore_barrier_ids_; + PrimExpr producer_thread_extent_; +}; + +tvm::transform::Pass InjectTmaBarrier() { + auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) { + // Check if function only uses threadIdx.x before proceeding + if (!ThreadTagChecker::HasOnlyThreadIdxX(f)) { + LOG(WARNING) << "InjectTmaBarrier will be disabled because the program " + "uses thread tags other than threadIdx.x\n" + << "If you want to use TMA barrier, please refactor " + "your program to use threadIdx.x only"; + // Return original function unchanged if other thread tags are found + return f; + } + arith::Analyzer analyzer; + return TmaBarrierRewriter::Rewrite(f, &analyzer); + }; + return CreatePrimFuncPass(pass_func, 0, "tir.transform.InjectTmaBarrier", {}); +} +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/lower_hopper_intrin.cc b/src/tir/transforms/lower_hopper_intrin.cc new file mode 100644 index 000000000000..2392f1cbe00a --- /dev/null +++ b/src/tir/transforms/lower_hopper_intrin.cc @@ -0,0 +1,171 @@ +/*! + * \file lower hopper intrin.cc + * \brief Lower Hopper intrinsics cuda GPU(sm90+) + */ + +#include +#include +#include +#include +#include +#include + +#include "../op/builtin.h" +#include "../runtime/runtime.h" + +namespace tvm { +namespace tir { +namespace transform { + +#if (CUDA_MAJOR_VERSION >= 12) +class LowerHopperIntrin : public StmtExprMutator { +public: + static PrimFunc Substitute(PrimFunc &f, bool disable_shuffle_elect) { + PrimFuncNode *fptr = f.CopyOnWrite(); + LowerHopperIntrin substituter(disable_shuffle_elect); + fptr->body = substituter.VisitStmt(f->body); + Map> init_desc_arg_map; + for (const auto &[call, var] : substituter.desc_map_) { + // Should allocate 128 bytes for TensorMap on stack + Call alloc_desc = Call(DataType::Handle(), builtin::tvm_stack_alloca(), + {StringImm("arg_value"), 16}); + Array init_desc_args; + if (call->op.same_as(create_tma_descriptor())) { + init_desc_args.push_back(StringImm(tvm_tensormap_create_tiled)); + } else if (call->op.same_as(create_tma_im2col_descriptor())) { + init_desc_args.push_back(StringImm(tvm_tensormap_create_im2col)); + } else { + CHECK(0) << call->op; + } + init_desc_args.push_back(var); + init_desc_args.insert(init_desc_args.end(), call->args.begin(), + call->args.end()); + // add to function attribute + Call init_desc = + Call(DataType::Handle(), builtin::tvm_call_packed(), init_desc_args); + fptr->body = + LetStmt(var, alloc_desc, SeqStmt({Evaluate(init_desc), fptr->body})); + init_desc_arg_map.Set(var, init_desc_args); + } + f = WithAttr(std::move(f), "tma_descriptor_args", init_desc_arg_map); + return f; + } + + Stmt VisitStmt_(const AttrStmtNode *op) final { + // Insert the prefetch TMA descriptor statement TO the beginning of the + // kernel + if (op->attr_key == tir::attr::thread_extent) { + IterVar iv = Downcast(op->node); + if (iv->thread_tag == "threadIdx.x") { + auto body = StmtExprMutator::VisitStmt(op->body); + if (prefetch_calls_.empty() && init_mbarrier_calls_.empty()) { + return AttrStmt(op->node, op->attr_key, op->value, body); + } else { + Array stmt_seq; + if (!init_mbarrier_calls_.empty()) { + auto alloc_mbarrier = + Evaluate(Call(DataType::Handle(), builtin::create_barriers(), + {static_cast(init_mbarrier_calls_.size())})); + stmt_seq.push_back(alloc_mbarrier); + } + + auto stmts = prefetch_calls_; + stmts.insert(stmts.end(), init_mbarrier_calls_.begin(), + init_mbarrier_calls_.end()); + PrimExpr condition; + if (!disable_shuffle_elect_) { + condition = Call(DataType::Bool(), tl_shuffle_elect(), {0}); + } else { + condition = EQ(iv->var, 0); + } + auto stmt_ = IfThenElse(condition, + stmts.size() > 1 ? SeqStmt(stmts) : stmts[0]); + stmt_seq.push_back(stmt_); + if (!init_mbarrier_calls_.empty()) { + // Note from FlashAttention: + // Helps with visibility of barrier init operations across warps / + // cta / cluster Available as a separate function so as to batch + // inits across barriers and fence once Note : It must be composed + // with an appropriate sync instruction with the right scope to + // ensure visibility eg. __syncthreads() or a cluster_arrive() + + // cluster_wait() + Stmt mem_fence = Evaluate(Call( + DataType::Handle(), tvm::tl::ptx_fence_barrier_init(), {})); + stmt_seq.push_back(mem_fence); + Stmt mem_sync = + Evaluate(Call(DataType::Handle(), builtin::tvm_storage_sync(), + {StringImm("shared")})); + stmt_seq.push_back(mem_sync); + } + stmt_seq.push_back(body); + + prefetch_calls_.clear(); + init_mbarrier_calls_.clear(); + return AttrStmt(op->node, op->attr_key, op->value, SeqStmt(stmt_seq)); + } + } + } + return StmtExprMutator::VisitStmt_(op); + } + + PrimExpr VisitExpr_(const CallNode *call) final { + if (call->op.same_as(create_tma_descriptor()) || + call->op.same_as(create_tma_im2col_descriptor())) { + Var var; + auto iter = desc_map_.find(tvm::ffi::GetRef(call)); + if (iter != desc_map_.end()) { + var = iter->second; + } else { + String name = call->args[2].as().value()->name_hint; + var = Var(name + "_desc", + PointerType(PrimType(cuTensorMapType()), "grid_constant")); + desc_map_[tvm::ffi::GetRef(call)] = var; + prefetch_calls_.push_back( + Evaluate(Call(DataType::Handle(), builtin::call_extern(), + {StringImm("tl::prefetch_tma_descriptor"), var}))); + } + return var; + } else if (call->op.same_as(create_list_of_mbarrier())) { + ICHECK(init_mbarrier_calls_.empty()); + int num_barriers = static_cast(call->args.size()); + for (int i = 0; i < num_barriers; i++) { + PrimExpr mbarrier = Call(DataType::Handle(), get_mbarrier(), {i}); + init_mbarrier_calls_.push_back(Evaluate( + Call(DataType::Handle(), builtin::ptx_init_barrier_thread_count(), + {mbarrier, call->args[i]}))); + } + return 0; + } else { + return StmtExprMutator::VisitExpr_(call); + } + } + +private: + Array prefetch_calls_; + Array init_mbarrier_calls_; + std::unordered_map desc_map_; + LowerHopperIntrin(bool disable_shuffle_elect) + : disable_shuffle_elect_(disable_shuffle_elect) {} + bool disable_shuffle_elect_; +}; + +using namespace tir::transform; + +tvm::transform::Pass LowerHopperIntrin() { + auto pass_func = [=](PrimFunc f, const IRModule &m, PassContext ctx) { + bool disable_shuffle_elect = + ctx->GetConfig(kDisableShuffleElect, Bool(false)).value(); + return LowerHopperIntrin::Substitute(f, disable_shuffle_elect); + }; + return CreatePrimFuncPass(pass_func, 0, "tir.transform.LowerHopperIntrin", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.transform.LowerHopperIntrin", LowerHopperIntrin); +} +#endif // (CUDA_MAJOR_VERSION >= 12) + +} // namespace transform +} // namespace tir +} // namespace tvm \ No newline at end of file diff --git a/src/tir/transforms/lower_shared_tmem.cc b/src/tir/transforms/lower_shared_tmem.cc new file mode 100644 index 000000000000..3b1f876732a5 --- /dev/null +++ b/src/tir/transforms/lower_shared_tmem.cc @@ -0,0 +1,319 @@ +/*! + * \file lower_shared_tmem.cc + * \brief Convert shared.tmem buffers to plain shared + ptx init, and do + * coordinate translation (from logical address to physical address) + */ +#include "../op/builtin.h" +#include "../target/utils.h" +#include "tvm/ir/type.h" +#include "tvm/tir/builtin.h" +#include "tvm/tir/expr.h" +#include "tvm/tir/stmt.h" +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace tir { + +class SharedTmemRewriter : public StmtExprMutator { +public: + static Stmt Rewrite(Stmt body) { + SharedTmemRewriter rewriter; + return rewriter(body); + } + +private: + Stmt VisitStmt_(const BlockNode *op) final { + Block block = tvm::ffi::GetRef(op); + Array alloc_buffers = op->alloc_buffers; + if (op->annotations.count(attr::kLayoutMap)) { + auto layout_map = op->annotations.Get(attr::kLayoutMap); + ICHECK(layout_map) << "layout map is not defined"; + layout_map_ = layout_map->as>().value(); + } + + // Record the mapping from buffer data var to buffer for later lookup + for (auto buffer : alloc_buffers) { + buffer_map_.insert({buffer->data, buffer}); + } + for (auto match_buffer : op->match_buffers) { + buffer_map_.insert({match_buffer->buffer->data, match_buffer->buffer}); + } + + Array tmem_buffers; + + for (const auto &[data, buffer] : buffer_map_) { + const auto *ptr_type = + buffer->data->type_annotation.as(); + auto storage_scope = ptr_type->storage_scope; + ICHECK(ptr_type) << "Buffer Var's type annotation must be of PointerType"; + if (storage_scope == "shared.tmem") { + tmem_buffers.push_back(buffer); + } + } + + if (tmem_buffers.empty()) { + return StmtExprMutator::VisitStmt_(op); + } + + ICHECK(thread_var_.defined()) << "thread_var_ is not defined"; + + for (auto buffer : tmem_buffers) { + buffer_data_to_buffer_.Set(buffer->data, buffer); + } + + /* + Transform the tmem buffers to new allocations + transform: + tmem_buf0 = T.alloc_buffer((128, 128,), "uint64", + scope="shared.tmem") + tmem_buf1 = T.alloc_buffer((128, 128,), "uint64", + scope="shared.tmem") + + into: + tmem_buf0 = T.alloc_buffer((1,), "uint64", scope="shared.tmem_addr") + tmem_buf1 = T.alloc_buffer((1,), "uint64", scope="shared.tmem_addr") + + if tx == 0: + T.ptx_init_tensor_memory(tmem_buf0[0], 128) + T.ptx_init_tensor_memory(tmem_buf1[0], 128) + */ + // 1. create new data vars + Array new_data_vars; + for (auto buffer : tmem_buffers) { + auto data = buffer->data; + if (var_remap_.count(data)) + continue; + auto new_data = + Var(data->name_hint, PointerType(PrimType(tmem_dtype_), "shared")); + var_remap_.Set(data, new_data); + new_data_vars.push_back(new_data); + } + + // 2. create new buffers + Array new_buffers; + for (auto buffer : tmem_buffers) { + auto data = buffer->data; + ICHECK(var_remap_.find(data) != var_remap_.end()) + << "data not found in var_remap_"; + auto new_data = var_remap_.at(data); + auto new_buffer = Buffer(new_data, tmem_dtype_, Array({1}), + Array({1}), PrimExpr(0), buffer->name, + buffer->data_alignment, buffer->offset_factor, + buffer->buffer_type); + new_buffers.push_back(new_buffer); + buffer_remap_.Set(buffer, new_buffer); + buffer_data_to_buffer_.Set(new_data, new_buffer); + } + + // remove the tmem buffers + alloc_buffers.MutateByApply([this](Buffer buf) { + if (buffer_remap_.find(buf) != buffer_remap_.end()) { + return buffer_remap_.at(buf); + } + return buf; + }); + if (!alloc_buffers.same_as(op->alloc_buffers)) { + block.CopyOnWrite()->alloc_buffers = alloc_buffers; + } else { + return StmtExprMutator::VisitStmt_(op); + } + + // 3. create init & dealloc calls for new buffers + std::vector init_mtmem_calls_; + std::vector dealloc_tmem_calls_; + for (auto buffer : tmem_buffers) { + auto data = buffer->data; + auto old_buffer = buffer_data_to_buffer_.at(data); + auto new_buffer = buffer_remap_.at(old_buffer); + + // Tmem physical coord range analysis + ICHECK(old_buffer->shape.size() == 2); + + auto analyzer = std::make_shared(); + arith::ConstIntBound phy_col_bounds = + analyzer->const_int_bound(old_buffer->shape[1]); + int num_cols_required = phy_col_bounds->max_value; + ICHECK(num_cols_required <= 512) + << "The number of columns required for tmem buffer " + << old_buffer->name << " is " << num_cols_required + << ", which exceeds the maximum of 512 columns"; + + int num_cols_allocated = 32; // Align num_cols_allocated to power of 2 + for (; num_cols_allocated < num_cols_required; num_cols_allocated *= 2) + ; + + auto new_buffer_access = new_buffer.access_ptr(1, DataType::Handle(), 1, + PrimExpr(0), PrimExpr(1)); + auto alloc_call = Call(DataType::Handle(), tl::ptx_init_tensor_memory(), + {new_buffer_access, PrimExpr(num_cols_allocated)}); + init_mtmem_calls_.push_back(Evaluate(alloc_call)); + auto dealloc_call = + Call(DataType::Handle(), tl::ptx_deallocate_tensor_memory(), + {new_buffer_access, PrimExpr(num_cols_allocated)}); + dealloc_tmem_calls_.push_back(Evaluate(dealloc_call)); + } + auto compare_by_buffer_name = [&](const Stmt &a, const Stmt &b) { + auto call_a = a.as()->value.as(); + auto call_b = b.as()->value.as(); + auto num_cols_a = call_a->args[1].as()->value; + auto num_cols_b = call_b->args[1].as()->value; + return num_cols_a > num_cols_b; + }; + std::sort(init_mtmem_calls_.begin(), init_mtmem_calls_.end(), + compare_by_buffer_name); + + Array new_body; + auto target = Target::Current(); + auto warp_size = TargetGetWarpSize(target); + auto thread_var_div_warp_size = + FloorDiv(thread_var_->var, IntImm(thread_var_->var->dtype, warp_size)); + new_body.push_back(IfThenElse(EQ(thread_var_div_warp_size, 0), + init_mtmem_calls_.size() > 1 + ? SeqStmt(init_mtmem_calls_) + : init_mtmem_calls_.back(), + Stmt())); + new_body.push_back( + Evaluate(Call(DataType::Handle(), builtin::tvm_storage_sync(), + {StringImm("shared")}))); + new_body.push_back(block->body); + new_body.push_back(IfThenElse(EQ(thread_var_div_warp_size, 0), + dealloc_tmem_calls_.size() > 1 + ? SeqStmt(dealloc_tmem_calls_) + : dealloc_tmem_calls_.back(), + Stmt())); + + auto block_ptr = block.CopyOnWrite(); + block_ptr->annotations.erase(attr::kLayoutMap); + block_ptr->body = SeqStmt(new_body); + + return StmtExprMutator::VisitStmt_(block.get()); + } + + PrimExpr GetTmemOffset(const Buffer &buffer, const Array &indices) { + ICHECK(buffer->shape.size() == 2); + ICHECK(indices.size() == 2); + ICHECK(layout_map_.defined()); + ICHECK(layout_map_.count(buffer)) + << "The layout of tmem buffer " << buffer->name + << " is not defined in the layout map"; + auto layout = layout_map_[buffer]; + ICHECK(layout.defined()); + Array tmem_phy_coords = layout->Forward(indices); + PrimExpr result = + tmem_phy_coords[0] << 16 | + tmem_phy_coords + [1]; // https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-memory-addressing + return result; + } + + PrimExpr VisitExpr_(const BufferLoadNode *op) final { + // Translate tmem[logical_row, logical_col] to tmem[0] + tmem_offset + // Where + // - (logical_row, logical_col) is the logical address in the tmem buffer + // - tmem[0] is the base address allocated for the tmem buffer + // - tmem_offset = tmem_phy_coords[0]<<16 | tmem_phy_coords[1] + // where tmem_phy_coords = layout.Forward(logical_row, logical_col) + // is the physical address in the tmem buffer + auto load = Downcast(StmtExprMutator::VisitExpr_(op)); + auto buffer = load->buffer; + auto indices = load->indices; + + if (buffer_remap_.count(buffer)) { + auto new_buffer = buffer_remap_[load->buffer]; + return BufferLoad(new_buffer, {0}) + GetTmemOffset(buffer, indices); + } else if (var_remap_.count(buffer->data)) { + auto new_buffer = Buffer( + var_remap_[buffer->data], tmem_dtype_, buffer->shape, buffer->strides, + buffer->elem_offset, buffer->name, buffer->data_alignment, + buffer->offset_factor, buffer->buffer_type); + return BufferLoad(new_buffer, {0}) + GetTmemOffset(buffer, indices); + } + return load; + } + + Stmt VisitStmt_(const BufferStoreNode *op) final { + auto store = Downcast(StmtExprMutator::VisitStmt_(op)); + auto buffer = store->buffer; + ICHECK(buffer.scope() != "shared.tmem") + << "We should never directly store data into tmem!"; + return store; + } + + PrimExpr VisitExpr_(const CallNode *op) final { + if (op->op.same_as(builtin::tvm_access_ptr())) { + ICHECK_EQ(op->args.size(), 5U); + Var buffer_data = Downcast(op->args[1]); + if (!var_remap_.count(buffer_data)) { + return StmtExprMutator::VisitExpr_(op); + } + Var new_data = var_remap_[buffer_data]; + return Call( + op->dtype, op->op, + {op->args[0], new_data, op->args[2], op->args[3], op->args[4]}); + } + auto expr = StmtExprMutator::VisitExpr_(op); + return expr; + } + PrimExpr VisitExpr_(const VarNode *op) final { + Var var = tvm::ffi::GetRef(op); + if (var_remap_.count(var)) { + return var_remap_[var]; + } + return var; + } + + Stmt VisitStmt_(const AttrStmtNode *op) final { + if (op->attr_key == tir::attr::thread_extent) { + IterVar iv = Downcast(op->node); + if (iv->thread_tag == "threadIdx.x") { + ICHECK(iv->dom->extent.as()); + thread_var_ = iv; + } + } + return StmtExprMutator::VisitStmt_(op); + } + + // Datatypes for tmem + const DataType tmem_dtype_ = DataType::UInt(32); + // This is a workaround for cpu backend, + // we need to define a thread_var for the serial loop. + IterVar thread_var_; + Map var_remap_; + Map buffer_data_to_buffer_; + Map buffer_remap_; + // Mapping from data Var of a Buffer to Buffer, for lookup + std::unordered_map buffer_map_; + Map layout_map_; +}; + +PrimFunc LowerSharedTmem(PrimFunc f) { + auto target = f->GetAttr(tvm::attr::kTarget); + ICHECK(target.defined()) << "LowerSharedTmem: Require the target attribute"; + SharedTmemRewriter rewriter; + f.CopyOnWrite()->body = rewriter.Rewrite(f->body); + return f; +} + +namespace transform { +using namespace tir::transform; + +tvm::transform::Pass LowerSharedTmem() { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + return tl::LowerSharedTmem(std::move(f)); + }; + return CreatePrimFuncPass(pass_func, 0, "tir.transform.LowerSharedTmem", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.transform.LowerSharedTmem", LowerSharedTmem); +} + +} // namespace transform +} // namespace tir +} // namespace tvm \ No newline at end of file diff --git a/src/tir/transforms/warp_specialized_rewriter.cc b/src/tir/transforms/warp_specialized_rewriter.cc new file mode 100644 index 000000000000..5d05f4d72082 --- /dev/null +++ b/src/tir/transforms/warp_specialized_rewriter.cc @@ -0,0 +1,1302 @@ +/*! + * \file warp_specialized_rewriter.cc + * \brief Warp specialized Pipeline for cuda GPU (sm90+) +*/ + +#include "warp_specialized_rewriter.h" + +namespace tvm { +namespace tir { +namespace transform { +using namespace runtime; +using arith::IRVisitorWithAnalyzer; + +struct LoopInfo { + Var loop_var; + PrimExpr extent; + PrimExpr min; +}; + +enum class Role : uint8_t { kConsumer, kProducer, kBoth }; + +class ProducerBufferDetector : public StmtExprVisitor { +public: + ProducerBufferDetector( + std::unordered_set cur_producer_buffers) + : cur_producer_buffers_(std::move(cur_producer_buffers)) {} + + void clear() { has_producer_buffer_ = false; } + + void VisitExpr_(const CallNode *call) final { + if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) { + has_producer_buffer_ = true; + } + StmtExprVisitor::VisitExpr_(call); + } + + void VisitExpr_(const BufferLoadNode *op) final { + if (cur_producer_buffers_.count(op->buffer.get())) { + has_producer_buffer_ = true; + } + StmtExprVisitor::VisitExpr_(op); + } + + bool has_producer_buffer_ = false; + std::unordered_set cur_producer_buffers_; +}; + +class ProducerUsedBufferFinder : public StmtExprVisitor { +public: + auto FindProducerusedBuffer(const Stmt &stmt) { + producer_buffers_.clear(); + std::unordered_set last_producer_buffers_; + for (;;) { + VisitStmt(stmt); + if (producer_buffers_ == last_producer_buffers_) { + break; + } + last_producer_buffers_ = producer_buffers_; + } + return producer_buffers_; + } + + void InsertBuffer(const PrimExpr &expr) { + // Find the buffer that is used in the condition + VarUseDefAnalyzer usage(Array{}); + usage(expr); + for (const auto &buffer : usage.buffer_use_count_) { + producer_buffers_.insert(buffer.first); + } + } + + void VisitStmt_(const IfThenElseNode *op) final { + ProducerBufferDetector producer_buffer_detector(producer_buffers_); + producer_buffer_detector(op->then_case); + if (op->else_case.defined()) { + producer_buffer_detector(op->else_case.value()); + } + if (producer_buffer_detector.has_producer_buffer_) { + InsertBuffer(op->condition); + } + StmtExprVisitor::VisitStmt_(op); + } + + void VisitStmt_(const ForNode *op) final { + ProducerBufferDetector producer_buffer_detector(producer_buffers_); + producer_buffer_detector(op->body); + if (producer_buffer_detector.has_producer_buffer_) { + InsertBuffer(op->min); + InsertBuffer(op->extent); + } + StmtExprVisitor::VisitStmt_(op); + } + + void VisitStmt_(const BufferStoreNode *op) final { + if (producer_buffers_.count(op->buffer.get())) { + InsertBuffer(op->value); + } + StmtExprVisitor::VisitStmt_(op); + } + + void VisitExpr_(const CallNode *op) final { + if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col())) { + for (auto arg : op->args) { + if (auto buffer_load = arg.as()) { + producer_buffers_.insert(buffer_load->buffer.get()); + } + } + } + } + +private: + std::unordered_set producer_buffers_; +}; + +class WarpSpecializedRoleMarker : public StmtVisitor { +public: + WarpSpecializedRoleMarker(Map buffer_data_to_buffer) + : buffer_data_to_buffer_(std::move(buffer_data_to_buffer)) {} + + void Prepare(const Stmt &stmt) { + ProducerUsedBufferFinder finder; + producer_buffers_ = finder.FindProducerusedBuffer(stmt); + } + + Role GetRole(const StmtNode *stmt) const { + auto it = map_.find(stmt); + ICHECK(it != map_.end()); + return it->second; + } + + Role GetRole(const Stmt &stmt) const { return GetRole(stmt.get()); } + + void VisitStmt_(const EvaluateNode *op) final { + Role role = Role::kConsumer; + if (auto call = op->value.as()) { + if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) { + role = Role::kProducer; + has_bulk_copy_ = true; + } + if (call->op.same_as(loop_break())) { + role = Role::kBoth; + } + } + SetRole(op, role); + } + + void VisitStmt_(const BufferStoreNode *op) final { + auto scope = StorageScope::Create(GetPtrStorageScope(op->buffer->data)); + bool is_shared_store = scope.rank == StorageRank::kShared; + if (producer_buffers_.count(op->buffer.get())) { + SetRole(op, Role::kBoth); + return; + } + if (!is_shared_store) { + SetRole(op, Role::kConsumer); + return; + } + + // Check reads from global + Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"", + /*body*/ tvm::ffi::GetRef(op)); + auto access = GetBlockReadWriteRegion(block, buffer_data_to_buffer_); + auto reads = access[0]; + Role role = Role::kProducer; + if (reads.empty()) + role = Role::kConsumer; + for (auto read : reads) { + if (read->buffer.scope() != "global") { + role = Role::kConsumer; + break; + } + } + if (role == Role::kProducer) + has_simt_copy_ = true; + SetRole(op, role); + } + + void VisitStmt_(const SeqStmtNode *op) final { + StmtVisitor::VisitStmt_(op); + auto role = GetRole(op->seq[0]); + for (auto stmt : op->seq) { + if (role != GetRole(stmt)) { + role = Role::kBoth; + break; + } + } + SetRole(op, role); + } + + void VisitStmt_(const IfThenElseNode *op) final { + StmtVisitor::VisitStmt_(op); + auto role = GetRole(op->then_case); + if (op->else_case.defined()) { + auto role_else = GetRole(op->else_case.value()); + if (role != role_else) + role = Role::kBoth; + } + SetRole(op, role); + } + + void VisitStmt_(const BlockRealizeNode *op) final { + StmtVisitor::VisitStmt_(op); + SetRole(op, GetRole(op->block)); + } + + void VisitStmt_(const AllocateNode *op) final { + StmtVisitor::VisitStmt_(op); + Role role = Role::kConsumer; + SetRole(op, role); + } + + template void HandleBodyStmt(const NodeType *op) { + StmtVisitor::VisitStmt_(op); + SetRole(op, GetRole(op->body)); + } + + void VisitStmt_(const ForNode *op) final { HandleBodyStmt(op); } + void VisitStmt_(const WhileNode *op) final { HandleBodyStmt(op); } + void VisitStmt_(const LetStmtNode *op) final { HandleBodyStmt(op); } + void VisitStmt_(const AttrStmtNode *op) final { HandleBodyStmt(op); } + void VisitStmt_(const AssertStmtNode *op) final { HandleBodyStmt(op); } + void VisitStmt_(const BlockNode *op) final { HandleBodyStmt(op); } + + bool HasProducer() { return has_simt_copy_ || has_bulk_copy_; } + + bool HasSimtCopy() { return has_simt_copy_; } + +private: + void SetRole(const StmtNode *stmt, Role role) { map_[stmt] = role; } + Map buffer_data_to_buffer_; + std::unordered_map map_; + bool has_simt_copy_ = false; + bool has_bulk_copy_ = false; + std::unordered_set producer_buffers_; +}; + +static PrimExpr makeGetBarrier(PrimExpr barrier_id) { + return Call(DataType::Handle(), get_mbarrier(), {std::move(barrier_id)}); +} + +static Stmt makeArriveBarrier(PrimExpr barrier_id, int cta_id = -1, + const PrimExpr &pred = 1) { + Array args = {makeGetBarrier(std::move(barrier_id))}; + if (cta_id != -1) { + args.push_back(cta_id); + args.push_back(pred); + } + return Evaluate( + Call(DataType::Handle(), builtin::ptx_arrive_barrier(), args)); +} + +static Stmt makeCpAsyncBarrier(PrimExpr barrier_id) { + auto call = Call(DataType::Handle(), builtin::ptx_cp_async_barrier(), + {makeGetBarrier(std::move(barrier_id))}); + return Evaluate(call); +} + +static Stmt makeParityWait(PrimExpr barrier_id, PrimExpr parity) { + auto call = Call(DataType::Handle(), mbarrier_wait_parity(), + {makeGetBarrier(std::move(barrier_id)), std::move(parity)}); + return Evaluate(call); +} + +class ProducerTraitsCollector : public StmtExprVisitor { +public: + ProducerTraitsCollector() { Clear(); } + + void Clear() { has_simt_copy = false; } + + void Collect(const Stmt &stmt) { VisitStmt(stmt); } + + bool HasSimtCopy() { return has_simt_copy; } + +private: + void VisitStmt_(const IfThenElseNode *op) final { + bool old_in_if_cond = in_if_cond_; + in_if_cond_ = true; + VisitExpr(op->condition); + in_if_cond_ = old_in_if_cond; + + VisitStmt(op->then_case); + if (op->else_case.defined()) { + VisitStmt(op->else_case.value()); + } + } + + void VisitExpr_(const BufferLoadNode *op) final { + if (!in_if_cond_) { + has_simt_copy = true; + } + StmtExprVisitor::VisitExpr_(op); + } + + bool has_simt_copy{}; + bool in_if_cond_ = false; +}; + +// Rewrite the producer Stmt to use the correct barrier index +class MbarrierRewriter : public StmtExprMutator { +public: + static Stmt Rewrite(Stmt stmt, PrimExpr barrier_id) { + MbarrierRewriter rewriter; + rewriter.producer_barrier_idx_ = std::move(barrier_id); + return rewriter(std::move(stmt)); + } + +private: + PrimExpr VisitExpr_(const CallNode *op) final { + auto call = Downcast(StmtExprMutator::VisitExpr_(op)); + if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) { + auto mbar = makeGetBarrier(producer_barrier_idx_); + auto arg0 = call->args[0].as(); + // Check if this is a 1D TMA load + auto is_1d_tma_load = + arg0 && !arg0.value()->op.same_as(create_tma_descriptor()) && + call->op.same_as(tma_load()); + if (is_1d_tma_load) { + call.CopyOnWrite()->args.Set(2, mbar); + } else { + Call access_ptr = Downcast(call->args[2]); + ICHECK(access_ptr->op.same_as(builtin::tvm_access_ptr())); + call.CopyOnWrite()->args.Set(1, mbar); + } + } + return call; + } + PrimExpr producer_barrier_idx_; +}; + +class ThreadIdxRewriter : public StmtExprMutator { +public: + static Stmt Rewrite(Stmt stmt, Var thread_var, PrimExpr replaced, + PrimExpr thread_extent, bool do_shuffle = false) { + auto rewriter = + ThreadIdxRewriter(std::move(thread_var), std::move(replaced), + std::move(thread_extent), do_shuffle); + return rewriter(std::move(stmt)); + } + +private: + ThreadIdxRewriter(Var thread_var, PrimExpr replaced, PrimExpr thread_extent, + bool do_shuffle) + : thread_var_(std::move(thread_var)), replaced_(std::move(replaced)), + thread_extent_(std::move(thread_extent)), do_shuffle_(do_shuffle) {} + + PrimExpr VisitExpr_(const VarNode *var) final { + if (var == thread_var_.get()) { + return replaced_; + } else { + return StmtExprMutator::VisitExpr_(var); + } + } + + Stmt VisitStmt_(const IfThenElseNode *op) final { + auto f_uses_thread_index = [=](const tvm::tir::VarNode *parameter) { + return parameter == thread_var_.get(); + }; + maybe_thread_opt_ = false; + if (!op->else_case.defined() && op->condition.as() && + UsesVar(op->condition, f_uses_thread_index) && + !(UsesVar(op->then_case, f_uses_thread_index))) { + auto eq_op = Downcast(op->condition); + if (eq_op->a.as() == thread_var_.get() || + eq_op->b.as() == thread_var_.get()) { + maybe_thread_opt_ = true; + } + auto then_case = StmtExprMutator::VisitStmt(op->then_case); + maybe_thread_opt_ = do_shuffle_ && maybe_thread_opt_ && has_tma_op_; + has_tma_op_ = false; + if (maybe_thread_opt_) { + return IfThenElse( + Call(DataType::Bool(), tl_shuffle_elect(), {thread_extent_}), + StmtExprMutator::VisitStmt(op->then_case), std::nullopt); + } + } + return StmtExprMutator::VisitStmt_(op); + } + + PrimExpr VisitExpr_(const CallNode *op) final { + if (op->op.same_as(tl::tma_load()) || + op->op.same_as(tl::tma_load_im2col()) || + op->op.same_as(tl::tma_store())) { + has_tma_op_ = true; + } + return StmtExprMutator::VisitExpr_(op); + } + + Var thread_var_; + PrimExpr replaced_; + PrimExpr thread_extent_; + bool maybe_thread_opt_ = false; + bool do_shuffle_; + bool has_tma_op_ = false; +}; + +Block MakeGroupBlock(const Stmt &stmt, + const Map &annotations) { + Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"", + /*body*/ stmt, + /*init=*/{}, /*alloc_buffers=*/{}, /*match_buffers=*/{}, + /*annotations=*/annotations); + return block; +} + +struct OpInfo { + int group_size{}, order{}, stage{}; + std::vector group; +}; +struct PipelineInfo { + std::vector op_infos; + + PipelineInfo() = default; + PipelineInfo(const Array> &group_info, + const Array &order_info, + const Array &stage_info) { + int n = static_cast(group_info.size()); + ICHECK(n == static_cast(order_info.size())); + ICHECK(n == static_cast(stage_info.size())); + // int cur_id = 0; + for (int i = 0; i < n; i++) { + OpInfo op_info; + op_info.group_size = group_info[i].size(); + for (int j = 0; j < op_info.group_size; j++) { + op_info.group.push_back(group_info[i][j].as()->value); + } + op_info.order = order_info[i].as()->value; + op_info.stage = stage_info[i].as()->value; + op_infos.push_back(op_info); + } + } + + PipelineInfo(const PipelineInfo &other) { + for (const auto &op_info : other.op_infos) { + op_infos.push_back(op_info); + } + } + + std::pair FindStmt(int stmt_idx) { + for (size_t i = 0; i < op_infos.size(); i++) { + for (size_t j = 0; j < op_infos[i].group.size(); j++) { + if (op_infos[i].group[j] == stmt_idx) { + return std::make_pair(i, j); + } + } + } + return std::make_pair(-1, -1); + } + + void UpdateOrder(int order) { + for (int i = 0; i < static_cast(op_infos.size()); i++) { + if (op_infos[i].order >= order && op_infos[i].order > 0) { + op_infos[i].order++; + } + } + } + + int SplitOp(int stmt_idx) { + auto pair = FindStmt(stmt_idx); + int op_idx = pair.first; + int inner_idx = pair.second; + ICHECK(op_idx != -1); + ICHECK(inner_idx != -1); + OpInfo half0; + OpInfo half1; + // The order to do sync + int sync_order = op_infos[op_idx].order + 1; + UpdateOrder(sync_order); + + half0.group_size = inner_idx + 1; + half0.order = op_infos[op_idx].order; + half0.stage = op_infos[op_idx].stage; + for (int i = 0; i <= inner_idx; i++) { + half0.group.push_back(op_infos[op_idx].group[i]); + } + half1.group_size = op_infos[op_idx].group_size - inner_idx - 1; + half1.order = op_infos[op_idx].order + 2; + half1.stage = op_infos[op_idx].stage; + for (int i = inner_idx + 1; i < op_infos[op_idx].group_size; i++) { + half1.group.push_back(op_infos[op_idx].group[i]); + } + op_infos.erase(op_infos.begin() + op_idx); + if (half0.group_size > 0) { + op_infos.insert(op_infos.begin() + op_idx, half0); + } + if (half1.group_size > 0) { + UpdateOrder(half1.order); + op_infos.insert(op_infos.begin() + op_idx + 1, half1); + } + return sync_order; + } + + void PrintPipelineInfo() { + std::cout << "Print op_infos:" << '\n'; + for (size_t i = 0; i < op_infos.size(); i++) { + std::cout << i << " " << op_infos[i].group_size << " " + << op_infos[i].order << " " << op_infos[i].stage << '\n'; + } + std::cout << "End of print" << '\n'; + } +}; + +class GroupOpRewriter : public StmtExprMutator { +public: + GroupOpRewriter(const PipelineInfo &pipeline_info) + : pipeline_info_(pipeline_info) {} + +private: + Stmt VisitStmt_(const ForNode *op) final { + Map annotations; + annotations.Set(String("stmt_group"), Integer(1)); + auto original_node = (op->body).as(); + if (!original_node) { + return tvm::ffi::GetRef(op); + } + Array new_body; + int cur_id = 0; + for (int i = 0; i < static_cast(pipeline_info_.op_infos.size()); i++) { + if (pipeline_info_.op_infos[i].group_size == 0) + continue; + Array block_stmt; + for (int j = 0; + j < static_cast(pipeline_info_.op_infos[i].group_size); j++) { + // ICHECK(group_info_[i][j].as()); + // int index = + // static_cast(group_info_[i][j].as()->value); + ICHECK(original_node->seq[cur_id].as()); + auto block = original_node->seq[cur_id].as(); + // TODO: handle nested seqstmt + block_stmt.push_back(block->body); + cur_id++; + } + new_body.push_back(MakeGroupBlock( + block_stmt.size() == 1 ? block_stmt[0] + // NOLINTNEXTLINE(performance-move-const-arg) + : SeqStmt(std::move(block_stmt)), + annotations)); + } + Array order_anno; + Array stage_anno; + for (const auto &op_info : pipeline_info_.op_infos) { + order_anno.push_back(Integer(op_info.order)); + stage_anno.push_back(Integer(op_info.stage)); + } + Map for_annotations = op->annotations; + for_annotations.erase("tl_pipeline_group"); + for_annotations.Set("software_pipeline_order", order_anno); + for_annotations.Set("software_pipeline_stage", stage_anno); + For new_for = + For(op->loop_var, op->min, op->extent, op->kind, + new_body.size() == 1 ? new_body[0] : SeqStmt(std::move(new_body)), + op->thread_binding, for_annotations); + return new_for; + } + + PipelineInfo pipeline_info_; +}; + +class WgMMACollector : public StmtExprVisitor { +public: + WgMMACollector() = default; + + void VisitExpr_(const CallNode *op) final { + if (op->op.same_as(tl_gemm()) || op->op.same_as(tl_gemm_sp())) { + auto op_name = std::string(op->args[0].as()->value); + if (has_wgmma_) { + has_wgmma_ = + op_name.find("false") == std::string::npos && !in_if_scope_; + } + } + StmtExprVisitor::VisitExpr_(op); + } + + void VisitStmt_(const IfThenElseNode *op) final { + in_if_scope_ = true; + StmtExprVisitor::VisitStmt(op->then_case); + if (op->else_case.defined()) { + StmtExprVisitor::VisitStmt(op->else_case.value()); + } + in_if_scope_ = false; + } + + static bool HasWgMMA(const Stmt &stmt) { + auto collector = WgMMACollector(); + collector(stmt); + return collector.has_wgmma_; + } + + bool has_wgmma_{true}; + bool in_if_scope_{false}; +}; + +class WSCodeEmitter : public StmtMutator { +public: + WSCodeEmitter(bool is_emitting_producer, const IterVar &thread_iv, + Map buffer_data_to_buffer, + const WarpSpecializedRoleMarker &marker, + bool mbarrier_only = false) + : is_emitting_producer_(is_emitting_producer), + buffer_data_to_buffer_(std::move(buffer_data_to_buffer)), + marker_(marker), thread_var_(thread_iv->var), + mbarrier_only_(mbarrier_only) {} + + /** + * @brief Whether a SIMT-style bulk copy was detected. + * + * Returns true when a simulated SIMT (thread-parallel) copy pattern was + * observed during analysis/emission, which can affect barrier insertion and + * copy emission. + * + * @return true if a SIMT copy was detected; false otherwise. + */ + bool hasSimtCopy() const { return has_simt_copy_; } + +private: + template < + typename NodeType> /** + * @brief Filter a statement by its producer/consumer + * role for emission. + * + * Returns one of: + * - the original statement (unchanged) when this + * emitter should emit it, + * - the result of visiting the statement (to descend + * into it) when mbarrier-only mode requires full + * traversal for non-producer roles, + * - an empty evaluate (`Evaluate(0)`) when the + * statement should be omitted. + * + * The decision is based on the role of `op` as + * reported by `marker_`, the emitter mode + * (`is_emitting_producer_`), and the `mbarrier_only_` + * flag. + * + * @param op The statement node to filter; its role is + * queried via `marker_`. + * @return Stmt The statement to place into the emitted + * IR (possibly transformed or an empty evaluate). + */ + Stmt FilterByRole(const NodeType *op) { + Role role = marker_.GetRole(op); + if (mbarrier_only_) { + if (role != Role::kProducer) + return StmtMutator::VisitStmt_(op); + } + if (role == Role::kBoth) { + return StmtMutator::VisitStmt_(op); + } else if ((role == Role::kProducer) == is_emitting_producer_) { + return tvm::ffi::GetRef(op); + } else { + return Evaluate(0); + } + } + + Stmt VisitStmt_(const SeqStmtNode *op) final { + + bool has_producer = false; + for (auto stmt : op->seq) { + if (marker_.GetRole(stmt) == Role::kProducer) { + has_producer = true; + break; + } + } + bool need_producer_sync = + has_producer && marker_.GetRole(op) == Role::kBoth; + if (!need_producer_sync) + return FilterByRole(op); + + auto seq_transformed = + op->seq.Map([&](const Stmt &stmt) { return VisitStmt(stmt); }); + + auto map = ExtractSyncPattern(op->seq); + + /* + std::cout << "Print ExtractSyncPattern" << std::endl; + for (int i = 0; i < static_cast(op->seq.size()); i++) { + std::cout << i << " " << map.acquire[i] << " " << map.release[i] << " " + << map.release_after[i] << std::endl; + } + std::cout << "Print sync pattern" << std::endl; + for (auto pattern : map.patterns) { + std::cout << pattern.release_idx << " " << pattern.acquire_idx << + std::endl; + } + std::cout << "End of ExtractSyncPattern" << std::endl; + pipeline_info_.PrintPipelineInfo(); + */ + Array new_body; + Map annotations; + annotations.Set(String("stmt_group"), Integer(1)); + + if (is_emitting_producer_) { // producer case + ProducerTraitsCollector collector; + for (int i = 0; i < static_cast(op->seq.size()); i++) { + Array block_stmt = {}; + if (!mbarrier_only_) { + if (marker_.GetRole(op->seq[i]) == Role::kConsumer) + continue; + if (marker_.GetRole(op->seq[i]) == Role::kBoth) { + block_stmt.push_back(seq_transformed[i]); + new_body.push_back( + MakeGroupBlock(block_stmt.size() == 1 + ? block_stmt[0] + // NOLINTNEXTLINE(performance-move-const-arg) + : SeqStmt(std::move(block_stmt)), + annotations)); + continue; + } + } + + for (int pattern_idx : map.acquire[i]) { + PrimExpr acquire_barrier_id = + stage_ + num_barriers_ + num_stages_ * pattern_idx; + PrimExpr parity = map.is_loop_dependency(pattern_idx) + ? bitwise_xor(parity_, 1) + : parity_; + block_stmt.push_back(makeParityWait(acquire_barrier_id, parity)); + } + ICHECK(!map.release[i].empty()); + for (size_t j = 0; j < map.release[i].size(); j++) { + int pattern_idx = map.release[i][j]; + PrimExpr release_barrier_id = + stage_ + num_barriers_ + num_stages_ * pattern_idx; + auto stmt = + MbarrierRewriter::Rewrite(seq_transformed[i], release_barrier_id); + collector.Collect(stmt); + block_stmt.push_back(stmt); + if (collector.HasSimtCopy()) { + block_stmt.push_back(makeCpAsyncBarrier(release_barrier_id)); + has_simt_copy_ = true; + } + if (map.release_after[i][j]) { + block_stmt.push_back(makeArriveBarrier(release_barrier_id)); + for (int s = 0; s < num_stages_; s++) { + released_barrier_.insert(s + num_barriers_ + + num_stages_ * pattern_idx); + } + } + collector.Clear(); + new_body.push_back( + MakeGroupBlock(block_stmt.size() == 1 + ? block_stmt[0] + // NOLINTNEXTLINE(performance-move-const-arg) + : SeqStmt(std::move(block_stmt)), + annotations)); + } + } + } else { // consumer case + for (int i = 0; i < static_cast(op->seq.size()); i++) { + Array block_stmt = {}; + if (marker_.GetRole(op->seq[i]) == Role::kProducer) + continue; + for (int pattern_idx : map.acquire[i]) { + PrimExpr acquire_barrier_id = + stage_ + num_barriers_ + num_stages_ * pattern_idx; + PrimExpr parity = map.is_loop_dependency(pattern_idx) + ? bitwise_xor(parity_, 1) + : parity_; + block_stmt.push_back(makeParityWait(acquire_barrier_id, parity)); + } + block_stmt.push_back(seq_transformed[i]); + for (size_t j = 0; j < map.release[i].size(); j++) { + if (map.release_after[i][j]) { + int pattern_idx = map.release[i][j]; + PrimExpr release_barrier_id = + stage_ + num_barriers_ + num_stages_ * pattern_idx; + block_stmt.push_back(makeArriveBarrier(release_barrier_id)); + for (int s = 0; s < num_stages_; s++) { + released_barrier_.insert(s + num_barriers_ + + num_stages_ * pattern_idx); + } + } + } + new_body.push_back(MakeGroupBlock( + block_stmt.size() == 1 ? block_stmt[0] + // NOLINTNEXTLINE(performance-move-const-arg) + : SeqStmt(std::move(block_stmt)), + annotations)); + } + // Filter out the producer stmts + int cur_id = 0; + PipelineInfo new_pipeline_info; + for (int i = 0; i < static_cast(pipeline_info_.op_infos.size()); + i++) { + auto op_info = pipeline_info_.op_infos[i]; + bool is_producer = false; + for (int j = 0; j < op_info.group_size; j++) { + if (marker_.GetRole(op->seq[cur_id]) == Role::kProducer) { + is_producer = true; + } + cur_id++; + } + if (is_producer) { + ICHECK(op_info.group_size == 1); + } else { + new_pipeline_info.op_infos.push_back(op_info); + } + } + pipeline_info_ = new_pipeline_info; + } + + num_barriers_ += map.patterns.size() * num_stages_; + + ICHECK(!new_body.empty()); + return new_body.size() == 1 ? new_body[0] : SeqStmt(std::move(new_body)); + } + + Stmt VisitStmt_(const ForNode *op) final { + int num_stages = 1; + auto num_stages_anno = op->annotations.Get("num_stages"); + if (num_stages_anno) { + ICHECK(num_stages_anno->as()); + num_stages = static_cast(num_stages_anno->as()->value); + ICHECK(num_stages_ == 1) << "Nested pipeline not supported."; + } + loop_stack_.emplace_back(LoopInfo{op->loop_var, op->extent, op->min}); + + Array> group_info_array; + Array order_info_array; + Array stage_info_array; + + auto group_anno = op->annotations.Get("tl_pipeline_group"); + if (group_anno) { + group_info_array = Downcast>>(group_anno.value()); + } + auto order_anno = op->annotations.Get("tl_pipeline_order"); + if (order_anno) { + order_info_array = Downcast>(order_anno.value()); + } + auto stage_anno = op->annotations.Get("tl_pipeline_stage"); + if (stage_anno) { + stage_info_array = Downcast>(stage_anno.value()); + } + + PipelineInfo pipeline_info(group_info_array, order_info_array, + stage_info_array); + if (!pipeline_info.op_infos.empty()) { + ICHECK(pipeline_info_.op_infos.empty()) + << "Nested pipeline not supported."; + } + + PrimExpr parity_before = std::move(parity_); + PrimExpr stage_before = std::move(stage_); + int num_stages_before = num_stages_; + PipelineInfo pipeline_info_before = pipeline_info_; + + num_stages_ = num_stages; + pipeline_info_ = pipeline_info; + PrimExpr linear_index = loop_stack_[0].loop_var - loop_stack_[0].min; + for (size_t i = 1; i < loop_stack_.size(); ++i) { + linear_index = linear_index * loop_stack_[i].extent + + (loop_stack_[i].loop_var - loop_stack_[i].min); + } + stage_ = FloorMod(linear_index, num_stages); + parity_ = FloorMod( + parity_before * op->extent + FloorDiv(linear_index, num_stages), 2); + auto result = FilterByRole(op); + + Stmt grouped_for_node; + if (result.as() && group_anno && !group_info_array.empty() && + !is_emitting_producer_) { + GroupOpRewriter group_op_rewriter(pipeline_info_); + auto for_node = Downcast(result); + grouped_for_node = group_op_rewriter(for_node); + } + + parity_ = std::move(parity_before); + stage_ = std::move(stage_before); + num_stages_ = num_stages_before; + pipeline_info_ = pipeline_info_before; + + // remove pipeline annotation + auto for_node = result.as(); + if (result.as()) { + auto for_node = Downcast(result); + for_node.CopyOnWrite()->annotations.erase("num_stages"); + if (is_emitting_producer_ || group_info_array.empty()) { + for_node.CopyOnWrite()->annotations.erase("tl_pipeline_order"); + for_node.CopyOnWrite()->annotations.erase("tl_pipeline_stage"); + } + if (is_emitting_producer_ || !group_anno || group_info_array.empty()) { + loop_stack_.pop_back(); + return for_node; + } + loop_stack_.pop_back(); + return grouped_for_node; + } + loop_stack_.pop_back(); + return result; + } + + Stmt VisitStmt_(const IfThenElseNode *op) final { return FilterByRole(op); } + Stmt VisitStmt_(const EvaluateNode *op) final { return FilterByRole(op); } + Stmt VisitStmt_(const AttrStmtNode *op) final { return FilterByRole(op); } + Stmt VisitStmt_(const BufferStoreNode *op) final { return FilterByRole(op); } + Stmt VisitStmt_(const LetStmtNode *op) final { return FilterByRole(op); } + Stmt VisitStmt_(const AssertStmtNode *op) final { return FilterByRole(op); } + Stmt VisitStmt_(const BlockNode *op) final { return FilterByRole(op); } + Stmt VisitStmt_(const BlockRealizeNode *op) final { return FilterByRole(op); } + + struct SyncPattern { + int release_idx, acquire_idx; + }; + + struct SyncPatternMap { + std::vector> acquire; + std::vector> release; + std::vector> release_after; + std::vector patterns; + + void resize(size_t n) { + acquire.resize(n); + release.resize(n); + release_after.resize(n); + } + + bool is_loop_dependency(int pattern_idx) { + return patterns[pattern_idx].release_idx > + patterns[pattern_idx].acquire_idx; + } + }; + + std::vector + CreateBaseSyncPairs(const Array &seq_stmt, + const std::vector &is_producer) { + const int n = seq_stmt.size(); + std::vector> reads, writes; + reads.reserve(n); + writes.reserve(n); + for (int i = 0; i < n; i++) { + Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, + /*name_hint=*/"", + /*body*/ seq_stmt[i]); + auto access = GetBlockAccessRegion(block, buffer_data_to_buffer_); + std::set read_set, write_set; + for (auto region : access[0]) { + auto var = region->buffer->data; + if (buffer_data_to_buffer_.count(var)) { + read_set.insert(buffer_data_to_buffer_[var].get()); + } else { + read_set.insert(region->buffer.get()); + } + } + for (auto region : access[1]) { + auto var = region->buffer->data; + if (buffer_data_to_buffer_.count(var)) { + write_set.insert(buffer_data_to_buffer_[var].get()); + } else { + write_set.insert(region->buffer.get()); + } + } + reads.push_back(std::move(read_set)); + writes.push_back(std::move(write_set)); + } + + auto intersect_fn = [](const std::set &lhs, + const std::set &rhs) { + for (auto ptr : lhs) + if (rhs.count(ptr)) + return true; + return false; + }; + + std::vector sync_patterns; + // producer_release consumer_acquire, + // inject before the first consumer stmt for each producer + for (int i = 0; i < n; i++) { + for (int j = i + 1; j < n; j++) { + if (is_producer[i] != is_producer[j] && + (intersect_fn(writes[i], reads[j]) || + intersect_fn(reads[i], writes[j]))) { + sync_patterns.push_back({i, j}); + break; + } + } + } + + // consumer_release producer_acquire + // valid when is_loop is true + // inject before the earliest producer stmt for each consumer + bool in_loop = !is_zero(parity_); + if (in_loop) { + for (int i = 0; i < n; i++) { + for (int j = 0; j < i; j++) { + if (is_producer[i] != is_producer[j] && + (intersect_fn(writes[i], reads[j]) || + intersect_fn(reads[i], writes[j]))) { + sync_patterns.push_back({i, j}); + break; + } + } + } + } + + return sync_patterns; + } + + static std::vector + RemoveUnusedSyncPatterns(const std::vector &sync_patterns, + const std::vector &is_producer) { + /* + Simplify multiple release-acquire pairs into one + ------------------ + Produce(A) + Produce(B) + Consume(A, B) + ------------------ + [(0, 2), (1, 2), (2, 0)] -> [(1, 2), (2, 0)] + + Or + ------------------ + Produce(A, B) + Consume(A) + Consume(B) + ------------------ + [(0, 1), (1, 0), (2, 0)] -> [(0, 1), (2, 0)] + */ + int M = sync_patterns.size(); + std::vector removed(M, false); + for (int i = 0; i < M; i++) { + for (int j = 0; j < M; j++) { + if (is_producer[sync_patterns[i].acquire_idx] == + is_producer[sync_patterns[j].acquire_idx] && + sync_patterns[i].acquire_idx >= sync_patterns[j].acquire_idx && + sync_patterns[i].release_idx < sync_patterns[j].release_idx) + removed[i] = true; + } + } + + std::vector sync_pattern_cleaned; + sync_pattern_cleaned.reserve(M); + for (int i = 0; i < M; i++) + if (!removed[i]) + sync_pattern_cleaned.push_back(sync_patterns[i]); + + return sync_pattern_cleaned; + } + + SyncPatternMap ExtractSyncPattern(const Array &seq_stmt) { + size_t num_stmts = seq_stmt.size(); + std::vector is_producer; + is_producer.reserve(num_stmts); + for (auto stmt : seq_stmt) { + is_producer.push_back(marker_.GetRole(stmt) == Role::kProducer); + } + + auto sync_patterns_base = CreateBaseSyncPairs(seq_stmt, is_producer); + auto sync_patterns = + RemoveUnusedSyncPatterns(sync_patterns_base, is_producer); + + // for (auto pattern : sync_patterns) { + // std::cout << pattern.release_idx << " " << pattern.acquire_idx << + // std::endl; + // } + + SyncPatternMap map; + map.resize(num_stmts); + map.patterns = sync_patterns; + + for (size_t i = 0; i < sync_patterns.size(); i++) { + int acquire_idx = sync_patterns[i].acquire_idx; + int release_idx = sync_patterns[i].release_idx; + + map.acquire[acquire_idx].push_back(i); + map.release[release_idx].push_back(i); + map.release_after[release_idx].push_back(true); + } + + std::vector cur_consumer_barrier, cur_producer_barrier; + for (int i = num_stmts - 1; i >= 0; i--) { + if (is_producer[i]) { + if (map.release[i].empty()) { + for (auto pattern_idx : cur_producer_barrier) { + map.release[i].push_back(pattern_idx); + map.release_after[i].push_back(false); + } + } else { + for (auto pattern_idx : map.release[i]) { + cur_producer_barrier.push_back(pattern_idx); + } + } + } else { + if (map.release[i].empty()) { + for (auto pattern_idx : cur_consumer_barrier) { + map.release[i].push_back(pattern_idx); + map.release_after[i].push_back(false); + } + } else { + for (auto pattern_idx : map.release[i]) { + cur_consumer_barrier.push_back(pattern_idx); + } + } + } + } + return map; + } + + const bool is_emitting_producer_; + Map buffer_data_to_buffer_; + std::unordered_set released_barrier_; + const WarpSpecializedRoleMarker &marker_; + + int num_barriers_ = 0; + PrimExpr parity_ = 0; + PrimExpr stage_ = 0; + int num_stages_ = 1; + std::vector loop_stack_; + Var thread_var_; + bool mbarrier_only_ = false; + PipelineInfo pipeline_info_; + friend class WarpSpecializedRewriter; + bool has_simt_copy_ = false; +}; + +class WarpSpecializedRewriter : public StmtExprMutator { +public: + WarpSpecializedRewriter(bool disable_warp_specialized, + bool disable_shuffle_elect) + : disable_warp_specialized_(disable_warp_specialized), + disable_shuffle_elect_(disable_shuffle_elect) {} + static PrimFunc Substitute(PrimFunc f, bool disable_warp_specialized, + bool disable_shuffle_elect) { + // Check if function only uses threadIdx.x before proceeding + if (!ThreadTagChecker::HasOnlyThreadIdxX(f)) { + LOG(WARNING) << "WarpSpecialize will be disabled because the program " + "uses thread tags other than threadIdx.x." + << "If you want to use warp specialization, please refactor " + "your program to use threadIdx.x only"; + // Return original function unchanged if other thread tags are found + return f; + } + + auto T = WarpSpecializedRewriter(disable_warp_specialized, + disable_shuffle_elect); + T.buffer_lca_ = DetectBufferAccessLCA(f); + for (auto [buffer, _] : T.buffer_lca_) + T.buffer_data_to_buffer_.Set(buffer->data, buffer); + f.CopyOnWrite()->body = T(f->body); + return f; + } + +private: + Stmt VisitStmt_(const AttrStmtNode *op) final { + if (op->attr_key == tir::attr::thread_extent && + Downcast(op->node)->thread_tag == "threadIdx.x") { + thread_iv_ = Downcast(op->node); + need_update_thread_extent_ = false; + AttrStmt attr_stmt = Downcast(StmtExprMutator::VisitStmt_(op)); + if (need_update_thread_extent_) { + thread_iv_.CopyOnWrite()->dom = {0, updated_thread_extent_.value()}; + attr_stmt.CopyOnWrite()->node = thread_iv_; + attr_stmt.CopyOnWrite()->value = updated_thread_extent_.value(); + } + thread_iv_ = {}; + return attr_stmt; + } else { + return StmtExprMutator::VisitStmt_(op); + } + } + + // If users define a thread binding, we will replace the thread binding with + // threadIdx.x We require the thread binding is threadIdx.x, and the extent is + // the same as the thread extent + Stmt VisitStmt_(const ForNode *op) final { + ICHECK(thread_iv_.defined()); + For for_node = Downcast(StmtExprMutator::VisitStmt_(op)); + if (for_node->kind == ForKind::kThreadBinding) { + ICHECK(for_node->thread_binding.defined()); + String thread_tag = for_node->thread_binding.value()->thread_tag; + ICHECK(thread_tag == "threadIdx.x") << "Only support threadIdx.x"; + Var thread_iv = Downcast(for_node->loop_var); + Stmt new_body = + ThreadIdxRewriter::Rewrite(for_node->body, thread_iv, thread_iv_, 0); + return new_body; + } + return for_node; + } + + Stmt VisitStmt_(const BlockRealizeNode *op) final { + BlockRealize block_realize = + Downcast(StmtExprMutator::VisitStmt_(op)); + if (!thread_iv_.defined()) { + return block_realize; + } + + Block block = block_realize->block; + WarpSpecializedRoleMarker marker(buffer_data_to_buffer_); + marker.Prepare(block); + marker(block); + if (!marker.HasProducer()) { + // Cannot detect any producer here, directly return. + return block_realize; + } + + if (disable_warp_specialized_) { + WSCodeEmitter mbarrier_emitter(true, thread_iv_, buffer_data_to_buffer_, + marker, true); + auto code = mbarrier_emitter(block->body); + int num_barriers = mbarrier_emitter.num_barriers_; + Array barrier_num_threads; + barrier_num_threads.reserve(num_barriers); + PrimExpr arrive_thread_count = thread_iv_->dom->extent; + for (int i = 0; i < num_barriers; i++) { + barrier_num_threads.push_back(arrive_thread_count); + } + Stmt init_barrier = Evaluate(Call( + DataType::Handle(), create_list_of_mbarrier(), barrier_num_threads)); + block.CopyOnWrite()->body = SeqStmt({init_barrier, code}); + block_realize.CopyOnWrite()->block = block; + return block_realize; + } + WSCodeEmitter producer(true, thread_iv_, buffer_data_to_buffer_, marker); + WSCodeEmitter consumer(false, thread_iv_, buffer_data_to_buffer_, marker, + false); + Stmt producer_code = producer(block->body); + Stmt consumer_code = consumer(block->body); + PrimExpr consumer_thread_extent = thread_iv_->dom->extent; + PrimExpr producer_thread_extent = thread_iv_->dom->extent; + // Need one warp-group for bulk-copy only case + if (!marker.HasSimtCopy()) + producer_thread_extent = 128; + + updated_thread_extent_ = consumer_thread_extent + producer_thread_extent; + + producer_code = ThreadIdxRewriter::Rewrite( + producer_code, thread_iv_->var, + thread_iv_->var - consumer_thread_extent, producer_thread_extent, + !disable_shuffle_elect_); + consumer_code = ThreadIdxRewriter::Rewrite( + consumer_code, thread_iv_->var, thread_iv_->var, consumer_thread_extent, + !disable_shuffle_elect_); + need_update_thread_extent_ = true; + + ICHECK(producer.num_barriers_ == consumer.num_barriers_) + << producer.num_barriers_ << " " << consumer.num_barriers_; + int num_barriers = consumer.num_barriers_; + Array barrier_num_threads; + barrier_num_threads.reserve(num_barriers); + for (int i = 0; i < num_barriers; i++) { + PrimExpr arrive_thread_count = + producer.released_barrier_.count(i) + ? (producer.hasSimtCopy() ? producer_thread_extent : 1) + : consumer_thread_extent; + barrier_num_threads.push_back(arrive_thread_count); + } + + Stmt init_barrier = Evaluate(Call( + DataType::Handle(), create_list_of_mbarrier(), barrier_num_threads)); + Stmt body = IfThenElse(GE(thread_iv_->var, consumer_thread_extent), + producer_code, consumer_code); + // Add an attr here to handle the partial thread count in ThreadSync pass. + Array ws_partition = {Downcast(producer_thread_extent), + Downcast(consumer_thread_extent)}; + body = AttrStmt(ws_partition, attr::kWarpSpecializationScope, 0, body); + + block.CopyOnWrite()->body = SeqStmt({init_barrier, body}); + block_realize.CopyOnWrite()->block = block; + return block_realize; + } + + WarpSpecializedRewriter() = default; + + Map buffer_data_to_buffer_; + Map> buffer_lca_; + Map buffer_remap_; + IterVar thread_iv_; + Optional updated_thread_extent_; + bool need_update_thread_extent_ = false; + bool disable_warp_specialized_ = false; + bool disable_shuffle_elect_ = false; +}; + +using namespace tir::transform; + +tvm::transform::Pass WarpSpecialized() { + auto pass_func = [=](PrimFunc f, const IRModule &m, PassContext ctx) { + bool disable_warp_specialized = + ctx->GetConfig(kDisableWarpSpecialized, Bool(false)).value(); + bool disable_shuffle_elect = + ctx->GetConfig(kDisableShuffleElect, Bool(false)).value(); + bool warp_specialized = WarpSpecializedDetector::Detect(f->body); + + if (!warp_specialized) { + return WarpSpecializedRewriter::Substitute(f, disable_warp_specialized, + disable_shuffle_elect); + } else { + auto node = ffi::String("default"); + f.CopyOnWrite()->body = + AttrStmt(node, attr::kCustomWarpSpecialization, 1, f->body); + return f; + } + }; + return CreatePrimFuncPass(pass_func, 0, "tir.transform.WarpSpecialized", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.transform.WarpSpecialized", WarpSpecialized); +} + +} // namespace transform +} // namespace tir +} // namespace tvm \ No newline at end of file diff --git a/src/tir/transforms/warp_specialized_rewriter.h b/src/tir/transforms/warp_specialized_rewriter.h new file mode 100644 index 000000000000..b8e48e5fc734 --- /dev/null +++ b/src/tir/transforms/warp_specialized_rewriter.h @@ -0,0 +1,99 @@ +/*! + * \file warp_specialized_rewriter.h + * \brief tools for warp-specialized-related analysis and transformation + */ + +#pragma once + +#include "arith/ir_visitor_with_analyzer.h" +#include "tir/analysis/var_use_def_analysis.h" +#include +#include +#include +#include +#include +#include + +#include + +#include "../op/builtin.h" +#include "./common/collector.h" +#include "runtime/thread_storage_scope.h" +#include "tir/transforms/ir_utils.h" + +namespace tvm { +namespace tir { +namespace transform { +using namespace runtime; +using arith::IRVisitorWithAnalyzer; + +class WarpSpecializedDetector : public IRVisitorWithAnalyzer { +public: + // return true means this aws will be disabled + static bool Detect(const Stmt &stmt, bool skip_thread_partition = false) { + WarpSpecializedDetector detector; + detector.VisitStmt(stmt); + if (detector.has_warp_specialization_) { + LOG(WARNING) << "Auto warp specialization will be disabled because warp " + "specialization is manually enabled"; + return true; + } + if (detector.has_tma_op_ && detector.has_mbarrier_op_) { + LOG(WARNING) << "Auto warp specialization will be disabled because TMA " + "and mbarrier are both present"; + return true; + } + return false; + } + + WarpSpecializedDetector() { + has_tma_op_ = false; + has_mbarrier_op_ = false; + has_warp_specialization_ = false; + } + +private: + void VisitStmt_(const EvaluateNode *op) final { + if (const CallNode *call = op->value.as()) { + if (call->op.same_as(create_list_of_mbarrier()) || + call->op.same_as(mbarrier_wait_parity()) || + call->op.same_as(builtin::ptx_arrive_barrier()) || + call->op.same_as(builtin::ptx_cp_async_barrier())) { + has_mbarrier_op_ = true; + } + } + IRVisitorWithAnalyzer::VisitStmt_(op); + } + + void VisitExpr_(const CallNode *op) final { + if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col()) || + op->op.same_as(set_max_nreg())) { + has_tma_op_ = true; + } + IRVisitorWithAnalyzer::VisitExpr_(op); + } + + void VisitStmt_(const AttrStmtNode *op) final { + if (op->attr_key == "warp_specialize" && + op->value.as()->value == 1) { + has_warp_specialization_ = true; + } + if (op->attr_key == tir::attr::thread_extent) { + IterVar iv = Downcast(op->node); + if (iv->thread_tag == "threadIdx.x") { + ICHECK(iv->dom->extent.as()); + thread_var_ = iv; + } + } + IRVisitorWithAnalyzer::VisitStmt_(op); + } + + bool has_tma_op_{false}; + IterVar thread_var_; + bool has_mbarrier_op_{false}; + bool has_warp_specialization_{false}; +}; + +} // namespace transform +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/wgmma_sync_rewriter.cc b/src/tir/transforms/wgmma_sync_rewriter.cc new file mode 100644 index 000000000000..f9f6174d9f23 --- /dev/null +++ b/src/tir/transforms/wgmma_sync_rewriter.cc @@ -0,0 +1,270 @@ +/*! + * \file warp_specialized_pipeline.cc + * \brief Warp specialized Pipeline for cuda GPU (sm90+) + */ + +#include +#include +#include +#include +#include +#include + +#include + +#include "../op/builtin.h" + +namespace tvm { +namespace tir { +namespace transform { + +bool isGemm(const Stmt &stmt) { + bool is_gemm = false; + if (stmt.as()) { + auto call = Downcast(stmt)->value.as(); + if (call && call->op.same_as(Op::Get("tir.call_extern"))) { + if (call->args[0].as()) { + std::string name = Downcast(call->args[0])->value; + if (name.find("gemm") != std::string::npos) { + is_gemm = true; + } + } + } + } + return is_gemm; +} + +bool isGemmSync(const Stmt &stmt) { + bool is_gemm_sync = false; + if (stmt.as()) { + auto call = Downcast(stmt)->value.as(); + if (call && call->op.same_as(Op::Get("tir.call_extern"))) { + if (call->args[0].as()) { + std::string name = Downcast(call->args[0])->value; + if (name.find("warpgroup_wait") != std::string::npos) { + is_gemm_sync = true; + } + } + } + } + return is_gemm_sync; +} + +bool isArriveBarrier(const Stmt &stmt) { + bool is_arrive_barrier = false; + if (stmt.as()) { + auto call = Downcast(stmt)->value.as(); + if (call && call->op.same_as(Op::Get("tir.ptx_arrive_barrier"))) { + is_arrive_barrier = true; + } + } + return is_arrive_barrier; +} + +class WgmmaSyncRewriter : public StmtExprMutator { +public: + static PrimFunc Substitute(PrimFunc f) { + auto T = WgmmaSyncRewriter(); + T.buffer_lca_ = DetectBufferAccessLCA(f); + for (auto [buffer, _] : T.buffer_lca_) + T.buffer_data_to_buffer_.Set(buffer->data, buffer); + f.CopyOnWrite()->body = T(f->body); + return f; + } + +private: + void CollectWgmmaInfo(const SeqStmtNode *op) { + for (int i = 0; i < static_cast(op->seq.size()); i++) { + auto stmt = op->seq[i]; + if (isGemm(stmt)) { + gemm_stmts_.push_back(stmt); + gemm_stmt_ids_.push_back(i); + bool found_release = false; + for (int j = i + 1; j < static_cast(op->seq.size()); j++) { + auto release_stmt = op->seq[j]; + if (isArriveBarrier(release_stmt)) { + found_release = true; + gemm_release_stmts_.push_back(release_stmt); + break; + } + } + if (!found_release) { + gemm_release_stmts_.push_back(Evaluate(0)); + } + // ICHECK(op->seq.size() > i + 1); + // auto release_stmt = op->seq[i + 1]; + // auto next_call = + // Downcast(release_stmt)->value.as(); + // ICHECK(next_call); + // ICHECK(next_call->op.same_as(Op::Get("tir.ptx_arrive_barrier"))); + Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, + /*name_hint=*/"", + /*body*/ op->seq[i]); + auto access = GetBlockReadWriteRegion(block, buffer_data_to_buffer_); + std::set read_set, write_set; + for (auto region : access[0]) + read_set.insert(region->buffer.get()); + for (auto region : access[1]) + write_set.insert(region->buffer.get()); + gemm_read_buffers_.push_back(read_set); + gemm_write_buffers_.push_back(write_set); + } + } + } + + Stmt VisitStmt_(const ForNode *op) final { + auto order_anno = op->annotations.Get("tl_pipeline_order"); + if (!order_anno) { + return StmtExprMutator::VisitStmt_(op); + } + + CollectWgmmaInfo(op->body.as()); + auto stmt_node = (op->body).as(); + ICHECK(stmt_node); + + auto intersect_fn = [](const std::set &lhs, + const std::set &rhs) { + for (auto ptr : lhs) + if (rhs.count(ptr)) + return true; + return false; + }; + + for (int r = 0; r < static_cast(gemm_stmts_.size()); r++) { + bool found = false; + auto last_stmt = Stmt(); + for (int i = 0; i < static_cast(stmt_node->seq.size()); i++) { + if (stmt_node->seq[i].same_as(gemm_stmts_[r])) { + found = true; + last_stmt = stmt_node->seq[i]; + continue; + } + if (!found) + continue; + Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, + /*name_hint=*/"", + /*body*/ stmt_node->seq[i]); + auto access = GetBlockReadWriteRegion(block, buffer_data_to_buffer_); + std::set read_set, write_set; + for (auto region : access[0]) + read_set.insert(region->buffer.get()); + for (auto region : access[1]) + write_set.insert(region->buffer.get()); + if (intersect_fn(read_set, gemm_write_buffers_[r]) || + intersect_fn(write_set, gemm_read_buffers_[r]) || + intersect_fn(write_set, gemm_write_buffers_[r])) { + break; + } + last_stmt = stmt_node->seq[i]; + } + last_stmts_.push_back(last_stmt); + } + + auto new_seq = Array(); + for (int i = 0; i < static_cast(stmt_node->seq.size()); i++) { + bool remove_ = false; + for (int j = 0; j < static_cast(gemm_stmts_.size()); j++) { + if (stmt_node->seq[i].same_as(gemm_release_stmts_[j])) { + remove_ = true; + continue; + } + } + if (remove_) + continue; + auto stmt = stmt_node->seq[i]; + for (int j = 0; j < static_cast(gemm_stmts_.size()); j++) { + if (stmt_node->seq[i].same_as(gemm_stmts_[j])) { + auto call = Downcast(stmt)->value.as(); + ICHECK(call); + ICHECK(call->op.same_as(Op::Get("tir.call_extern"))); + ICHECK(call->args[0].as()); + std::string name = Downcast(call->args[0])->value; + std::string new_name = name.substr(0, name.size() - 1) + ", -1>"; + auto new_args = Array(); + new_args.push_back(StringImm(new_name)); + for (int k = 1; k < static_cast(call->args.size()); k++) { + new_args.push_back(call->args[k]); + } + stmt = Evaluate( + Call(DataType::Handle(), builtin::call_extern(), new_args)); + break; + } + } + + new_seq.push_back(stmt); + for (int j = 0; j < static_cast(gemm_stmts_.size()); j++) { + if (stmt_node->seq[i].same_as(last_stmts_[j])) { + Array new_args; + new_args.push_back(StringImm("cute::warpgroup_wait<0>")); + new_args.push_back(Integer(j)); + auto new_call = + Call(DataType::Handle(), builtin::call_extern(), new_args); + new_seq.push_back(Evaluate(new_call)); + if (std::count(gemm_release_stmts_.begin(), gemm_release_stmts_.end(), + gemm_release_stmts_[j]) == 1) { + new_seq.push_back(gemm_release_stmts_[j]); + } else { + gemm_release_stmts_[j] = Evaluate(0); + } + } + } + } + + int gemm_count = 0; + int max_sync_index = 0; + for (int i = 0; i < static_cast(new_seq.size()); i++) { + if (isGemm(new_seq[i])) { + gemm_count++; + } else if (isGemmSync(new_seq[i])) { + auto call = Downcast(new_seq[i])->value.as(); + auto sync_index = + static_cast(Downcast(call->args[1])->value); + auto wait_count = gemm_count - sync_index - 1; + if (sync_index > max_sync_index) + max_sync_index = sync_index; + if (sync_index < max_sync_index) { + // new_seq.erase(new_seq.begin() + i); + new_seq.Set(i, Evaluate(0)); + } else { + Array new_args; + std::string call_str = + "cute::warpgroup_wait<" + std::to_string(wait_count) + ">"; + new_args.push_back(StringImm(call_str)); + new_seq.Set(i, Evaluate(Call(DataType::Handle(), + builtin::call_extern(), new_args))); + } + } + } + auto new_for = + For(op->loop_var, op->min, op->extent, op->kind, + new_seq.size() == 1 ? new_seq[0] : SeqStmt(std::move(new_seq)), + op->thread_binding, op->annotations); + return new_for; + } + + WgmmaSyncRewriter() = default; + + Map> buffer_lca_; + Map buffer_data_to_buffer_; + std::vector> gemm_read_buffers_; + std::vector> gemm_write_buffers_; + std::vector gemm_stmts_; + std::vector gemm_release_stmts_; + std::vector last_stmts_; + + std::vector gemm_stmt_ids_; + friend class WgmmaReleaseCollector; +}; + +using namespace tir::transform; + +tvm::transform::Pass RewriteWgmmaSync() { + auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) { + return WgmmaSyncRewriter::Substitute(std::move(f)); + }; + return CreatePrimFuncPass(pass_func, 0, "tir.transform.RewriteWgmmaSync", {}); +} + +} // namespace transform +} // namespace tir +} // namespace tvm \ No newline at end of file