diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index e0dbc7be50cf..e692464be796 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -210,6 +210,32 @@ class ScheduleRule : public runtime::ObjectRef { Optional> vector_load_lens, Optional> reuse_read, Optional> reuse_write, bool use_software_pipeline); + /*! + * \brief Extension of MultiLevelTiling for auto-tensorization with multiple groups of candidate + * Hopper tensor core intrinsics + * \param intrin_groups A list of groups of tensor core intrinsics. The map should contains key + * "init", "load_a", "load_b", "compute", "store", which represent the tensor intrin for + * initialization, loading operand A, loading operand B, tensor core computation, storing the + * result. The value of the map should be names of tensor intrinsics, must be registered via + * TensorIntrin.register(...) beforehand + * \param structure The tiling structure. Recommended: + * - 'SSSRRSRS' on GPU + * \param tile_binds For each level of tiles, which thread axis it is bound to. Recommended: + * - [blockIdx.y, blockIdx.x, threadIdx.y] on GPU + * \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit + * \param vector_load_lens The length of vector lane in vectorized cooperative fetching. + * NullOpt means disable vectorization + * \param reuse_read Data reuse configuration for reading. NullOpt means no reuse. + * \param reuse_write Data reuse configuration for writing. NullOpt means no reuse. + * \param use_software_pipeline Whether use the software pipeline. + * \return The schedule rule created + */ + TVM_DLL static ScheduleRule MultiLevelTilingTensorCoreHopper( + Array> intrin_groups, String structure, + Optional> tile_binds, Optional max_innermost_factor, + Optional> vector_load_lens, Optional> reuse_read, + Optional> reuse_write, bool use_software_pipeline); + /*! * \brief Extension of MultiLevelTiling for backends with wide vectors. * The loop over the innermost spatial axis of the output buffer is always vectorized with the diff --git a/python/tvm/meta_schedule/schedule_rule/__init__.py b/python/tvm/meta_schedule/schedule_rule/__init__.py index d330fc713991..b08804495143 100644 --- a/python/tvm/meta_schedule/schedule_rule/__init__.py +++ b/python/tvm/meta_schedule/schedule_rule/__init__.py @@ -27,6 +27,7 @@ from .multi_level_tiling import ( MultiLevelTiling, MultiLevelTilingTensorCore, + MultiLevelTilingTensorCoreHopper, MultiLevelTilingWideVector, MultiLevelTilingWithIntrin, ReuseType, diff --git a/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py b/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py index 19651a2ce18e..7a4278161bba 100644 --- a/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py +++ b/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py @@ -197,6 +197,61 @@ def __init__( ) +@register_object("meta_schedule.ScheduleRuleMultiLevelTilingTensorCoreHopper") +class ScheduleRuleMultiLevelTilingTensorCoreHopper(ScheduleRule): + """Extension of MultiLevelTiling for auto-tensorizing with multiple groups of candidate tensor + core intrinsics. + + Parameters + ---------- + intrin_groups : List[Mapping[str, str]] + A list of groups of tensor core intrinsics. The map should contains key "init", "load_a", + "load_b", "compute", "store", which represent the tensor intrin for initialization, + loading operand A, loading operand B, tensor core computation, storing the result. + The value of the map should be names of tensor intrinsics, must be registerd via + TensorIntrin.register(...) beforehand + structure : str + The tiling structure. Recommended: + - 'SSSRRSRS' on GPU + tile_bind : Optional[List[str]] + For each level of tiles, which thread axis it is bound to. Recommended: + - [blockIdx.y, vthread.x, threadIdx.y] on GPU + max_innermost_factor : Optional[int] + The maximum size of the innermost factor. None means no limit + vector_load_lens : Optional[List[int]] + The length of vector lane in vectorized cooperative fetching. + None means disable vectorization + reuse_read : Optional[ReuseType] + Data reuse configuration for reading. None means no reuse. + reuse_write : Optional[ReuseType] + Data reuse configuration for writing. None means no reuse. + use_software_pipeline : bool + Whether to use the software pipeline. + """ + + def __init__( + self, + intrin_groups: List[Mapping[str, str]], + structure: str, + tile_binds: Optional[List[str]] = None, + max_innermost_factor: Optional[int] = None, + vector_load_lens: Optional[List[int]] = None, + reuse_read: Optional[ReuseType] = None, + reuse_write: Optional[ReuseType] = None, + use_software_pipeline: bool = False, + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.ScheduleRuleMultiLevelTilingTensorCoreHopper, # type: ignore # pylint: disable=no-member + intrin_groups, + structure, + tile_binds, + max_innermost_factor, + vector_load_lens, + reuse_read.as_dict() if reuse_read is not None else None, + reuse_write.as_dict() if reuse_write is not None else None, + use_software_pipeline, + ) + @register_object("meta_schedule.MultiLevelTilingWideVector") class MultiLevelTilingWideVector(ScheduleRule): """Extension of MultiLevelTiling for backends with wide vectors. The loop over the innermost diff --git a/python/tvm/tir/tensor_intrin/cuda.py b/python/tvm/tir/tensor_intrin/cuda.py index 409a1ff10a78..89dcb962dfeb 100644 --- a/python/tvm/tir/tensor_intrin/cuda.py +++ b/python/tvm/tir/tensor_intrin/cuda.py @@ -1636,6 +1636,71 @@ def mma_store_desc(a: T.handle, c: T.handle) -> None: return mma_store_desc, mma_store_desc +def get_wgmma_intrin_group( + load_scope: Literal["shared", "shared.dyn"], + store_scope: Literal["global", "shared", "shared.dyn"], + in_dtype: str, + out_dtype: str, + trans_b: bool, +) -> Dict[str, str]: + """Get a group of intrinsics for wgmma tensor core with the given configurations + + Parameters + ---------- + load_scope : Literal["shared", "shared.dyn"] + The memory scope of the input buffer. + + store_scope : Literal["global", "shared", "shared.dyn"] + The memory scope of the result buffer. + + in_dtype : str + The input data type. + + out_dtype : str + The output data dtype. + + trans_b : bool + Whether the input matrix B is transposed. + + Returns + ------- + ret : Dict[str, str] + A group of tensor intrinsics. + """ + assert load_scope in ["shared", "shared.dyn"] + assert store_scope in ["global", "shared", "shared.dyn"] + assert in_dtype in ["float16", "int8"] + assert out_dtype in ["float16", "float32", "int32"] + + shape = "16x16x16" + in_dtype = "f16" if in_dtype == "float16" else "s8" + out_dtype = "f16" if out_dtype == "float16" else "f32" if out_dtype == "float32" else "s32" + # convert "shared.dyn" to "shared_dyn" + load_scope = load_scope.replace(".", "_") + store_scope = store_scope.replace(".", "_") + trans_a = "" + trans_b = "_trans" if trans_b else "" + + # e.g. wgmma_load_16x16x16_f16_a_shared + load_a_intrin = f"wgmma_load_{shape}_{in_dtype}_a{trans_a}_{load_scope}" + # e.g. wgmma_load_16x16x16_f16_b_trans_shared_dyn + load_b_intrin = f"wgmma_load_{shape}_{in_dtype}_b{trans_b}_{load_scope}" + # e.g. wgmma_sync_16x16x16_f16f16f32_trans + compute_intrin = f"wgmma_sync_{shape}_{in_dtype}{in_dtype}{out_dtype}{trans_b}" + # e.g. wgmma_fill_16x16x16_f16 + init_intrin = f"wgmma_fill_{shape}_{out_dtype}" + # e.g. wgmma_store_16x16x16_f16_shared_dyn + store_intrin = f"wgmma_store_{shape}_{out_dtype}_{store_scope}" + + return { + "init": init_intrin, + "load_a": load_a_intrin, + "load_b": load_b_intrin, + "compute": compute_intrin, + "store": store_intrin, + } + + TensorIntrin.register("mma_init_m16n8k8_f16", *get_mma_init_intrin(16, 8, 8, "float16")) TensorIntrin.register("mma_init_m16n8k8_f32", *get_mma_init_intrin(16, 8, 8, "float32")) diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core_hopper.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core_hopper.cc new file mode 100644 index 000000000000..8563da7bd7ef --- /dev/null +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core_hopper.cc @@ -0,0 +1,930 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include + +#include +#include +#include + +#include "../utils.h" +#include "./multi_level_tiling.h" + +namespace tvm { +namespace meta_schedule { + +using tir::BlockRV; +using tir::IterVarType; +using tir::LoopRV; +using tir::Schedule; + +struct TensorCoreIntrinGroup { + String init_intrin; + String load_a_intrin; + String load_b_intrin; + String compute_intrin; + String store_intrin; + + /*! \brief Create TensorCoreIntrinGroup from config in a map. The map should contains the + * following keys: + * - init + * - load_a + * - load_b + * - compute + * - store + * The values of the keys should be the names of the corresponding intrinsics and should be + * registered via TensorIntrin.Register beforehand. + */ + static TensorCoreIntrinGroup FromConfig(const Map& config); +}; + +TensorCoreIntrinGroup TensorCoreIntrinGroup::FromConfig(const Map& config) { + auto f_initialize_intrin = [&config](String key_name, String* intrin_name) { + CHECK(config.count(key_name)) << "ValueError: " << key_name << " is not set."; + *intrin_name = config.at(key_name); + // Check the existence of the intrin + tir::TensorIntrin::Get(*intrin_name); + }; + TensorCoreIntrinGroup intrin_group; + f_initialize_intrin("init", &intrin_group.init_intrin); + f_initialize_intrin("load_a", &intrin_group.load_a_intrin); + f_initialize_intrin("load_b", &intrin_group.load_b_intrin); + f_initialize_intrin("compute", &intrin_group.compute_intrin); + f_initialize_intrin("store", &intrin_group.store_intrin); + return intrin_group; +} + +class TensorCoreStateNode : public StateNode { + public: + /*! \brief The tensor core intrinsic group. */ + TensorCoreIntrinGroup intrin_group; + /*! \brief The auto tensorization maping info. */ + tir::AutoTensorizeMappingInfo mapping_info{nullptr}; + /*! \brief The Tensor Core reindex block A for Tensor Core computation */ + tir::BlockRV tensor_core_reindex_A; + /*! \brief The Tensor Core reindex block B for Tensor Core computation */ + tir::BlockRV tensor_core_reindex_B; + /*! \brief The Tensor Core reindex store block for Tensor Core computation */ + tir::BlockRV tensor_core_reindex_store; + /*! \brief Flag to indicate its a WGMMA intrin group */ + bool is_hopper; + /*! \brief Flag to indicate whether to use async software pipeline */ + bool use_async; + + State Copy() const final; + + static constexpr const char* _type_key = "meta_schedule.TensorCoreState"; + TVM_DECLARE_FINAL_OBJECT_INFO(TensorCoreStateNode, StateNode); +}; + +class TensorCoreState : public State { + public: + explicit TensorCoreState(TensorCoreIntrinGroup intrin_group, + tir::AutoTensorizeMappingInfo mapping_info, Schedule sch, + BlockRV block_rv, bool use_async, Array> tiles = {}); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TensorCoreState, State, TensorCoreStateNode); +}; + +TVM_REGISTER_OBJECT_TYPE(TensorCoreStateNode); + +TensorCoreState::TensorCoreState(TensorCoreIntrinGroup intrin_group, + tir::AutoTensorizeMappingInfo mapping_info, Schedule sch, + BlockRV block_rv, bool use_async, Array> tiles) { + ObjectPtr node = make_object(); + node->intrin_group = intrin_group; + node->mapping_info = mapping_info; + node->sch = std::move(sch); + node->block_rv = std::move(block_rv); + node->tiles = std::move(tiles); + node->is_hopper = support::StartsWith(intrin_group.compute_intrin, "hopper_sync"); + node->use_async = use_async; + data_ = std::move(node); +} + +State TensorCoreStateNode::Copy() const { + ObjectPtr node = make_object(*this); + node->sch = sch->Copy(); + return State(node); +} + +/*! + * \brief Extension of MultiLevelTiling for auto-tensorizing with a single group of tensor core + * intrinsics. + */ +class MultiLevelTilingTensorCoreHopperNode : public MultiLevelTilingNode { + private: + // SubRule: Add tensorization-related transformations + inline std::vector TransformForTensorization(TensorCoreState state) const; + // Subrule: Transform the layout of the output. This is necessary for efficient cache write the + // output in the shared memory. + std::vector TransformIntermediateOutputLayout(TensorCoreState state); + // Subrule: Add read cache for wgmma + // Basically same with MultiLevelTilingNode::AddReadReuse, but change CacheRead + ComputeAt to + // ReadAt + inline std::vector WGMMAAddReadReuse(TensorCoreState state) const; + // Subrule: Add tensorized load + inline std::vector AddReadReuseTensorCore(TensorCoreState state) const; + // Subrule: Add tensorized store + inline std::vector AddWriteReuseTensorCore(TensorCoreState state) const; + // Subrule: Add software pipeline + inline std::vector AddSoftwarePipeline(TensorCoreState state) const; + // Subrule: split loop for wgmma using sample partitioned tile + inline std::pair, Array> WGMMASplitLoop(const Schedule& sch, + BlockRV block, LoopRV loop, + int n_tiles, + int partition_pos, + int innerpart_factor) const; + // Subrule: tile loop nest for wgmma + // Basically same with MultiLevelTilingNode::TileLoopNest, but change SamplePerfectTile to + // SamplePartitionedTile + inline std::vector WGMMATileLoopNest(TensorCoreState state) const; + + // Override ApplySubRules to apply tensorization-specific sub-rules + std::vector ApplySubRules(std::vector states) final; + + // Override Apply to apply tensorization-specific analysis before applying sub-rules + Array Apply(const Schedule& sch, const BlockRV& block_rv) final; + + // Inherited from ScheduleRuleNode + ScheduleRule Clone() const final { + ObjectPtr n = + make_object(*this); + return ScheduleRule(n); + } + + /*! + * \brief Transform and tensorize with the given tensor intrin + * \param state The state of the meta schedule rule + * \param intrin_name The name of the tensor intrin + * \return The loop to be tensorized. NullOpt if the workload can't be tensorized. + */ + Optional TransformWithTensorIntrin(TensorCoreStateNode* state, + const String& intrin_name) const; + + /*! + * \brief Tile, blockize and annotate for tensorization with the given intrin + * \param block_rv The block to be tensorized + * \param intrin_name The name of the tensor intrin + */ + void TileAndAnnotateTensorize(Schedule* sch, const BlockRV& block_rv, const String& intrin_name, + const String& permuted_layout_annotate_value) const; + + public: + /*! \brief The candidate tensor core intrin groups to apply */ + std::vector intrin_groups; + /*! \brief Whether to use software pipeline */ + bool use_software_pipeline = false; + static constexpr const char* _type_key = "meta_schedule.MultiLevelTilingTensorCoreHopper"; + TVM_DECLARE_FINAL_OBJECT_INFO(MultiLevelTilingTensorCoreHopperNode, MultiLevelTilingNode); + + private: +}; + +// Entry of the mega rule; Inherited from ScheduleRuleNode +Array MultiLevelTilingTensorCoreHopperNode::Apply(const Schedule& sch, + const BlockRV& block_rv) { + if (!NeedsMultiLevelTiling(sch->state(), sch->GetSRef(block_rv))) { + return {sch}; + } + + std::unordered_map intrin_group_to_mapping_info; + for (int i = 0, n = intrin_groups.size(); i < n; ++i) { + TensorCoreIntrinGroup intrin_group = intrin_groups[i]; + Optional mapping_info = tir::GetAutoTensorizeMappingInfo( + sch->state(), sch->GetSRef(block_rv), + tir::TensorIntrin::Get(intrin_groups[i].compute_intrin).value()->desc); + if (mapping_info.defined()) { + intrin_group_to_mapping_info.emplace(i, mapping_info.value()); + } + } + + if (intrin_group_to_mapping_info.empty()) { + // No tensor intrinsics can be applied. + return {sch}; + } + + // Save the original schedule so that we can roll back transformations if tensorization + // fail. + Schedule original_sch = sch; + + std::vector initial_states; + for (const auto& kv : intrin_group_to_mapping_info) { + const TensorCoreIntrinGroup& intrin_group = intrin_groups[kv.first]; + const tir::AutoTensorizeMappingInfo& mapping_info = kv.second; + Schedule new_sch = sch->Copy(); + new_sch->Annotate(block_rv, tir::attr::meta_schedule_tiling_structure, structure); + initial_states.push_back(TensorCoreState(intrin_group, mapping_info, new_sch, block_rv, true)); + } + Array results; + for (auto&& state : ApplySubRules(initial_states)) { + TVM_PY_LOG(INFO, logger) << "Sketch " << results.size() << ": tensorizing with " + << state.as()->intrin_group.compute_intrin; + results.push_back(std::move(state->sch)); + } + if (results.empty()) { + TVM_PY_LOG(INFO, logger) << "The workload cannot be tensorized."; + return {original_sch}; + } + return results; +} + +std::vector MultiLevelTilingTensorCoreHopperNode::ApplySubRules(std::vector states) { + states = SubRule(std::move(states), [&](State state) { + return TransformForTensorization(Downcast(state)); + }); + states = SubRule(std::move(states), [&](State state) { + TensorCoreState tc_state = Downcast(state); + return tc_state->is_hopper ? WGMMATileLoopNest(tc_state) : TileLoopNest(state, 2); + }); + states = SubRule(std::move(states), [&](State state) { + return TransformIntermediateOutputLayout(Downcast(state)); + }); + states = SubRule(std::move(states), [&](State state) { return AddWriteReuse(state); }); + states = SubRule(std::move(states), [&](State state) { + return AddWriteReuseTensorCore(Downcast(state)); + }); + states = SubRule(std::move(states), [&](State state) { + TensorCoreState tc_state = Downcast(state); + return tc_state->is_hopper ? WGMMAAddReadReuse(tc_state) : AddReadReuse(state); + }); + states = SubRule(std::move(states), [&](State state) { + return AddReadReuseTensorCore(Downcast(state)); + }); + states = SubRule(std::move(states), [&](State state) { + return AddSoftwarePipeline(Downcast(state)); + }); + return states; +} + +void MultiLevelTilingTensorCoreHopperNode::TileAndAnnotateTensorize( + Schedule* sch, const BlockRV& block_rv, const String& intrin_name, + const String& permuted_layout_annotate_value) const { + Optional loop = TileWithTensorIntrin(*sch, block_rv, intrin_name).value(); + ICHECK(loop.defined()); + BlockRV blockized_outer = (*sch)->Blockize(loop.value()); + (*sch)->Annotate(blockized_outer, tir::attr::meta_schedule_auto_tensorize, intrin_name); + if (!permuted_layout_annotate_value.empty()) { + (*sch)->Annotate(blockized_outer, "permuted_layout", permuted_layout_annotate_value); + } +} + +std::vector MultiLevelTilingTensorCoreHopperNode::WGMMAAddReadReuse(TensorCoreState state) const { + const ReuseConfig& config = this->reuse_read_; + if (config.req == ReuseType::kNoReuse) { + return {std::move(state)}; + } + ICHECK(config.req != ReuseType::kMayReuse); + const BlockRV& block_rv = state->block_rv; + std::vector results; + results.reserve(config.levels.size()); + for (int level : config.levels) { + State new_state = state->Copy(); + Schedule& sch = new_state->sch; + const LoopRV& loop_rv = state->tiles[level - 1].back(); + // Enumerate all buffers that are read but not written + std::vector read_buffer_ndims = tir::GetReadBufferNDims(sch->GetSRef(block_rv)); + for (int i = 0, n_reads = read_buffer_ndims.size(); i < n_reads; ++i) { + int buffer_ndim = read_buffer_ndims[i]; + if (buffer_ndim == -1) { + continue; + } + // Do cache_read + BlockRV cache_read_block = sch->ReadAt(loop_rv, block_rv, i, config.scope); + new_state->read_reuse.emplace(i, cache_read_block); + if (state->is_hopper) { + new_state->sch->Annotate(cache_read_block, "permuted_layout", + String(std::string("g2s_") + std::string(i == 0 ? "A" : "B"))); + } + } + results.push_back(std::move(new_state)); + } + return results; +} + +std::pair, Array> MultiLevelTilingTensorCoreHopperNode::WGMMASplitLoop( + const Schedule& sch, BlockRV block, LoopRV loop, int n_tiles, int partition_pos, + int innerpart_factor) const { + Array factors = sch->SamplePartitionedTile( + /*loop=*/loop, + /*n=*/n_tiles, + /*partition_pos=*/partition_pos, + /*innerpart_factor=*/innerpart_factor); + Array splits = sch->Split(/*loop=*/loop, + /*factors=*/{factors.begin(), factors.end()}); + return {factors, splits}; +} + +std::vector MultiLevelTilingTensorCoreHopperNode::WGMMATileLoopNest(TensorCoreState state) const { + Schedule& sch = state->sch; + const BlockRV& block_rv = state->block_rv; + // Step 1. Assuming trivial binding, pair the loops and their iter-var-types + Array loops = sch->GetLoops(block_rv); + if (!(loops.size() == 3 || !state->is_hopper)) { + LOG(DEBUG) << "The WGMMA tensor core only supports SSR loops now"; + return {}; + } + std::vector iter_types = GetBlockVarTypes(sch->GetSRef(state->block_rv)); + ICHECK_EQ(loops.size(), iter_types.size()); + // Step 2. For each loop axis, tile it + int64_t spatial_loop_product = 1; + std::vector> tiles(s_indices_.size() + r_indices_.size()); + state->tile_factors.resize(tiles.size()); + std::vector> tile_factors; + tile_factors.resize(tiles.size()); + for (int i = 0, n = loops.size(); i < n; ++i) { + LoopRV loop = loops[i]; + const std::vector* idx = nullptr; + + if (iter_types[i] == IterVarType::kDataPar) { + idx = &s_indices_; + if (spatial_loop_product != -1) { + if (const int64_t* extent = tir::GetLoopIntExtent(sch->Get(loop).get())) { + spatial_loop_product *= *extent; + } else { + spatial_loop_product = -1; + } + } + } else if (iter_types[i] == IterVarType::kCommReduce) { + idx = &r_indices_; + } else { + continue; + } + + const int n_tiles = idx->size(); + + if (n_tiles == 1) { + tiles[idx->at(0)].push_back(loop); + } else { + auto [factors, splits] = + iter_types[i] == IterVarType::kDataPar + ? WGMMASplitLoop( + sch, block_rv, loop, n_tiles, 3, + i == 0 ? 2 // 32 (load A intrin i shape) / 16 (sync intrin i shape) == 2 + : 4 // 32 (load B intrin j shape) / 8 (sync intrin j shape) == 4 + ) + : SplitLoop(sch, block_rv, loop, n_tiles); + + // Put every tile to its slot + for (int j = 0; j < n_tiles; ++j) { + tiles[idx->at(j)].push_back(splits[j]); + tile_factors[idx->at(j)].push_back(factors[j]); + } + } + } + state->tile_factors = std::move(tile_factors); + // Step 3. Reorder to organize the tiles + sch->Reorder(support::ConcatArrayList(tiles.begin(), tiles.end())); + // Step 4. Bind the tiles to threads + int n_binds = std::min(tile_binds.size(), tiles.size()); + for (int i = 0; i < n_binds; ++i) { + LoopRV fused = sch->Fuse(tiles[i]); + sch->Bind(fused, tile_binds[i]); + tiles[i] = {fused}; + } + state->tiles = Array>{tiles.begin(), tiles.end()}; + if (this->thread_warp_size_ != -1) { + int64_t low_inclusive = 1; + int64_t high_inclusive = this->max_threads_per_block_; + if (spatial_loop_product > 2 * this->thread_warp_size_) { + low_inclusive = this->thread_warp_size_; + } + sch->Annotate(block_rv, tir::attr::meta_schedule_thread_extent_low_inclusive, + Integer(low_inclusive)); + sch->Annotate(block_rv, tir::attr::meta_schedule_thread_extent_high_inclusive, + Integer(high_inclusive)); + } + return {state}; +} + +std::vector MultiLevelTilingTensorCoreHopperNode::TransformIntermediateOutputLayout( + TensorCoreState state) { + if (state->is_hopper) { + return {state}; + } + // Transform the intermediate output to packed layout + // [..., warp_m, warp_n, accum_frag_m, accum_frag_n, accum_elem_m, accum_elem_n] + // where warp_m, warp_n are thread indices bound to the warp id, accum_frag_m, accum_frag_n are + // the index of the fragments in each warp, accum_elem_m, accum_elem_n are the index of the + // elements in each accumulator fragment. + + // Get the shape of the wgmma accumulator + auto [frag_shape_m, frag_shape_n] = [&]() { + tir::Block intrin_block = + Downcast( + tir::TensorIntrin::Get(state->intrin_group.init_intrin).value()->desc->body) + ->block; + tir::For loop_m = Downcast(intrin_block->body); + tir::For loop_n = Downcast(loop_m->body); + return std::make_tuple(loop_m->extent, loop_n->extent); + }(); + + // Get the tile index of the warp id (i.e. threadIdx.y) + auto it = std::find(tile_binds.begin(), tile_binds.end(), "threadIdx.y"); + ICHECK(it != tile_binds.end()); + auto tile_index_warp_id = std::distance(tile_binds.begin(), it); + + // Get the extent of loop indicated by `loop_idx` inside the warp scope. + // For example, after spatial loops i, j are tiled, we will have + // tile_factors = ((i0, j0), (i1, j1), ..., (in, jn)) + // This function computes the product of tile_factors[i][loop_idx] for i > tile_index_warp_id. + // `loop_idx` can be negative, in which case it is counted from the end. + auto f_get_inner_tile_product = [&](int loop_idx) { + Array factors; + for (int i = tile_index_warp_id + 1; i < static_cast(s_indices_.size()); ++i) { + auto s_factors = state->tile_factors[s_indices_[i]]; + if (loop_idx < 0) { + loop_idx += s_factors.size(); + } + factors.push_back(s_factors[loop_idx]); + } + ICHECK(!factors.empty()); + if (factors.size() == 1) { + return factors[0]; + } + auto result = factors[0]; + for (int i = 1; i < static_cast(factors.size()); ++i) { + result = result * factors[i]; + } + return result; + }; + + // Compute the number of output fragment of each warp + auto warp_num_frag_m = f_get_inner_tile_product(-2); + auto warp_num_frag_n = f_get_inner_tile_product(-1); + + Schedule& sch = state->sch; + int buffer_ndim = static_cast(sch->Get(state->block_rv)->writes[0]->buffer->shape.size()); + // The dimension of the buffer should be larger or same as that of the tensor intrin. + ICHECK_GE(buffer_ndim, 2); + int num_higher_dims = buffer_ndim - 2; + + auto index_map = + tir::IndexMap::FromFunc(buffer_ndim, + // frag_shape_m and frag_shape_n are structural bindings that cannot + // not be automatically captured until c++20 + [&, frag_shape_m = frag_shape_m, + frag_shape_n = frag_shape_n](const Array& indices) { + Array result; + result.reserve(indices.size() + 4); + for (int i = 0; i < num_higher_dims; ++i) { + result.push_back(indices[i]); + } + const auto& m = indices[num_higher_dims]; + const auto& n = indices[num_higher_dims + 1]; + auto accum_m = floormod(m, frag_shape_m); + auto accum_n = floormod(n, frag_shape_n); + auto outer_m = floordiv(m, frag_shape_m); + auto outer_n = floordiv(n, frag_shape_n); + + result.push_back(floordiv(outer_m, warp_num_frag_m)); + result.push_back(floordiv(outer_n, warp_num_frag_n)); + result.push_back(floormod(outer_m, warp_num_frag_m)); + result.push_back(floormod(outer_n, warp_num_frag_n)); + result.push_back(accum_m); + result.push_back(accum_n); + return result; + }); + sch->TransformLayout(state->block_rv, 0, tir::BufferIndexType::kWrite, index_map, + /*pad_value=*/NullOpt, /*assume_injective_transform=*/true); + + return {state}; +} + +std::vector MultiLevelTilingTensorCoreHopperNode::AddWriteReuseTensorCore( + TensorCoreState state) const { + if (state->is_hopper) { + state->sch->WriteAt(state->tiles[2].back(), state->block_rv, 0, "m16n8k8.matrixC"); + state->sch->ReverseComputeInline(state->tensor_core_reindex_store); + return {state}; + } + // Add the cache write stage for Tensor Core + Schedule& sch = state->sch; + auto cache_write = sch->CacheWrite(state->block_rv, 0, "wgmma.accumulator"); + + // The compute block has been tiled by the warp shape and the fragment shape. + // We need to bind the cache write block (from the accumulator to the shared memory) to the warp + // id. The schedule is as follows: + // + // After adding cache write for wgmma.accumulator, we will have + // for i0, j0, i1, j1, accum_m, accum_n: + // shared_mem[i0, j0, i1, j1, accum_m, accum_n] = accum[i0, j0, i1, j1, accum_m, accum_n] + // for i0', j0', i1', j1', accum_m', accum_n': + // global_mem[i0', j0', i1', j1', accum_m', accum_n'] = + // shared_mem[i0', j0', i1', j1', accum_m', accum_n'] + // where i0' and j0' are already bound to the block id and warp id. + // + // To reduce the shared memory usage and allow efficient data movement, we will apply + // transformations to generate the following schedule: + // + // for i1': + // for i0_j0 (fused and bound to threadIdx.y): + // for j1, accum_m, accum_n: + // shared_mem[i0, j0, i1, j1, accum_m, accum_n] = accum[i0, j0, i1, j1, accum_m, accum_n] + // for i0', j0', j1', accum_m', accum_n': + // global_mem[i0', j0', i1', j1', accum_m', accum_n'] = + // shared_mem[i0', j0', i1', j1', accum_m', accum_n'] + // + // i1' is reordered to the outermost. This effectively allows only a row (i.e. loop i1') of the + // fragments are moved to the shared memory and then to the global memory each time. + // As a result, shared memory for the output will only have shape of [j1, accum_m, accum_n] + // instead of [i0 * i1 * accum_m, j0 * j1 * accum_n]. + + // Get the loops other than the innermost two loops (accum_m and accum_n). + auto f_get_loops = [&](const BlockRV& block_rv) -> std::array { + Array buffer_loops = sch->GetLoops(block_rv); + ICHECK_GT(buffer_loops.size(), 6); + return {buffer_loops[buffer_loops.size() - 6], buffer_loops[buffer_loops.size() - 5], + buffer_loops[buffer_loops.size() - 4], buffer_loops[buffer_loops.size() - 3]}; + }; + { + const auto& [i0, j0, i1, j1] = f_get_loops(state->write_reuse[0]); + sch->Reorder({i1, i0, j0, j1}); + sch->ComputeAt(cache_write, i1, true); + } + { + auto loops = f_get_loops(cache_write); + const auto& i0 = loops[0]; + const auto& j0 = loops[1]; + auto fused = sch->Fuse({i0, j0}); + sch->Bind(fused, "threadIdx.y"); + } + + sch->ReverseComputeInline(state->tensor_core_reindex_store); + auto loops = sch->GetLoops(cache_write); + auto blockized_store = sch->Blockize(loops[loops.size() - 2]); + sch->Annotate(blockized_store, tir::attr::meta_schedule_auto_tensorize, + state->intrin_group.store_intrin); + + Array buffer_loops = sch->GetLoops(state->write_reuse[0]); + ICHECK_GT(buffer_loops.size(), 5); + sch->Fuse(Array{buffer_loops.end() - 5, // The src shmem is always 2D + buffer_loops.end()}); + AnnotateCooperativeFetching(&sch, state->write_reuse[0]); + return {state}; +} + +std::vector MultiLevelTilingTensorCoreHopperNode::AddReadReuseTensorCore( + TensorCoreState state) const { + const Array& r_tiles = state->tiles[r_indices_[1]]; + Schedule& sch = state->sch; + ICHECK(!r_tiles.empty()) << "ValueError: Cannot find the suitable reduction loop in the block"; + + auto f_tensorize_load = [&](int read_index, String scope, String intrin_name) { + auto cache_read = sch->CacheRead(state->block_rv, read_index, scope); + state->sch->ComputeAt(cache_read, r_tiles.back(), true); + String permuted_layout_annotate_value = + state->is_hopper ? std::string("s2l_") + std::string(read_index == 0 ? "A" : "B") : ""; + TileAndAnnotateTensorize(&sch, cache_read, intrin_name, permuted_layout_annotate_value); + }; + f_tensorize_load(0, state->is_hopper ? "m16n8k8.matrixA" : "wgmma.matrix_a", + state->intrin_group.load_a_intrin); + f_tensorize_load(1, state->is_hopper ? "m16n8k8.matrixB" : "wgmma.matrix_b", + state->intrin_group.load_b_intrin); + + for (int i = 0; i < 2; ++i) { + const tir::BlockRV cache_read = state->read_reuse.at(i); + // Inline the reindex / padding block + sch->ComputeInline(sch->GetProducers(cache_read)[0]); + const tir::BlockNode* cache_read_block = sch->GetSRef(cache_read)->StmtAs(); + tir::Buffer cache_read_buffer = tir::GetNthAccessBuffer( + sch->state(), GetRef(cache_read_block), 0, tir::BufferIndexType::kWrite); + const DataType& dtype = cache_read_buffer->dtype; + if (dtype.is_float16()) { + sch->StorageAlign(cache_read, 0, -2, 32, 8); + } else if (dtype.is_int() && dtype.bits() == 8) { + sch->StorageAlign(cache_read, 0, -2, 32, 16); + } else { + TVM_PY_LOG(WARNING, logger) << "StorageAlign is not applied for data type " << dtype + << ", shared memory accesses might be inefficient."; + } + } + return {state}; +} + +std::vector MultiLevelTilingTensorCoreHopperNode::AddSoftwarePipeline( + TensorCoreState state) const { + if (!use_software_pipeline) { + return {state}; + } + // The current config is not suitable for software pipelining. + if (r_indices_.size() < 2) { + return {state}; + } + + Schedule& sch = state->sch; + // Check reduction length after blockize. + int64_t reduction_length = 1; + for (int r_index : r_indices_) { + const Array& tiles = state->tiles[r_index]; + for (const LoopRV& tile : tiles) { + const auto* extent = sch->Get(tile)->extent.as(); + ICHECK(extent != nullptr) << "Dynamic extent is not supported."; + reduction_length *= extent->value; + } + } + if (reduction_length <= 1) { + return {state}; + } + + for (int i = 0; i < 2; ++i) { + const tir::BlockRV cache_read = state->read_reuse.at(i); + if (state->is_hopper) { + // Add vector bytes for memhammer + sch->Annotate(cache_read, tir::attr::vector_bytes, Integer(16)); + if (!state->use_async) { + sch->Annotate(cache_read, tir::attr::local_stage, Integer(1)); + sch->Annotate(cache_read, tir::attr::double_buffer_scope, Integer(0)); + } + } else { + // Add local stage and double buffering + sch->Annotate(cache_read, tir::attr::manifest_shared_memory_local_stage, Integer(1)); + sch->Annotate(cache_read, tir::attr::double_buffer_scope, Integer(0)); + } + } + + // Add annotations of software pipeline + // + // Before pipelining, the original loop can be expressed as the pseudo code below: + // + // for k0 in [0, K0): + // load tile k0 to registers + // load tile k0 from registers to shared memory + // + // for k1 in [0, K1): + // load fragment k1 of tile k0 + // compute matmul with fragment k1 + // + + // Inner software pipeline: Prefetch to tensor core fragment by one iteration + // The following annotation for the inner loop is equivalent the pesudo code below: + // + // Pipelined inner loop: + // + // prologue: + // load fragment 0 + // body: + // for k1 in [0, K1 - 1): + // load fragment k1 + 1 + // compute matmul with fragment k1 + // epilogue: + // compute matmul with fragment K1 - 1 + // + sch->Annotate(state->tiles[r_indices_[1]].back(), tir::attr::software_pipeline_stage, + Array{0, 0, 1}); + sch->Annotate(state->tiles[r_indices_[1]].back(), tir::attr::software_pipeline_order, + Array{0, 1, 2}); + if (state->is_hopper && state->use_async) { + sch->Annotate(state->tiles[r_indices_[0]].back(), tir::attr::software_pipeline_async_stages, + Array{0}); + sch->Annotate(state->tiles[r_indices_[0]].back(), tir::attr::software_pipeline_stage, + Array{0, 0, 1, 2, 2}); + sch->Annotate(state->tiles[r_indices_[0]].back(), tir::attr::software_pipeline_order, + Array{0, 1, 3, 2, 4}); + } else { + // Outer software pipeline: Interleave the outer loop with the (pipelined) inner loop. + // The prefetching stage of the inner pipeline is executed by one iteration in the outer loop. + // The following annotation for the outer loop is equivalent the pesudo code below: + // + // Pipelined outer loop with nested inner pipeline: + // + // prologue: + // load tile 0 to registers + // load tile 0 from registers to shared memory + // + // // prologue of the inner pipeline + // load fragment 0 of tile 0 + // + // body: + // for k0 in [0, K0 - 1): + // load tile k0 + 1 to registers + // + // // body of the inner pipeline + // for k1 in [0, K1 - 1): + // load fragment k1 + 1 of tile k0 + // compute matmul with fragment k1 of tile k0 + // + // load tile k0 + 1 from registers to shared memory + // + // // prologue of the inner pipeline + // load fragment 0 of tile k0 + 1 + // + // // epilogue of the inner pipeline + // compute matmul with fragment K1 - 1 of tile k0 + // + // epilogue: + // + // // body of the inner pipeline + // for k1 in [0, K1 - 1): + // load fragment k1 + 1 of tile K0 - 1 + // compute matmul with fragment k1 of tile K0 - 1 + // + // // epilogue of the inner pipeline + // compute matmul with fragment K1 - 1 of tile K0 - 1 + // + sch->Annotate(state->tiles[r_indices_[0]].back(), tir::attr::software_pipeline_stage, + Array{0, 0, 0, 0, 0, 1, 1}); + sch->Annotate(state->tiles[r_indices_[0]].back(), tir::attr::software_pipeline_order, + Array{0, 3, 1, 4, 5, 2, 6}); + } + + return {state}; +} + +Optional MultiLevelTilingTensorCoreHopperNode::TransformWithTensorIntrin( + TensorCoreStateNode* state, const String& intrin_name) const { + BlockRV block_rv = state->block_rv; + const tir::AutoTensorizeMappingInfo& mapping_info = state->mapping_info; + tir::StmtSRef block_sref = state->sch->GetSRef(state->block_rv); + + // Add reindex stages + const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); + // Hold the reference of the block before reindex + const tir::Block block_before_reindex = GetRef(block); + if (block->reads.size() != 2 || block->writes.size() != 1) { + // only matmul-like computation is allowed + return NullOpt; + } + state->tensor_core_reindex_store = + state->sch->ReIndex(state->block_rv, 0, tir::BufferIndexType::kWrite); + state->tensor_core_reindex_A = + state->sch->ReIndex(state->block_rv, 0, tir::BufferIndexType::kRead); + state->tensor_core_reindex_B = + state->sch->ReIndex(state->block_rv, 1, tir::BufferIndexType::kRead); + + // Transform the layout of reindex buffers accordingly. + // The index map defines the mapping for the computation block. We need to extract the sub index + // map to transform the load and store block. + ICHECK_EQ(mapping_info->mappings.size(), 1U); // assume only one mapping is present + const tir::IndexMap& index_map = mapping_info->mappings[0]; + + // Find the correspondence between block iters and the iters in the index map. + std::unordered_map lhs_to_index_map_src; + std::unordered_map rhs_to_index_map_tgt; + std::unordered_set unmapped_index_map_src; + ICHECK_EQ(mapping_info->lhs_iters.size(), index_map->initial_indices.size()); + for (int i = 0; i < static_cast(mapping_info->lhs_iters.size()); ++i) { + lhs_to_index_map_src[mapping_info->lhs_iters[i]->var] = index_map->initial_indices[i]; + } + // The number of result iters in the index map is equal or more than the number of rhs (the + // tensor intrin) iters. When there are extra iters, these iters represent unmapped iters from + // the lhs. They will be skipped during pattern matching for tensorization. An example of such + // case is batch matmul, the batch dimension is kept after layout transformations and it will be + // kept as a outer loop after tensorization. + int offset = static_cast(index_map->final_indices.size()) - + static_cast(mapping_info->rhs_iters.size()); + ICHECK_GE(offset, 0); + for (int i = 0; i < offset; ++i) { + const tir::VarNode* var_ptr = index_map->final_indices[i].as(); + ICHECK(var_ptr != nullptr); + unmapped_index_map_src.insert(GetRef(var_ptr)); + } + for (int i = offset; i < static_cast(index_map->final_indices.size()); ++i) { + rhs_to_index_map_tgt[mapping_info->rhs_iters[i - offset]->var] = index_map->final_indices[i]; + } + + auto f_get_sub_index_map = [&](const tir::Buffer& lhs_buffer, const tir::Region& lhs_region) { + std::vector sub_index_map_src; + std::vector sub_index_map_tgt; + const tir::Buffer& rhs_buffer = mapping_info->lhs_buffer_map[lhs_buffer]; + for (const Range& range : lhs_region) { + ICHECK(tir::is_one(range->extent)); + const tir::VarNode* var_ptr = range->min.as(); + ICHECK(var_ptr != nullptr); + const tir::Var& lhs_representer = lhs_to_index_map_src[GetRef(var_ptr)]; + sub_index_map_src.push_back(lhs_representer); + if (unmapped_index_map_src.count(lhs_representer)) { + sub_index_map_tgt.push_back(lhs_representer); + } + } + for (size_t i = 0; i < mapping_info->rhs_buffer_indices[rhs_buffer].size(); ++i) { + const tir::VarNode* var = mapping_info->rhs_buffer_indices[rhs_buffer][i].as(); + ICHECK(var != nullptr); + sub_index_map_tgt.push_back(rhs_to_index_map_tgt[GetRef(var)]); + } + return tir::IndexMap(sub_index_map_src, sub_index_map_tgt); + }; + + std::unordered_set visited_buffers; + + Map buffer_sub_index_map; // cache of the sub index map associated + // with each buffer + + auto f_transform_buffer_layout = [&](tir::BufferIndexType index_type, int buffer_index) { + const tir::Buffer& lhs_buffer = tir::GetNthAccessBuffer( + state->sch->state(), block_before_reindex, buffer_index, index_type); + if (visited_buffers.count(lhs_buffer)) { + return; + } + visited_buffers.insert(lhs_buffer); + // Refresh block pointer (block sref is not invalidated) + block = TVM_SREF_TO_BLOCK(block_sref); + const tir::BufferRegion& reindexed_buffer_region = tir::GetNthAccessBufferRegion( + state->sch->state(), GetRef(block), buffer_index, index_type); + auto sub_index_map = f_get_sub_index_map(lhs_buffer, reindexed_buffer_region->region); + buffer_sub_index_map.Set(lhs_buffer, sub_index_map); + state->sch->TransformLayout(state->block_rv, buffer_index, index_type, sub_index_map, + /*pad_value=*/NullOpt, /*assume_injective_transform=*/true); + }; + + for (int i = 0, n = block_before_reindex->reads.size(); i < n; ++i) { + f_transform_buffer_layout(tir::BufferIndexType::kRead, i); + } + for (int i = 0, n = block_before_reindex->writes.size(); i < n; ++i) { + f_transform_buffer_layout(tir::BufferIndexType::kWrite, i); + } + + // Transform the layout of current block and reindex blocks + auto f_transform_reindex_block_layout = [&](const BlockRV& block_rv, + tir::BufferIndexType buffer_type) { + tir::Buffer buffer = + tir::GetNthAccessBuffer(state->sch->state(), state->sch->Get(block_rv), 0, buffer_type); + const auto& sub_index_map = buffer_sub_index_map.at(buffer); + state->sch->TransformBlockLayout(block_rv, sub_index_map); + }; + f_transform_reindex_block_layout(state->tensor_core_reindex_store, tir::BufferIndexType::kWrite); + f_transform_reindex_block_layout(state->tensor_core_reindex_A, tir::BufferIndexType::kRead); + f_transform_reindex_block_layout(state->tensor_core_reindex_B, tir::BufferIndexType::kRead); + state->sch->TransformBlockLayout(state->block_rv, index_map); + return tir::TileWithTensorIntrin(state->sch, state->block_rv, intrin_name, + /*allow_padding=*/true); +} + +inline std::vector MultiLevelTilingTensorCoreHopperNode::TransformForTensorization( + TensorCoreState state) const { + // Do reindex and layout transformations. + Optional transformed_loop_rv = + TransformWithTensorIntrin(state.operator->(), state->intrin_group.compute_intrin); + if (!transformed_loop_rv.defined()) { + // The workload can't be tensorized. + return {}; + } + + // Do blockize + state->block_rv = state->sch->Blockize(transformed_loop_rv.value()); + + // Add annotations for post processors. + state->sch->Annotate(state->block_rv, tir::attr::meta_schedule_auto_tensorize, + state->intrin_group.compute_intrin); + state->sch->Annotate(state->block_rv, tir::attr::meta_schedule_auto_tensorize_init, + state->intrin_group.init_intrin); + state->sch->Annotate(state->block_rv, tir::attr::warp_execution, Integer(1)); + return {std::move(state)}; +} + +ScheduleRule ScheduleRule::MultiLevelTilingTensorCore( + Array> intrin_groups, String structure, Optional> tile_binds, + Optional max_innermost_factor, Optional> vector_load_lens, + Optional> reuse_read, Optional> reuse_write, + bool use_software_pipeline) { + if (tile_binds.defined()) { + for (const String& tile_bind : tile_binds.value()) { + CHECK_NE(tile_bind, "threadIdx.x") << "Cannot bind to threadIdx.x when using tensor core."; + } + } + auto node = MultiLevelTilingInitCommon( + structure, tile_binds, max_innermost_factor, vector_load_lens, reuse_read, reuse_write); + + bool have_wmma_intrin_group = false; + node->intrin_groups.reserve(intrin_groups.size()); + for (const auto& intrin_group_config : intrin_groups) { + TensorCoreIntrinGroup group = TensorCoreIntrinGroup::FromConfig(intrin_group_config); + if (support::StartsWith(group.compute_intrin, "wmma")) { + have_wmma_intrin_group = true; + } + node->intrin_groups.emplace_back(group); + } + + if (have_wmma_intrin_group) { + CHECK(node->reuse_write_.req == ReuseType::kMustReuse && + runtime::StorageScope::Create(node->reuse_write_.scope).rank == + runtime::StorageRank::kShared) + << "ValueError: Shared memory write reuse must be enabled for MultiLevelTilingTensorCore."; + } + + node->use_software_pipeline = use_software_pipeline; + return ScheduleRule(node); +} + +TVM_REGISTER_NODE_TYPE(MultiLevelTilingTensorCoreHopperNode); +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleMultiLevelTilingTensorCoreHopper") + .set_body_typed(ScheduleRule::MultiLevelTilingTensorCore); + +} // namespace meta_schedule +} // namespace tvm diff --git a/tests/python/meta_schedule/test_meta_schedule_mma_m64n8k16_auto_tensorization.py b/tests/python/meta_schedule/test_meta_schedule_mma_m64n8k16_auto_tensorization.py new file mode 100644 index 000000000000..9a9471aef2ee --- /dev/null +++ b/tests/python/meta_schedule/test_meta_schedule_mma_m64n8k16_auto_tensorization.py @@ -0,0 +1,1237 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Tests for WGMMA m64n8k16 Auto Tensorization""" + +import tempfile +import numpy as np + +import tvm +from tvm import te +from tvm import meta_schedule as ms +from tvm._ffi import register_func +from tvm.meta_schedule.testing.space_generation import ( + check_sketches, + generate_design_space, +) +from tvm.meta_schedule.builder import LocalBuilder +from tvm.script import ir as I +from tvm.script import tir as T +from tvm.target import Target +from tvm.tir import Schedule +from tvm.tir.schedule import Trace + +# get tensor intrin +from tvm.tir.tensor_intrin import cuda # pylint: disable=unused-import + +import tvm.testing + + +@I.ir_module +class MmaModule: + @T.prim_func + def main( + X: T.Buffer((4096, 4096), "float16"), + Y: T.Buffer((4096, 4096), "float16"), + C: T.Buffer((4096, 4096), "float16"), + ): + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) + # with T.block("root"): + C_reindex_m64n8k16_matrixC = T.alloc_buffer((4096, 4096), "float16", scope="m64n8k16.matrixC") + X_reindex_shared_dyn = T.alloc_buffer((4096, 4096), "float16", scope="shared.dyn") + Y_reindex_shared_dyn = T.alloc_buffer((4096, 4096), "float16", scope="shared.dyn") + X_reindex_shared_dyn_m64n8k16_matrixA = T.alloc_buffer( + (4096, 4096), "float16", scope="m64n8k16.matrixA" + ) + Y_reindex_shared_dyn_m64n8k16_matrixB = T.alloc_buffer( + (4096, 4096), "float16", scope="m64n8k16.matrixB" + ) + for ax0_0_0_ax1_0_0_fused in T.thread_binding(4, thread="blockIdx.x"): + for ax0_0_1_ax1_0_1_fused in T.thread_binding(256, thread="blockIdx.y"): + for ax0_0_2_ax1_0_2_fused in T.thread_binding(4, thread="threadIdx.y"): + for ax2_0_0 in T.serial( + 128, + annotations={ + "software_pipeline_async_stages": [0], + "software_pipeline_order": [0, 1, 3, 2, 4], + "software_pipeline_stage": [0, 0, 1, 2, 2], + }, + ): + with T.block("X_reindex_shared.dyn"): + v0, v1 = T.axis.remap("SS", [ax0_0_1_ax1_0_1_fused, ax2_0_0]) + T.reads(X[v0 // 8 * 128 : v0 // 8 * 128 + 128, v1 * 32 : v1 * 32 + 32]) + T.writes( + X_reindex_shared_dyn[ + v0 // 8 * 128 : v0 // 8 * 128 + 128, v1 * 32 : v1 * 32 + 32 + ] + ) + T.block_attr( + { + "auto_copy": 1, + "buffer_dim_align": [[0, 0, 32, 8]], + "permuted_layout": "g2s_A", + "vector_bytes": 16, + } + ) + for ax0, ax1 in T.grid(128, 32): + X_reindex_shared_dyn[v0 // 8 * 128 + ax0, v1 * 32 + ax1] = X[ + v0 // 8 * 128 + ax0, v1 * 32 + ax1 + ] + with T.block("Y_reindex_shared.dyn"): + v0, v1, v2 = T.axis.remap( + "SSS", [ax2_0_0, ax0_0_0_ax1_0_0_fused, ax0_0_1_ax1_0_1_fused] + ) + T.reads( + Y[ + v0 * 32 : v0 * 32 + 32, + v1 * 1024 + v2 % 8 * 128 : v1 * 1024 + v2 % 8 * 128 + 128, + ] + ) + T.writes( + Y_reindex_shared_dyn[ + v0 * 32 : v0 * 32 + 32, + v1 * 1024 + v2 % 8 * 128 : v1 * 1024 + v2 % 8 * 128 + 128, + ] + ) + T.block_attr( + { + "auto_copy": 1, + "buffer_dim_align": [[0, 0, 32, 8]], + "permuted_layout": "g2s_B", + "vector_bytes": 16, + } + ) + for ax0, ax1 in T.grid(32, 128): + Y_reindex_shared_dyn[ + v0 * 32 + ax0, v1 * 1024 + v2 % 8 * 128 + ax1 + ] = Y[v0 * 32 + ax0, v1 * 1024 + v2 % 8 * 128 + ax1] + for ax2_0_1 in T.serial( + 4, + annotations={ + "software_pipeline_order": [0, 1, 2], + "software_pipeline_stage": [0, 0, 1], + }, + ): + for ax0_0, ax1_0 in T.grid(2, 1): + with T.block("X_reindex_shared.dyn_m64n8k16.matrixA_o"): + v0_o = T.axis.spatial( + 128, + ax0_0_1_ax1_0_1_fused // 8 * 4 + + ax0_0_2_ax1_0_2_fused // 2 * 2 + + ax0_0, + ) + v1_o = T.axis.spatial(512, ax2_0_0 * 4 + ax2_0_1 + ax1_0) + T.reads( + X_reindex_shared_dyn[ + v0_o * 32 : v0_o * 32 + 32, v1_o * 8 : v1_o * 8 + 8 + ] + ) + T.writes( + X_reindex_shared_dyn_m64n8k16_matrixA[ + v0_o * 32 : v0_o * 32 + 32, v1_o * 8 : v1_o * 8 + 8 + ] + ) + T.block_attr( + { + "meta_schedule.auto_tensorize": "mma_load_m64n8k16_f16_A_shared_dyn", + "permuted_layout": "s2l_A", + } + ) + for ax0_1, ax1_1 in T.grid(32, 8): + with T.block("X_reindex_shared.dyn_m64n8k16.matrixA"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads( + X_reindex_shared_dyn[ + v0_o * 32 + v0_i, v1_o * 8 + v1_i + ] + ) + T.writes( + X_reindex_shared_dyn_m64n8k16_matrixA[ + v0_o * 32 + v0_i, v1_o * 8 + v1_i + ] + ) + X_reindex_shared_dyn_m64n8k16_matrixA[ + v0_o * 32 + v0_i, v1_o * 8 + v1_i + ] = X_reindex_shared_dyn[ + v0_o * 32 + v0_i, v1_o * 8 + v1_i + ] + for ax0_0, ax1_0 in T.grid(1, 2): + with T.block("Y_reindex_shared.dyn_m64n8k16.matrixB_o"): + v0_o = T.axis.spatial(512, ax2_0_0 * 4 + ax2_0_1 + ax0_0) + v1_o = T.axis.spatial( + 128, + ax0_0_0_ax1_0_0_fused * 32 + + ax0_0_1_ax1_0_1_fused % 8 * 4 + + ax0_0_2_ax1_0_2_fused % 2 * 2 + + ax1_0, + ) + T.reads( + Y_reindex_shared_dyn[ + v0_o * 8 : v0_o * 8 + 8, v1_o * 32 : v1_o * 32 + 32 + ] + ) + T.writes( + Y_reindex_shared_dyn_m64n8k16_matrixB[ + v0_o * 8 : v0_o * 8 + 8, v1_o * 32 : v1_o * 32 + 32 + ] + ) + T.block_attr( + { + "meta_schedule.auto_tensorize": "mma_load_m64n8k16_f16_B_shared_dyn", + "permuted_layout": "s2l_B", + } + ) + for ax0_1, ax1_1 in T.grid(8, 32): + with T.block("Y_reindex_shared.dyn_m64n8k16.matrixB"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads( + Y_reindex_shared_dyn[ + v0_o * 8 + v0_i, v1_o * 32 + v1_i + ] + ) + T.writes( + Y_reindex_shared_dyn_m64n8k16_matrixB[ + v0_o * 8 + v0_i, v1_o * 32 + v1_i + ] + ) + Y_reindex_shared_dyn_m64n8k16_matrixB[ + v0_o * 8 + v0_i, v1_o * 32 + v1_i + ] = Y_reindex_shared_dyn[ + v0_o * 8 + v0_i, v1_o * 32 + v1_i + ] + for ax0_0_3, ax1_0_3, ax2_0_2, ax0_0_4, ax1_0_4 in T.grid( + 1, 1, 1, 4, 8 + ): + with T.block("C_o"): + v0_o = T.axis.spatial( + 256, + ax0_0_1_ax1_0_1_fused // 8 * 8 + + ax0_0_2_ax1_0_2_fused // 2 * 4 + + ax0_0_3 * 4 + + ax0_0_4, + ) + v1_o = T.axis.spatial( + 512, + ax0_0_0_ax1_0_0_fused * 128 + + ax0_0_1_ax1_0_1_fused % 8 * 16 + + ax0_0_2_ax1_0_2_fused % 2 * 8 + + ax1_0_3 * 8 + + ax1_0_4, + ) + v2_o = T.axis.reduce(512, ax2_0_0 * 4 + ax2_0_1 + ax2_0_2) + T.reads( + X_reindex_shared_dyn_m64n8k16_matrixA[ + v0_o * 16 : v0_o * 16 + 16, v2_o * 8 : v2_o * 8 + 8 + ], + Y_reindex_shared_dyn_m64n8k16_matrixB[ + v2_o * 8 : v2_o * 8 + 8, v1_o * 8 : v1_o * 8 + 8 + ], + ) + T.writes( + C_reindex_m64n8k16_matrixC[ + v0_o * 16 : v0_o * 16 + 16, v1_o * 8 : v1_o * 8 + 8 + ] + ) + T.block_attr( + { + "meta_schedule.auto_tensorize": "mma_sync_m64n8k16_f16f16f16", + "meta_schedule.auto_tensorize_init": "mma_init_m64n8k16_f16", + "meta_schedule.thread_extent_high_inclusive": 1024, + "meta_schedule.thread_extent_low_inclusive": 32, + "warp_execution": 1, + } + ) + with T.init(): + for ax0_1, ax1_1 in T.grid(16, 8): + with T.block("C_init"): + v0_i_init, v1_i_init = T.axis.remap( + "SS", [ax0_1, ax1_1] + ) + T.reads() + T.writes( + C_reindex_m64n8k16_matrixC[ + v0_o * 16 + v0_i_init, v1_o * 8 + v1_i_init + ] + ) + C_reindex_m64n8k16_matrixC[ + v0_o * 16 + v0_i_init, v1_o * 8 + v1_i_init + ] = T.float16(0) + for ax0_1, ax1_1, ax2_1 in T.grid(16, 8, 8): + with T.block("C"): + v0_i, v1_i, v2_i = T.axis.remap( + "SSR", [ax0_1, ax1_1, ax2_1] + ) + T.reads( + C_reindex_m64n8k16_matrixC[ + v0_o * 16 + v0_i, v1_o * 8 + v1_i + ], + X_reindex_shared_dyn_m64n8k16_matrixA[ + v0_o * 16 + v0_i, v2_o * 8 + v2_i + ], + Y_reindex_shared_dyn_m64n8k16_matrixB[ + v2_o * 8 + v2_i, v1_o * 8 + v1_i + ], + ) + T.writes( + C_reindex_m64n8k16_matrixC[ + v0_o * 16 + v0_i, v1_o * 8 + v1_i + ] + ) + T.block_attr( + {"meta_schedule.tiling_structure": "SSSRRSRS"} + ) + C_reindex_m64n8k16_matrixC[ + v0_o * 16 + v0_i, v1_o * 8 + v1_i + ] = ( + C_reindex_m64n8k16_matrixC[ + v0_o * 16 + v0_i, v1_o * 8 + v1_i + ] + + X_reindex_shared_dyn_m64n8k16_matrixA[ + v0_o * 16 + v0_i, v2_o * 8 + v2_i + ] + * Y_reindex_shared_dyn_m64n8k16_matrixB[ + v2_o * 8 + v2_i, v1_o * 8 + v1_i + ] + ) + with T.block("C_reindex_m64n8k16.matrixC"): + v0, v1, v2 = T.axis.remap( + "SSS", + [ax0_0_1_ax1_0_1_fused, ax0_0_2_ax1_0_2_fused, ax0_0_0_ax1_0_0_fused], + ) + T.reads( + C_reindex_m64n8k16_matrixC[ + v0 // 8 * 128 + v1 // 2 * 64 : v0 // 8 * 128 + v1 // 2 * 64 + 64, + v2 * 1024 + + v0 % 8 * 128 + + v1 % 2 * 64 : v2 * 1024 + + v0 % 8 * 128 + + v1 % 2 * 64 + + 64, + ] + ) + T.writes( + C[ + v0 // 8 * 128 + v1 // 2 * 64 : v0 // 8 * 128 + v1 // 2 * 64 + 64, + v2 * 1024 + + v0 % 8 * 128 + + v1 % 2 * 64 : v2 * 1024 + + v0 % 8 * 128 + + v1 % 2 * 64 + + 64, + ] + ) + T.block_attr({"auto_copy": 1}) + for ax0, ax1 in T.grid(64, 64): + C[ + v0 // 8 * 128 + v1 // 2 * 64 + ax0, + v2 * 1024 + v0 % 8 * 128 + v1 % 2 * 64 + ax1, + ] = C_reindex_m64n8k16_matrixC[ + v0 // 8 * 128 + v1 // 2 * 64 + ax0, + v2 * 1024 + v0 % 8 * 128 + v1 % 2 * 64 + ax1, + ] + + +def matmul_fp16(N: int, M: int, K: int, out_dtype: str): + x = te.placeholder((N, K), name="X", dtype="float16") + y = te.placeholder((K, M), name="Y", dtype="float16") + k = te.reduce_axis((0, K), name="k") + c = te.compute( # pylint: disable=invalid-name + (N, M), + lambda i, j: te.sum(x[i][k].astype(out_dtype) * y[k][j].astype(out_dtype), axis=[k]), + name="C", + ) + return (x, y, c) + + +def multi_level_tiling_mma(out_dtype): + simplify_dict = {"float32": "f32", "float16": "f16"} + out_dtype = simplify_dict[out_dtype] + return ms.schedule_rule.MultiLevelTilingTensorCore( + intrin_groups=[ + { + "init": f"mma_init_m64n8k16_{out_dtype}", + "load_a": "mma_load_m64n8k16_f16_A_shared_dyn", + "load_b": "mma_load_m64n8k16_f16_B_shared_dyn", + "compute": f"mma_sync_m64n8k16_f16f16{out_dtype}", + "store": f"mma_store_m64n8k16_{out_dtype}_global", + }, + ], + structure="SSSRRSRS", + tile_binds=["blockIdx.x", "blockIdx.y", "threadIdx.y"], + max_innermost_factor=4, # 64 // tensor intrin size + vector_load_lens=[1, 2, 3, 4, 8, 16], + reuse_read=ms.schedule_rule.ReuseType( + req="must", + levels=[4], + scope="shared.dyn", + ), + reuse_write=ms.schedule_rule.ReuseType( + req="no", + levels=[2], + scope="shared.dyn", + ), + use_software_pipeline=True, + ) + + +def _design_space(mod, out_dtype): + return generate_design_space( + kind="cuda-tensorcore", + mod=mod, + target=Target("nvidia/h100"), + types=None, + sch_rules=[multi_level_tiling_mma(out_dtype)], + ) + + +gemm_decision = [ + ("SamplePartitionedTile", [1, 32, 2, 1, 4]), + ("SamplePartitionedTile", [4, 8, 2, 1, 8]), + ("SamplePerfectTile", [128, 4, 1]), +] + + +def test_mma_auto_tensorization(): + mod = te.create_prim_func(matmul_fp16(M=4096, N=4096, K=4096, out_dtype="float16")) + actual = _design_space(mod, "float16") + check_sketches( + mod, + sketches=actual, + expected_mods=[MmaModule], + expected_decisions=[gemm_decision], + ) + + +expected_cuda_script = r"""#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) +#include +__device__ half max(half a, half b) +{ + return __hgt(__half(a), __half(b)) ? a : b; +} +__device__ half min(half a, half b) +{ + return __hlt(__half(a), __half(b)) ? a : b; +} +#else + +typedef unsigned short uint16_t; +typedef unsigned char uint8_t; +typedef signed char int8_t; +typedef int int32_t; +typedef unsigned long long uint64_t; +typedef unsigned int uint32_t; + +#define TVM_FORCE_INLINE inline __attribute__((always_inline)) +#define TVM_XINLINE TVM_FORCE_INLINE __device__ __host__ +#define TVM_ALIGNED(x) __attribute__ ((aligned(x))) +#define TVM_HALF_OPERATOR(RTYPE, OP) \ + TVM_XINLINE RTYPE operator OP (half a, half b) { \ + return RTYPE(float(a) OP float(b)); \ + } \ + template \ + TVM_XINLINE RTYPE operator OP (half a, T b) { \ + return RTYPE(float(a) OP float(b)); \ + } \ + template \ + TVM_XINLINE RTYPE operator OP (T a, half b) { \ + return RTYPE(float(a) OP float(b)); \ + } + +#define TVM_HALF_ASSIGNOP(AOP, OP) \ + template \ + TVM_XINLINE half operator AOP (const T& a) { \ + return *this = half(float(*this) OP float(a)); \ + } \ + template \ + TVM_XINLINE half operator AOP (const volatile T& a) volatile { \ + return *this = half(float(*this) OP float(a)); \ + } + +class TVM_ALIGNED(2) half { + public: + uint16_t half_; + + static TVM_XINLINE half Binary(uint16_t value) { + half res; + res.half_ = value; + return res; + } + + TVM_XINLINE half() {} + + TVM_XINLINE half(const float& value) { constructor(value); } + TVM_XINLINE explicit half(const double& value) { constructor(value); } + TVM_XINLINE explicit half(const int8_t& value) { constructor(value); } + TVM_XINLINE explicit half(const uint8_t& value) { constructor(value); } + TVM_XINLINE explicit half(const int32_t& value) { constructor(value); } + TVM_XINLINE explicit half(const uint32_t& value) { constructor(value); } + TVM_XINLINE explicit half(const long long& value) { constructor(value); } + TVM_XINLINE explicit half(const uint64_t& value) { constructor(value); } + + TVM_XINLINE operator float() const { \ + return float(half2float(half_)); \ + } \ + TVM_XINLINE operator float() const volatile { \ + return float(half2float(half_)); \ + } + + + TVM_HALF_ASSIGNOP(+=, +) + TVM_HALF_ASSIGNOP(-=, -) + TVM_HALF_ASSIGNOP(*=, *) + TVM_HALF_ASSIGNOP(/=, /) + + TVM_XINLINE half operator+() { + return *this; + } + + TVM_XINLINE half operator-() { + return half(-float(*this)); + } + + TVM_XINLINE half operator=(const half& a) { + half_ = a.half_; + return a; + } + + template + TVM_XINLINE half operator=(const T& a) { + return *this = half(a); + } + + TVM_XINLINE half operator=(const half& a) volatile { + half_ = a.half_; + return a; + } + + template + TVM_XINLINE half operator=(const T& a) volatile { + return *this = half(a); + } + + private: + union Bits { + float f; + int32_t si; + uint32_t ui; + }; + + static int const fp16FractionBits = 10; + static int const fp32FractionBits = 23; + static int32_t const fp32FractionMask = ~(~0u << fp32FractionBits); // == 0x7fffff + static int32_t const fp32HiddenBit = 1 << fp32FractionBits; // == 0x800000 + static int const shift = fp32FractionBits - fp16FractionBits; // == 13 + static int const shiftSign = 16; + static int32_t const expAdjust = 127 - 15; // exp32-127 = exp16-15, so exp16 = exp32 - (127-15) + + static int32_t const infN = 0x7F800000; // flt32 infinity + static int32_t const maxN = 0x477FFFFF; // max flt32 that's a flt16 normal after >> by shift + static int32_t const minN = 0x38800000; // min flt16 normal as a flt32 + static int32_t const maxZ = 0x33000000; // max fp32 number that's still rounded to zero in fp16 + static int32_t const signN = 0x80000000; // flt32 sign bit + + static int32_t const infC = infN >> shift; + static int32_t const nanN = (infC + 1) << shift; // minimum flt16 nan as a flt32 + static int32_t const maxC = maxN >> shift; + static int32_t const minC = minN >> shift; + static int32_t const signC = signN >> shiftSign; // flt16 sign bit + + static int32_t const mulN = 0x52000000; // (1 << 23) / minN + static int32_t const mulC = 0x33800000; // minN / (1 << (23 - shift)) + + static int32_t const subC = 0x003FF; // max flt32 subnormal down shifted + static int32_t const norC = 0x00400; // min flt32 normal down shifted + + static int32_t const maxD = infC - maxC - 1; + static int32_t const minD = minC - subC - 1; + + TVM_XINLINE uint16_t float2half(const float& value) const { + Bits v; + v.f = value; + uint32_t sign = v.si & signN; // grab sign bit + v.si ^= sign; // clear sign bit from v + sign >>= shiftSign; // logical shift sign to fp16 position + + if (v.si <= maxZ) { + // Handle eventual zeros here to ensure + // vshift will not exceed 32 below. + v.ui = 0; + } else if (v.si < minN) { + // Handle denorms + uint32_t exp32 = v.ui >> fp32FractionBits; + int32_t exp16 = exp32 - expAdjust; + // If exp16 == 0 (just into the denorm range), then significant should be shifted right 1. + // Smaller (so negative) exp16 values should result in greater right shifts. + uint32_t vshift = 1 - exp16; + uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask); + v.ui = significand >> vshift; + v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0; + } else if (v.si <= maxN) { + // Handle norms + v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0; + v.ui -= expAdjust << fp32FractionBits; + } else if (v.si <= infN) { + v.si = infN; + } else if (v.si < nanN) { + v.si = nanN; + } + + v.ui >>= shift; + return sign | (v.ui & 0x7fff); + } + + // Same as above routine, except for addition of volatile keyword + TVM_XINLINE uint16_t float2half( + const volatile float& value) const volatile { + Bits v; + v.f = value; + uint32_t sign = v.si & signN; // grab sign bit + v.si ^= sign; // clear sign bit from v + sign >>= shiftSign; // logical shift sign to fp16 position + + if (v.si <= maxZ) { + // Handle eventual zeros here to ensure + // vshift will not exceed 32 below. + v.ui = 0; + } else if (v.si < minN) { + // Handle denorms + uint32_t exp32 = v.ui >> fp32FractionBits; + int32_t exp16 = exp32 - expAdjust; + // If exp16 == 0 (just into the denorm range), then significant should be shifted right 1. + // Smaller (so negative) exp16 values should result in greater right shifts. + uint32_t vshift = 1 - exp16; + uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask); + v.ui = significand >> vshift; + v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0; + } else if (v.si <= maxN) { + // Handle norms + v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0; + v.ui -= expAdjust << fp32FractionBits; + } else if (v.si <= infN) { + v.si = infN; + } else if (v.si < nanN) { + v.si = nanN; + } + + v.ui >>= shift; + return sign | (v.ui & 0x7fff); + } + + TVM_XINLINE float half2float(const uint16_t& value) const { + Bits v; + v.ui = value; + int32_t sign = v.si & signC; + v.si ^= sign; + sign <<= shiftSign; + v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC); + v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC); + Bits s; + s.si = mulC; + s.f *= v.si; + int32_t mask = -(norC > v.si); + v.si <<= shift; + v.si ^= (s.si ^ v.si) & mask; + v.si |= sign; + return v.f; + } + + TVM_XINLINE float half2float( + const volatile uint16_t& value) const volatile { + Bits v; + v.ui = value; + int32_t sign = v.si & signC; + v.si ^= sign; + sign <<= shiftSign; + v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC); + v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC); + Bits s; + s.si = mulC; + s.f *= v.si; + int32_t mask = -(norC > v.si); + v.si <<= shift; + v.si ^= (s.si ^ v.si) & mask; + v.si |= sign; + return v.f; + } + + template + TVM_XINLINE void constructor(const T& value) { + half_ = float2half(float(value)); + } +}; + +TVM_HALF_OPERATOR(half, +) +TVM_HALF_OPERATOR(half, -) +TVM_HALF_OPERATOR(half, *) +TVM_HALF_OPERATOR(half, /) +TVM_HALF_OPERATOR(bool, >) +TVM_HALF_OPERATOR(bool, <) +TVM_HALF_OPERATOR(bool, >=) +TVM_HALF_OPERATOR(bool, <=) + +TVM_XINLINE half __float2half_rn(const float a) { + return half(a); +} +#endif + + +// Pack two half values. +static inline __device__ __host__ unsigned +__pack_half2(const half x, const half y) { + unsigned v0 = *((unsigned short *)&x); + unsigned v1 = *((unsigned short *)&y); + return (v1 << 16) | v0; +} + +#define CUDA_UNSUPPORTED_HALF_MATH_BINARY(HALF_MATH_NAME, FP32_MATH_NAME) \ +static inline __device__ __host__ half HALF_MATH_NAME(half x, half y) { \ + float tmp_x = __half2float(x); \ + float tmp_y = __half2float(y); \ + float result = FP32_MATH_NAME(tmp_x, tmp_y); \ + return __float2half(result); \ +} + +#define CUDA_UNSUPPORTED_HALF_MATH_UNARY(HALF_MATH_NAME, FP32_MATH_NAME) \ +static inline __device__ __host__ half HALF_MATH_NAME(half x) { \ + float tmp_x = __half2float(x); \ + float result = FP32_MATH_NAME(tmp_x); \ + return __float2half(result); \ +} + +// Some fp16 math functions are not supported in cuda_fp16.h, +// so we define them here to make sure the generated CUDA code +// is valid. +#if defined(__CUDA_ARCH__) +#if (__CUDA_ARCH__ >= 530) +CUDA_UNSUPPORTED_HALF_MATH_BINARY(hpow, powf) +CUDA_UNSUPPORTED_HALF_MATH_UNARY(htanh, tanhf) +CUDA_UNSUPPORTED_HALF_MATH_UNARY(htan, tanf) +CUDA_UNSUPPORTED_HALF_MATH_UNARY(hatan, atanf) +CUDA_UNSUPPORTED_HALF_MATH_UNARY(herf, erf) +#else +CUDA_UNSUPPORTED_HALF_MATH_UNARY(hexp, exp) +#endif +#endif + +#undef CUDA_UNSUPPORTED_HALF_MATH_BINARY +#undef CUDA_UNSUPPORTED_HALF_MATH_UNARY +__forceinline__ __device__ unsigned int +cast_smem_ptr_to_int(const void* const smem_ptr) +{ + unsigned int smem_int; + asm volatile ("{ .reg .u64 smem_int; cvta.to.shared.u64 smem_int, %1; cvt.u32.u64 %0, smem_int; }" + : "=r"(smem_int) : "l"(smem_ptr)); + return smem_int; +} + +#if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || \ + (__CUDACC_VER_MAJOR__ > 11)) +#define TVM_ENABLE_L2_PREFETCH 1 +#else +#define TVM_ENABLE_L2_PREFETCH 0 +#endif + +#ifdef _WIN32 + using uint = unsigned int; + using uchar = unsigned char; + using ushort = unsigned short; + using int64_t = long long; + using uint64_t = unsigned long long; +#else + #define uint unsigned int + #define uchar unsigned char + #define ushort unsigned short + #define int64_t long long + #define uint64_t unsigned long long +#endif +extern "C" __global__ void __launch_bounds__(128) main_kernel(half* __restrict__ C, half* __restrict__ X, half* __restrict__ Y); +extern "C" __global__ void __launch_bounds__(128) main_kernel(half* __restrict__ C, half* __restrict__ X, half* __restrict__ Y) { + extern __shared__ uchar buf_dyn_shmem[]; + uint1 C_reindex_m64n8k16_matrixC[64]; + half X_reindex_shared_dyn_m64n8k16_matrixA[32]; + half Y_reindex_shared_dyn_m64n8k16_matrixB[32]; + for (int ax0_0_4_init = 0; ax0_0_4_init < 4; ++ax0_0_4_init) { + for (int ax1_0_4_init = 0; ax1_0_4_init < 8; ++ax1_0_4_init) { + for (int b = 0; b < 2; ++b) { + C_reindex_m64n8k16_matrixC[(((ax0_0_4_init * 16) + (ax1_0_4_init * 2)) + b)] = make_uint1(__pack_half2(__float2half_rn(0.000000e+00f), __float2half_rn(0.000000e+00f))); + } + } + } + for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 4; ++ax0_ax1_fused_0) { + + { + unsigned int addr = cast_smem_ptr_to_int(buf_dyn_shmem + ((((ax0_ax1_fused_0 * 2048) + (((int)threadIdx.y) * 512)) + ((((int)threadIdx.x) >> 2) * 64)) + (((((int)threadIdx.x) & 3) ^ (((int)threadIdx.x) >> 3)) * 16))); + __asm__ __volatile__( + #if TVM_ENABLE_L2_PREFETCH + "cp.async.cg.shared.global.L2::128B [%0], [%1], %2;" + #else + "cp.async.cg.shared.global [%0], [%1], %2;" + #endif + :: "r"(addr), "l"((void*)(X + ((((((((int)blockIdx.y) >> 3) * 524288) + (ax0_ax1_fused_0 * 131072)) + (((int)threadIdx.y) * 32768)) + ((((int)threadIdx.x) >> 2) * 4096)) + ((((int)threadIdx.x) & 3) * 8)))), "n"(16) + ); + } + } + for (int ax0_ax1_fused_0_1 = 0; ax0_ax1_fused_0_1 < 4; ++ax0_ax1_fused_0_1) { + + { + unsigned int addr = cast_smem_ptr_to_int(buf_dyn_shmem + (((((ax0_ax1_fused_0_1 * 2048) + (((int)threadIdx.y) * 512)) + ((((int)threadIdx.x) >> 4) * 256)) + (((((int)threadIdx.x) & 15) ^ ((((int)threadIdx.y) * 2) + (((int)threadIdx.x) >> 4))) * 16)) + 24576)); + __asm__ __volatile__( + #if TVM_ENABLE_L2_PREFETCH + "cp.async.cg.shared.global.L2::128B [%0], [%1], %2;" + #else + "cp.async.cg.shared.global [%0], [%1], %2;" + #endif + :: "r"(addr), "l"((void*)(Y + ((((((ax0_ax1_fused_0_1 * 32768) + (((int)threadIdx.y) * 8192)) + ((((int)threadIdx.x) >> 4) * 4096)) + (((int)blockIdx.x) * 1024)) + ((((int)blockIdx.y) & 7) * 128)) + ((((int)threadIdx.x) & 15) * 8)))), "n"(16) + ); + } + } +__asm__ __volatile__("cp.async.commit_group;"); + + for (int ax0_ax1_fused_0_2 = 0; ax0_ax1_fused_0_2 < 4; ++ax0_ax1_fused_0_2) { + + { + unsigned int addr = cast_smem_ptr_to_int(buf_dyn_shmem + (((((ax0_ax1_fused_0_2 * 2048) + (((int)threadIdx.y) * 512)) + ((((int)threadIdx.x) >> 2) * 64)) + (((((int)threadIdx.x) & 3) ^ (((int)threadIdx.x) >> 3)) * 16)) + 8192)); + __asm__ __volatile__( + #if TVM_ENABLE_L2_PREFETCH + "cp.async.cg.shared.global.L2::128B [%0], [%1], %2;" + #else + "cp.async.cg.shared.global [%0], [%1], %2;" + #endif + :: "r"(addr), "l"((void*)(X + (((((((((int)blockIdx.y) >> 3) * 524288) + (ax0_ax1_fused_0_2 * 131072)) + (((int)threadIdx.y) * 32768)) + ((((int)threadIdx.x) >> 2) * 4096)) + ((((int)threadIdx.x) & 3) * 8)) + 32))), "n"(16) + ); + } + } + for (int ax0_ax1_fused_0_3 = 0; ax0_ax1_fused_0_3 < 4; ++ax0_ax1_fused_0_3) { + + { + unsigned int addr = cast_smem_ptr_to_int(buf_dyn_shmem + (((((ax0_ax1_fused_0_3 * 2048) + (((int)threadIdx.y) * 512)) + ((((int)threadIdx.x) >> 4) * 256)) + (((((int)threadIdx.x) & 15) ^ ((((int)threadIdx.y) * 2) + (((int)threadIdx.x) >> 4))) * 16)) + 32768)); + __asm__ __volatile__( + #if TVM_ENABLE_L2_PREFETCH + "cp.async.cg.shared.global.L2::128B [%0], [%1], %2;" + #else + "cp.async.cg.shared.global [%0], [%1], %2;" + #endif + :: "r"(addr), "l"((void*)(Y + (((((((ax0_ax1_fused_0_3 * 32768) + (((int)threadIdx.y) * 8192)) + ((((int)threadIdx.x) >> 4) * 4096)) + (((int)blockIdx.x) * 1024)) + ((((int)blockIdx.y) & 7) * 128)) + ((((int)threadIdx.x) & 15) * 8)) + 131072))), "n"(16) + ); + } + } +__asm__ __volatile__("cp.async.commit_group;"); + +__asm__ __volatile__("cp.async.wait_group 1;"); + + __syncthreads(); + for (int ax0_0 = 0; ax0_0 < 2; ++ax0_0) { + + { + unsigned int addr = cast_smem_ptr_to_int((&(((half*)buf_dyn_shmem)[(((((((int)threadIdx.y) >> 1) * 2048) + (ax0_0 * 1024)) + (((int)threadIdx.x) * 32)) + ((0 ^ ((((int)threadIdx.x) & 7) >> 1)) * 8))])) + 0); + __asm__ __volatile__( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned *)(X_reindex_shared_dyn_m64n8k16_matrixA + (ax0_0 * 8)))[0]), "=r"(((unsigned *)(X_reindex_shared_dyn_m64n8k16_matrixA + (ax0_0 * 8)))[1]), "=r"(((unsigned *)(X_reindex_shared_dyn_m64n8k16_matrixA + (ax0_0 * 8)))[2]), "=r"(((unsigned *)(X_reindex_shared_dyn_m64n8k16_matrixA + (ax0_0 * 8)))[3]) + : "r"(addr) + ); + } + } + for (int ax1_0 = 0; ax1_0 < 2; ++ax1_0) { + + { + unsigned int addr = cast_smem_ptr_to_int((&(((half*)buf_dyn_shmem)[((((((int)threadIdx.x) & 7) * 128) + ((((((((int)threadIdx.y) & 1) * 8) + (ax1_0 * 4)) + (((int)threadIdx.x) >> 3)) ^ (((int)threadIdx.x) & 7)) * 8)) + 12288)])) + 0); + __asm__ __volatile__( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned *)(Y_reindex_shared_dyn_m64n8k16_matrixB + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(Y_reindex_shared_dyn_m64n8k16_matrixB + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(Y_reindex_shared_dyn_m64n8k16_matrixB + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(Y_reindex_shared_dyn_m64n8k16_matrixB + (ax1_0 * 8)))[3]) + : "r"(addr) + ); + } + } + for (int ax2_0_0 = 0; ax2_0_0 < 126; ++ax2_0_0) { + __syncthreads(); + for (int ax0_ax1_fused_0_4 = 0; ax0_ax1_fused_0_4 < 4; ++ax0_ax1_fused_0_4) { + + { + unsigned int addr = cast_smem_ptr_to_int(buf_dyn_shmem + (((((((ax2_0_0 + 2) % 3) * 8192) + (ax0_ax1_fused_0_4 * 2048)) + (((int)threadIdx.y) * 512)) + ((((int)threadIdx.x) >> 2) * 64)) + (((((int)threadIdx.x) & 3) ^ (((int)threadIdx.x) >> 3)) * 16))); + __asm__ __volatile__( + #if TVM_ENABLE_L2_PREFETCH + "cp.async.cg.shared.global.L2::128B [%0], [%1], %2;" + #else + "cp.async.cg.shared.global [%0], [%1], %2;" + #endif + :: "r"(addr), "l"((void*)(X + ((((((((((int)blockIdx.y) >> 3) * 524288) + (ax0_ax1_fused_0_4 * 131072)) + (((int)threadIdx.y) * 32768)) + ((((int)threadIdx.x) >> 2) * 4096)) + (ax2_0_0 * 32)) + ((((int)threadIdx.x) & 3) * 8)) + 64))), "n"(16) + ); + } + } + for (int ax0_ax1_fused_0_5 = 0; ax0_ax1_fused_0_5 < 4; ++ax0_ax1_fused_0_5) { + + { + unsigned int addr = cast_smem_ptr_to_int(buf_dyn_shmem + ((((((((ax2_0_0 + 2) % 3) * 8192) + (ax0_ax1_fused_0_5 * 2048)) + (((int)threadIdx.y) * 512)) + ((((int)threadIdx.x) >> 4) * 256)) + (((((int)threadIdx.x) & 15) ^ ((((int)threadIdx.y) * 2) + (((int)threadIdx.x) >> 4))) * 16)) + 24576)); + __asm__ __volatile__( + #if TVM_ENABLE_L2_PREFETCH + "cp.async.cg.shared.global.L2::128B [%0], [%1], %2;" + #else + "cp.async.cg.shared.global [%0], [%1], %2;" + #endif + :: "r"(addr), "l"((void*)(Y + ((((((((ax2_0_0 * 131072) + (ax0_ax1_fused_0_5 * 32768)) + (((int)threadIdx.y) * 8192)) + ((((int)threadIdx.x) >> 4) * 4096)) + (((int)blockIdx.x) * 1024)) + ((((int)blockIdx.y) & 7) * 128)) + ((((int)threadIdx.x) & 15) * 8)) + 262144))), "n"(16) + ); + } + } +__asm__ __volatile__("cp.async.commit_group;"); + +__asm__ __volatile__("cp.async.wait_group 1;"); + + __syncthreads(); + for (int ax2_0_1 = 0; ax2_0_1 < 3; ++ax2_0_1) { + for (int ax0_0_1 = 0; ax0_0_1 < 2; ++ax0_0_1) { + + { + unsigned int addr = cast_smem_ptr_to_int((&(((half*)buf_dyn_shmem)[((((((ax2_0_0 % 3) * 4096) + ((((int)threadIdx.y) >> 1) * 2048)) + (ax0_0_1 * 1024)) + (((int)threadIdx.x) * 32)) + (((ax2_0_1 + 1) ^ ((((int)threadIdx.x) & 7) >> 1)) * 8))])) + 0); + __asm__ __volatile__( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned *)(X_reindex_shared_dyn_m64n8k16_matrixA + ((((ax2_0_1 + 1) & 1) * 16) + (ax0_0_1 * 8))))[0]), "=r"(((unsigned *)(X_reindex_shared_dyn_m64n8k16_matrixA + ((((ax2_0_1 + 1) & 1) * 16) + (ax0_0_1 * 8))))[1]), "=r"(((unsigned *)(X_reindex_shared_dyn_m64n8k16_matrixA + ((((ax2_0_1 + 1) & 1) * 16) + (ax0_0_1 * 8))))[2]), "=r"(((unsigned *)(X_reindex_shared_dyn_m64n8k16_matrixA + ((((ax2_0_1 + 1) & 1) * 16) + (ax0_0_1 * 8))))[3]) + : "r"(addr) + ); + } + } + for (int ax1_0_1 = 0; ax1_0_1 < 2; ++ax1_0_1) { + + { + unsigned int addr = cast_smem_ptr_to_int((&(((half*)buf_dyn_shmem)[((((((ax2_0_0 % 3) * 4096) + (ax2_0_1 * 1024)) + ((((int)threadIdx.x) & 7) * 128)) + ((((((((int)threadIdx.y) & 1) * 8) + (ax1_0_1 * 4)) + (((int)threadIdx.x) >> 3)) ^ (((int)threadIdx.x) & 7)) * 8)) + 13312)])) + 0); + __asm__ __volatile__( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned *)(Y_reindex_shared_dyn_m64n8k16_matrixB + ((((ax2_0_1 + 1) & 1) * 16) + (ax1_0_1 * 8))))[0]), "=r"(((unsigned *)(Y_reindex_shared_dyn_m64n8k16_matrixB + ((((ax2_0_1 + 1) & 1) * 16) + (ax1_0_1 * 8))))[1]), "=r"(((unsigned *)(Y_reindex_shared_dyn_m64n8k16_matrixB + ((((ax2_0_1 + 1) & 1) * 16) + (ax1_0_1 * 8))))[2]), "=r"(((unsigned *)(Y_reindex_shared_dyn_m64n8k16_matrixB + ((((ax2_0_1 + 1) & 1) * 16) + (ax1_0_1 * 8))))[3]) + : "r"(addr) + ); + } + } + for (int ax0_0_4 = 0; ax0_0_4 < 4; ++ax0_0_4) { + for (int ax1_0_4 = 0; ax1_0_4 < 8; ++ax1_0_4) { + + { + __asm__ __volatile__( + "mma.sync.aligned.m64n8k16.row.col.f16.f16.f16.f16" + "{%0, %1}, {%2, %3}, {%4}, {%5, %6};\n" + : "=r"(((unsigned *)(C_reindex_m64n8k16_matrixC + ((ax0_0_4 * 16) + (ax1_0_4 * 2))))[0]), "=r"(((unsigned *)(C_reindex_m64n8k16_matrixC + ((ax0_0_4 * 16) + (ax1_0_4 * 2))))[1]) + : "r"(((unsigned *)(X_reindex_shared_dyn_m64n8k16_matrixA + (((ax2_0_1 & 1) * 16) + (ax0_0_4 * 4))))[0]), "r"(((unsigned *)(X_reindex_shared_dyn_m64n8k16_matrixA + (((ax2_0_1 & 1) * 16) + (ax0_0_4 * 4))))[1]), "r"(((unsigned *)(Y_reindex_shared_dyn_m64n8k16_matrixB + (((ax2_0_1 & 1) * 16) + (ax1_0_4 * 2))))[0]), "r"(((unsigned *)(C_reindex_m64n8k16_matrixC + ((ax0_0_4 * 16) + (ax1_0_4 * 2))))[0]), "r"(((unsigned *)(C_reindex_m64n8k16_matrixC + ((ax0_0_4 * 16) + (ax1_0_4 * 2))))[1])); + } + } + } + } + for (int ax0_0_2 = 0; ax0_0_2 < 2; ++ax0_0_2) { + + { + unsigned int addr = cast_smem_ptr_to_int((&(((half*)buf_dyn_shmem)[(((((((ax2_0_0 + 1) % 3) * 4096) + ((((int)threadIdx.y) >> 1) * 2048)) + (ax0_0_2 * 1024)) + (((int)threadIdx.x) * 32)) + ((0 ^ ((((int)threadIdx.x) & 7) >> 1)) * 8))])) + 0); + __asm__ __volatile__( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned *)(X_reindex_shared_dyn_m64n8k16_matrixA + (ax0_0_2 * 8)))[0]), "=r"(((unsigned *)(X_reindex_shared_dyn_m64n8k16_matrixA + (ax0_0_2 * 8)))[1]), "=r"(((unsigned *)(X_reindex_shared_dyn_m64n8k16_matrixA + (ax0_0_2 * 8)))[2]), "=r"(((unsigned *)(X_reindex_shared_dyn_m64n8k16_matrixA + (ax0_0_2 * 8)))[3]) + : "r"(addr) + ); + } + } + for (int ax1_0_2 = 0; ax1_0_2 < 2; ++ax1_0_2) { + + { + unsigned int addr = cast_smem_ptr_to_int((&(((half*)buf_dyn_shmem)[((((((ax2_0_0 + 1) % 3) * 4096) + ((((int)threadIdx.x) & 7) * 128)) + ((((((((int)threadIdx.y) & 1) * 8) + (ax1_0_2 * 4)) + (((int)threadIdx.x) >> 3)) ^ (((int)threadIdx.x) & 7)) * 8)) + 12288)])) + 0); + __asm__ __volatile__( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned *)(Y_reindex_shared_dyn_m64n8k16_matrixB + (ax1_0_2 * 8)))[0]), "=r"(((unsigned *)(Y_reindex_shared_dyn_m64n8k16_matrixB + (ax1_0_2 * 8)))[1]), "=r"(((unsigned *)(Y_reindex_shared_dyn_m64n8k16_matrixB + (ax1_0_2 * 8)))[2]), "=r"(((unsigned *)(Y_reindex_shared_dyn_m64n8k16_matrixB + (ax1_0_2 * 8)))[3]) + : "r"(addr) + ); + } + } + for (int ax0_0_4_1 = 0; ax0_0_4_1 < 4; ++ax0_0_4_1) { + for (int ax1_0_4_1 = 0; ax1_0_4_1 < 8; ++ax1_0_4_1) { + + { + __asm__ __volatile__( + "mma.sync.aligned.m64n8k16.row.col.f16.f16.f16.f16" + "{%0, %1}, {%2, %3}, {%4}, {%5, %6};\n" + : "=r"(((unsigned *)(C_reindex_m64n8k16_matrixC + ((ax0_0_4_1 * 16) + (ax1_0_4_1 * 2))))[0]), "=r"(((unsigned *)(C_reindex_m64n8k16_matrixC + ((ax0_0_4_1 * 16) + (ax1_0_4_1 * 2))))[1]) + : "r"(((unsigned *)(X_reindex_shared_dyn_m64n8k16_matrixA + ((ax0_0_4_1 * 4) + 16)))[0]), "r"(((unsigned *)(X_reindex_shared_dyn_m64n8k16_matrixA + ((ax0_0_4_1 * 4) + 16)))[1]), "r"(((unsigned *)(Y_reindex_shared_dyn_m64n8k16_matrixB + ((ax1_0_4_1 * 2) + 16)))[0]), "r"(((unsigned *)(C_reindex_m64n8k16_matrixC + ((ax0_0_4_1 * 16) + (ax1_0_4_1 * 2))))[0]), "r"(((unsigned *)(C_reindex_m64n8k16_matrixC + ((ax0_0_4_1 * 16) + (ax1_0_4_1 * 2))))[1])); + } + } + } + } +__asm__ __volatile__("cp.async.wait_group 0;"); + + __syncthreads(); + for (int ax2_0_1_1 = 0; ax2_0_1_1 < 3; ++ax2_0_1_1) { + for (int ax0_0_3 = 0; ax0_0_3 < 2; ++ax0_0_3) { + + { + unsigned int addr = cast_smem_ptr_to_int((&(((half*)buf_dyn_shmem)[(((((((int)threadIdx.y) >> 1) * 2048) + (ax0_0_3 * 1024)) + (((int)threadIdx.x) * 32)) + (((ax2_0_1_1 + 1) ^ ((((int)threadIdx.x) & 7) >> 1)) * 8))])) + 0); + __asm__ __volatile__( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned *)(X_reindex_shared_dyn_m64n8k16_matrixA + ((((ax2_0_1_1 + 1) & 1) * 16) + (ax0_0_3 * 8))))[0]), "=r"(((unsigned *)(X_reindex_shared_dyn_m64n8k16_matrixA + ((((ax2_0_1_1 + 1) & 1) * 16) + (ax0_0_3 * 8))))[1]), "=r"(((unsigned *)(X_reindex_shared_dyn_m64n8k16_matrixA + ((((ax2_0_1_1 + 1) & 1) * 16) + (ax0_0_3 * 8))))[2]), "=r"(((unsigned *)(X_reindex_shared_dyn_m64n8k16_matrixA + ((((ax2_0_1_1 + 1) & 1) * 16) + (ax0_0_3 * 8))))[3]) + : "r"(addr) + ); + } + } + for (int ax1_0_3 = 0; ax1_0_3 < 2; ++ax1_0_3) { + + { + unsigned int addr = cast_smem_ptr_to_int((&(((half*)buf_dyn_shmem)[((((ax2_0_1_1 * 1024) + ((((int)threadIdx.x) & 7) * 128)) + ((((((((int)threadIdx.y) & 1) * 8) + (ax1_0_3 * 4)) + (((int)threadIdx.x) >> 3)) ^ (((int)threadIdx.x) & 7)) * 8)) + 13312)])) + 0); + __asm__ __volatile__( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned *)(Y_reindex_shared_dyn_m64n8k16_matrixB + ((((ax2_0_1_1 + 1) & 1) * 16) + (ax1_0_3 * 8))))[0]), "=r"(((unsigned *)(Y_reindex_shared_dyn_m64n8k16_matrixB + ((((ax2_0_1_1 + 1) & 1) * 16) + (ax1_0_3 * 8))))[1]), "=r"(((unsigned *)(Y_reindex_shared_dyn_m64n8k16_matrixB + ((((ax2_0_1_1 + 1) & 1) * 16) + (ax1_0_3 * 8))))[2]), "=r"(((unsigned *)(Y_reindex_shared_dyn_m64n8k16_matrixB + ((((ax2_0_1_1 + 1) & 1) * 16) + (ax1_0_3 * 8))))[3]) + : "r"(addr) + ); + } + } + for (int ax0_0_4_2 = 0; ax0_0_4_2 < 4; ++ax0_0_4_2) { + for (int ax1_0_4_2 = 0; ax1_0_4_2 < 8; ++ax1_0_4_2) { + + { + __asm__ __volatile__( + "mma.sync.aligned.m64n8k16.row.col.f16.f16.f16.f16" + "{%0, %1}, {%2, %3}, {%4}, {%5, %6};\n" + : "=r"(((unsigned *)(C_reindex_m64n8k16_matrixC + ((ax0_0_4_2 * 16) + (ax1_0_4_2 * 2))))[0]), "=r"(((unsigned *)(C_reindex_m64n8k16_matrixC + ((ax0_0_4_2 * 16) + (ax1_0_4_2 * 2))))[1]) + : "r"(((unsigned *)(X_reindex_shared_dyn_m64n8k16_matrixA + (((ax2_0_1_1 & 1) * 16) + (ax0_0_4_2 * 4))))[0]), "r"(((unsigned *)(X_reindex_shared_dyn_m64n8k16_matrixA + (((ax2_0_1_1 & 1) * 16) + (ax0_0_4_2 * 4))))[1]), "r"(((unsigned *)(Y_reindex_shared_dyn_m64n8k16_matrixB + (((ax2_0_1_1 & 1) * 16) + (ax1_0_4_2 * 2))))[0]), "r"(((unsigned *)(C_reindex_m64n8k16_matrixC + ((ax0_0_4_2 * 16) + (ax1_0_4_2 * 2))))[0]), "r"(((unsigned *)(C_reindex_m64n8k16_matrixC + ((ax0_0_4_2 * 16) + (ax1_0_4_2 * 2))))[1])); + } + } + } + } + for (int ax0_0_5 = 0; ax0_0_5 < 2; ++ax0_0_5) { + + { + unsigned int addr = cast_smem_ptr_to_int((&(((half*)buf_dyn_shmem)[((((((((int)threadIdx.y) >> 1) * 2048) + (ax0_0_5 * 1024)) + (((int)threadIdx.x) * 32)) + ((0 ^ ((((int)threadIdx.x) & 7) >> 1)) * 8)) + 4096)])) + 0); + __asm__ __volatile__( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned *)(X_reindex_shared_dyn_m64n8k16_matrixA + (ax0_0_5 * 8)))[0]), "=r"(((unsigned *)(X_reindex_shared_dyn_m64n8k16_matrixA + (ax0_0_5 * 8)))[1]), "=r"(((unsigned *)(X_reindex_shared_dyn_m64n8k16_matrixA + (ax0_0_5 * 8)))[2]), "=r"(((unsigned *)(X_reindex_shared_dyn_m64n8k16_matrixA + (ax0_0_5 * 8)))[3]) + : "r"(addr) + ); + } + } + for (int ax1_0_5 = 0; ax1_0_5 < 2; ++ax1_0_5) { + + { + unsigned int addr = cast_smem_ptr_to_int((&(((half*)buf_dyn_shmem)[((((((int)threadIdx.x) & 7) * 128) + ((((((((int)threadIdx.y) & 1) * 8) + (ax1_0_5 * 4)) + (((int)threadIdx.x) >> 3)) ^ (((int)threadIdx.x) & 7)) * 8)) + 16384)])) + 0); + __asm__ __volatile__( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned *)(Y_reindex_shared_dyn_m64n8k16_matrixB + (ax1_0_5 * 8)))[0]), "=r"(((unsigned *)(Y_reindex_shared_dyn_m64n8k16_matrixB + (ax1_0_5 * 8)))[1]), "=r"(((unsigned *)(Y_reindex_shared_dyn_m64n8k16_matrixB + (ax1_0_5 * 8)))[2]), "=r"(((unsigned *)(Y_reindex_shared_dyn_m64n8k16_matrixB + (ax1_0_5 * 8)))[3]) + : "r"(addr) + ); + } + } + for (int ax0_0_4_3 = 0; ax0_0_4_3 < 4; ++ax0_0_4_3) { + for (int ax1_0_4_3 = 0; ax1_0_4_3 < 8; ++ax1_0_4_3) { + + { + __asm__ __volatile__( + "mma.sync.aligned.m64n8k16.row.col.f16.f16.f16.f16" + "{%0, %1}, {%2, %3}, {%4}, {%5, %6};\n" + : "=r"(((unsigned *)(C_reindex_m64n8k16_matrixC + ((ax0_0_4_3 * 16) + (ax1_0_4_3 * 2))))[0]), "=r"(((unsigned *)(C_reindex_m64n8k16_matrixC + ((ax0_0_4_3 * 16) + (ax1_0_4_3 * 2))))[1]) + : "r"(((unsigned *)(X_reindex_shared_dyn_m64n8k16_matrixA + ((ax0_0_4_3 * 4) + 16)))[0]), "r"(((unsigned *)(X_reindex_shared_dyn_m64n8k16_matrixA + ((ax0_0_4_3 * 4) + 16)))[1]), "r"(((unsigned *)(Y_reindex_shared_dyn_m64n8k16_matrixB + ((ax1_0_4_3 * 2) + 16)))[0]), "r"(((unsigned *)(C_reindex_m64n8k16_matrixC + ((ax0_0_4_3 * 16) + (ax1_0_4_3 * 2))))[0]), "r"(((unsigned *)(C_reindex_m64n8k16_matrixC + ((ax0_0_4_3 * 16) + (ax1_0_4_3 * 2))))[1])); + } + } + } + for (int ax2_0_1_2 = 0; ax2_0_1_2 < 3; ++ax2_0_1_2) { + for (int ax0_0_6 = 0; ax0_0_6 < 2; ++ax0_0_6) { + + { + unsigned int addr = cast_smem_ptr_to_int((&(((half*)buf_dyn_shmem)[((((((((int)threadIdx.y) >> 1) * 2048) + (ax0_0_6 * 1024)) + (((int)threadIdx.x) * 32)) + (((ax2_0_1_2 + 1) ^ ((((int)threadIdx.x) & 7) >> 1)) * 8)) + 4096)])) + 0); + __asm__ __volatile__( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned *)(X_reindex_shared_dyn_m64n8k16_matrixA + ((((ax2_0_1_2 + 1) & 1) * 16) + (ax0_0_6 * 8))))[0]), "=r"(((unsigned *)(X_reindex_shared_dyn_m64n8k16_matrixA + ((((ax2_0_1_2 + 1) & 1) * 16) + (ax0_0_6 * 8))))[1]), "=r"(((unsigned *)(X_reindex_shared_dyn_m64n8k16_matrixA + ((((ax2_0_1_2 + 1) & 1) * 16) + (ax0_0_6 * 8))))[2]), "=r"(((unsigned *)(X_reindex_shared_dyn_m64n8k16_matrixA + ((((ax2_0_1_2 + 1) & 1) * 16) + (ax0_0_6 * 8))))[3]) + : "r"(addr) + ); + } + } + for (int ax1_0_6 = 0; ax1_0_6 < 2; ++ax1_0_6) { + + { + unsigned int addr = cast_smem_ptr_to_int((&(((half*)buf_dyn_shmem)[((((ax2_0_1_2 * 1024) + ((((int)threadIdx.x) & 7) * 128)) + ((((((((int)threadIdx.y) & 1) * 8) + (ax1_0_6 * 4)) + (((int)threadIdx.x) >> 3)) ^ (((int)threadIdx.x) & 7)) * 8)) + 17408)])) + 0); + __asm__ __volatile__( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned *)(Y_reindex_shared_dyn_m64n8k16_matrixB + ((((ax2_0_1_2 + 1) & 1) * 16) + (ax1_0_6 * 8))))[0]), "=r"(((unsigned *)(Y_reindex_shared_dyn_m64n8k16_matrixB + ((((ax2_0_1_2 + 1) & 1) * 16) + (ax1_0_6 * 8))))[1]), "=r"(((unsigned *)(Y_reindex_shared_dyn_m64n8k16_matrixB + ((((ax2_0_1_2 + 1) & 1) * 16) + (ax1_0_6 * 8))))[2]), "=r"(((unsigned *)(Y_reindex_shared_dyn_m64n8k16_matrixB + ((((ax2_0_1_2 + 1) & 1) * 16) + (ax1_0_6 * 8))))[3]) + : "r"(addr) + ); + } + } + for (int ax0_0_4_4 = 0; ax0_0_4_4 < 4; ++ax0_0_4_4) { + for (int ax1_0_4_4 = 0; ax1_0_4_4 < 8; ++ax1_0_4_4) { + + { + __asm__ __volatile__( + "mma.sync.aligned.m64n8k16.row.col.f16.f16.f16.f16" + "{%0, %1}, {%2, %3}, {%4}, {%5, %6};\n" + : "=r"(((unsigned *)(C_reindex_m64n8k16_matrixC + ((ax0_0_4_4 * 16) + (ax1_0_4_4 * 2))))[0]), "=r"(((unsigned *)(C_reindex_m64n8k16_matrixC + ((ax0_0_4_4 * 16) + (ax1_0_4_4 * 2))))[1]) + : "r"(((unsigned *)(X_reindex_shared_dyn_m64n8k16_matrixA + (((ax2_0_1_2 & 1) * 16) + (ax0_0_4_4 * 4))))[0]), "r"(((unsigned *)(X_reindex_shared_dyn_m64n8k16_matrixA + (((ax2_0_1_2 & 1) * 16) + (ax0_0_4_4 * 4))))[1]), "r"(((unsigned *)(Y_reindex_shared_dyn_m64n8k16_matrixB + (((ax2_0_1_2 & 1) * 16) + (ax1_0_4_4 * 2))))[0]), "r"(((unsigned *)(C_reindex_m64n8k16_matrixC + ((ax0_0_4_4 * 16) + (ax1_0_4_4 * 2))))[0]), "r"(((unsigned *)(C_reindex_m64n8k16_matrixC + ((ax0_0_4_4 * 16) + (ax1_0_4_4 * 2))))[1])); + } + } + } + } + for (int ax0_0_4_5 = 0; ax0_0_4_5 < 4; ++ax0_0_4_5) { + for (int ax1_0_4_5 = 0; ax1_0_4_5 < 8; ++ax1_0_4_5) { + + { + __asm__ __volatile__( + "mma.sync.aligned.m64n8k16.row.col.f16.f16.f16.f16" + "{%0, %1}, {%2, %3}, {%4}, {%5, %6};\n" + : "=r"(((unsigned *)(C_reindex_m64n8k16_matrixC + ((ax0_0_4_5 * 16) + (ax1_0_4_5 * 2))))[0]), "=r"(((unsigned *)(C_reindex_m64n8k16_matrixC + ((ax0_0_4_5 * 16) + (ax1_0_4_5 * 2))))[1]) + : "r"(((unsigned *)(X_reindex_shared_dyn_m64n8k16_matrixA + ((ax0_0_4_5 * 4) + 16)))[0]), "r"(((unsigned *)(X_reindex_shared_dyn_m64n8k16_matrixA + ((ax0_0_4_5 * 4) + 16)))[1]), "r"(((unsigned *)(Y_reindex_shared_dyn_m64n8k16_matrixB + ((ax1_0_4_5 * 2) + 16)))[0]), "r"(((unsigned *)(C_reindex_m64n8k16_matrixC + ((ax0_0_4_5 * 16) + (ax1_0_4_5 * 2))))[0]), "r"(((unsigned *)(C_reindex_m64n8k16_matrixC + ((ax0_0_4_5 * 16) + (ax1_0_4_5 * 2))))[1])); + } + } + } + for (int ax0_0_7 = 0; ax0_0_7 < 8; ++ax0_0_7) { + __syncthreads(); + for (int ax1_0_7 = 0; ax1_0_7 < 8; ++ax1_0_7) { + *(uint1*)(((half*)buf_dyn_shmem) + ((((((int)threadIdx.x) * 2050) + (((int)threadIdx.y) * 512)) + (ax1_0_7 * 64)) + 12288)) = C_reindex_m64n8k16_matrixC[((((ax0_0_7 >> 1) * 16) + (ax1_0_7 * 2)) + (ax0_0_7 & 1))]; + } + __syncthreads(); + for (int threadIdx_x_cache_threadIdx_y_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_0 = 0; threadIdx_x_cache_threadIdx_y_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_0 < 512; ++threadIdx_x_cache_threadIdx_y_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_0) { + C[(((((((((((((int)blockIdx.y) >> 3) * 524288) + (((threadIdx_x_cache_threadIdx_y_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_0 & 15) >> 3) * 262144)) + (ax0_0_7 * 32768)) + ((((int)threadIdx.y) & 1) * 16384)) + ((((int)threadIdx.x) >> 3) * 4096)) + (((int)blockIdx.x) * 1024)) + ((((int)blockIdx.y) & 7) * 128)) + ((threadIdx_x_cache_threadIdx_y_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_0 & 7) * 16)) + ((((int)threadIdx.y) >> 1) * 8)) + (((int)threadIdx.x) & 7))] = ((half*)buf_dyn_shmem)[((((threadIdx_x_cache_threadIdx_y_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_0 * 128) + (((int)threadIdx.y) * 32)) + ((int)threadIdx.x)) + 12288)]; + } + } +} + +""" + + +@tvm.testing.requires_tensorcore +def test_mma_script_after_build(): + arch = tvm.contrib.nvcc.get_target_compute_version() + major, _ = tvm.contrib.nvcc.parse_compute_version(arch) + if major < 9: + # At least sm90 is required + return + + mod = te.create_prim_func(matmul_fp16(M=4096, N=4096, K=4096, out_dtype="float16")) + actual = _design_space(mod, "float16") + assert len(actual) == 1 + sketch = actual[0] + + i = 0 + new_decisions = {} + for inst in sketch.trace.insts: + if not inst.kind.name.startswith("Sample"): + continue + assert i < len(gemm_decision) + if inst.kind.name == gemm_decision[i][0]: + new_decisions[inst] = gemm_decision[i][1] + i += 1 + assert len(new_decisions) == len(gemm_decision) + sch = Schedule(mod) + Trace( + insts=sketch.trace.insts, + decisions=new_decisions, + ).apply_to_schedule(sch, remove_postproc=True) + + sch.enter_postproc() + # DefaultCUDATensorCore + ms.postproc.DisallowDynamicLoop().apply(sch) + ms.postproc.RewriteCooperativeFetch().apply(sch) + # Disable RewriteUnboundBlock here since max_threads_per_block_ is not set + # ms.postproc.RewriteUnboundBlock(256).apply(sch) + ms.postproc.RewriteParallelVectorizeUnroll().apply(sch) + ms.postproc.RewriteReductionBlock().apply(sch) + ms.postproc.VerifyGPUCode().apply(sch) + ms.postproc.RewriteTensorize(False).apply(sch) + + with tvm.transform.PassContext(config={"tir.use_async_copy": 1}): + rt_mod = tvm.build(sch.mod, target="cuda") + print(rt_mod.imported_modules[0].get_source()) + assert rt_mod.imported_modules[0].get_source() == expected_cuda_script + + +def initializer(): + @register_func("meta_schedule.builder.async_build") + def async_build(mod, target, _params): # pylint: disable=unused-variable, unused-argument + # pylint: disable=import-outside-toplevel + from tvm.driver import build as tvm_build + from tvm.tir.transform import RemoveWeightLayoutRewriteBlock + + # re-import here for local builder to register index_map_m64n8k16_matrixC + # pylint: disable=import-outside-toplevel, unused-import + from tvm.tir.tensor_intrin import cuda + + mod = RemoveWeightLayoutRewriteBlock(skip_ndarray_rewrite=True)(mod) + with tvm.transform.PassContext(config={"tir.use_async_copy": 1}): + rt_mod = tvm_build(mod, target=target) + return rt_mod + + +@tvm.testing.requires_tensorcore +@tvm.testing.requires_cublas +def test_mma_tune(): + arch = tvm.contrib.nvcc.get_target_compute_version() + major, _ = tvm.contrib.nvcc.parse_compute_version(arch) + if major < 8: + # At least sm80 is required + return + + # pylint: disable=import-outside-toplevel + from tvm.contrib import cublas + + def tune(out_dtype): + M, N, K = 1024, 1024, 1024 + target = Target("nvidia/h100") + func = te.create_prim_func(matmul_fp16(N=N, M=M, K=K, out_dtype=out_dtype)).with_attr( + {"global_symbol": "main"} + ) + mod = tvm.IRModule({"main": func}) + + with tempfile.TemporaryDirectory() as work_dir: + db = ms.tir_integration.tune_tir( + mod=mod, + target=target, + work_dir=work_dir, + max_trials_global=8, + builder=LocalBuilder( + f_build="meta_schedule.builder.async_build", initializer=initializer + ), + space=ms.space_generator.PostOrderApply( + sch_rules=[multi_level_tiling_mma(out_dtype=out_dtype)], + ), + ) + sch = db.query_schedule(mod, target=target, workload_name="main") + with tvm.transform.PassContext(config={"tir.use_async_copy": 1}): + rt_mod = tvm.build(sch.mod, target=target) + + a_np = np.random.uniform(0, 1, size=(M, K)).astype("float16") + b_np = np.random.uniform(0, 1, size=(K, N)).astype("float16") + A_cublas = te.placeholder((M, K), name="A", dtype="float16") + B_cublas = te.placeholder((K, N), name="B", dtype="float16") + C_cublas = cublas.matmul(A_cublas, B_cublas, dtype=out_dtype) + s = te.create_schedule(C_cublas.op) + dev = tvm.cuda(0) + f_cublas = tvm.build(s, [A_cublas, B_cublas, C_cublas], target) + a_cublas = tvm.nd.array(a_np.astype("float16"), dev) + b_cublas = tvm.nd.array(b_np.astype("float16"), dev) + c_cublas = tvm.nd.array(np.zeros((M, N), dtype=C_cublas.dtype), dev) + f_cublas(a_cublas, b_cublas, c_cublas) + a_tvm = tvm.nd.array(a_np, device=tvm.cuda(0)) + b_tvm = tvm.nd.array(b_np, device=tvm.cuda(0)) + c_tvm = tvm.nd.array(np.empty((M, N)).astype(out_dtype), device=tvm.cuda(0)) + rt_mod(a_tvm, b_tvm, c_tvm) + assert np.allclose(c_tvm.numpy(), c_cublas.numpy(), rtol=1e-2) + + tune("float16") + tune("float32") + + +if __name__ == "__main__": + test_mma_auto_tensorization() + test_mma_script_after_build() + test_mma_tune() diff --git a/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc_hopper.py b/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc_hopper.py new file mode 100644 index 000000000000..b54c06313025 --- /dev/null +++ b/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc_hopper.py @@ -0,0 +1,1506 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring,line-too-long,invalid-name,too-many-locals,too-many-statements,too-many-nested-blocks,too-many-branches,too-many-lines,chained-comparison + +import pytest + +import tvm +import tvm.testing +from tvm import meta_schedule as ms +from tvm import te +from tvm.meta_schedule.testing import te_workload +from tvm.meta_schedule.testing.space_generation import ( + check_sketches, + generate_design_space, + get_rules, + print_sketches, +) +from tvm.script import tir as T +from tvm.tir.tensor_intrin.cuda import get_wgmma_intrin_group + + +def multi_level_tiling_tensor_core( + *, + read_reuse_scope="shared", + write_reuse_scope="shared", + in_dtype="float16", + out_dtype="float32", + trans_b=False, + use_software_pipeline=False, +) -> ms.schedule_rule.ScheduleRule: + assert read_reuse_scope in ["shared", "shared.dyn"] + assert write_reuse_scope in ["shared", "shared.dyn", "global"] + if not isinstance(in_dtype, list): + in_dtype = [in_dtype] + if not isinstance(out_dtype, list): + out_dtype = [out_dtype] + if not isinstance(trans_b, list): + trans_b = [trans_b] + return ms.schedule_rule.MultiLevelTilingTensorCoreHopper( + intrin_groups=[ + get_wgmma_intrin_group( + read_reuse_scope, write_reuse_scope, _in_dtype, _out_dtype, _trans_b + ) + for _in_dtype in in_dtype + for _out_dtype in out_dtype + for _trans_b in trans_b + ], + structure="SSSRRSRS", + tile_binds=["blockIdx.y", "blockIdx.x", "threadIdx.y"], + max_innermost_factor=4, # 64 // tensor intrin size + vector_load_lens=[1, 2, 3, 4, 8, 16], + reuse_read=ms.schedule_rule.ReuseType( + req="must", + levels=[4], + scope=read_reuse_scope, + ), + reuse_write=ms.schedule_rule.ReuseType( + req="must" if write_reuse_scope.startswith("shared") else "no", + levels=[2], + scope=write_reuse_scope, + ), + use_software_pipeline=use_software_pipeline, + ) + + +@pytest.mark.parametrize("shared_scope", ["shared", "shared.dyn"]) +def test_matmul_relu(shared_scope): + intrin_suffix = shared_scope.replace(".", "_") + # fmt: off + @T.prim_func + def matmul_relu_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, 128), "float16"), compute: T.Buffer((128, 128), "float32")) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) + # with T.block("root"): + C_reindex_shared = T.alloc_buffer((4, 8, 2, 1, 16, 16), scope=shared_scope) + C_reindex_shared_wmma_accumulator = T.alloc_buffer((4, 8, 2, 1, 16, 16), scope="wmma.accumulator") + A_reindex_shared = T.alloc_buffer((128, 128), "float16", scope=shared_scope) + B_reindex_shared = T.alloc_buffer((128, 128), "float16", scope=shared_scope) + A_reindex_shared_wmma_matrix_a = T.alloc_buffer((128, 128), "float16", scope="wmma.matrix_a") + B_reindex_shared_wmma_matrix_b = T.alloc_buffer((128, 128), "float16", scope="wmma.matrix_b") + for ax0_0_0_ax1_0_0_fused in T.thread_binding(8, thread="blockIdx.y"): + for ax0_0_1_ax1_0_1_fused in T.thread_binding(2, thread="blockIdx.x"): + for ax0_0_2_ax1_0_2_fused in T.thread_binding(2, thread="threadIdx.y"): + for ax2_0_0 in range(1): + for ax0_ax1_fused in range(4096): + with T.block("A_reindex_shared"): + v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused // 2 * 32 + ax0_ax1_fused // 128) + v1 = T.axis.spatial(128, ax0_ax1_fused % 128) + T.reads(A[v0, v1]) + T.writes(A_reindex_shared[v0, v1]) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 8}) + A_reindex_shared[v0, v1] = A[v0, v1] + for ax0_ax1_fused in range(4096): + with T.block("B_reindex_shared"): + v0 = T.axis.spatial(128, ax0_ax1_fused // 32) + v1 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused % 2 * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_fused % 32) + T.reads(B[v0, v1]) + T.writes(B_reindex_shared[v0, v1]) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 1}) + B_reindex_shared[v0, v1] = B[v0, v1] + for ax2_0_1 in range(4): + for ax0_0, ax1_0 in T.grid(2, 2): + with T.block("A_reindex_shared_wmma.matrix_a_o"): + v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused // 2 * 2 + ax0_0) + v1_o = T.axis.spatial(8, ax2_0_1 * 2 + ax1_0) + T.reads(A_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize": f"wmma_load_16x16x16_f16_a_{intrin_suffix}"}) + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("A_reindex_shared_wmma.matrix_a"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads(A_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = A_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + for ax0_0, ax1_0 in T.grid(2, 1): + with T.block("B_reindex_shared_wmma.matrix_b_o"): + v0_o = T.axis.spatial(8, ax2_0_1 * 2 + ax0_0) + v1_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused + ax1_0) + T.reads(B_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize": f"wmma_load_16x16x16_f16_b_{intrin_suffix}"}) + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("B_reindex_shared_wmma.matrix_b"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads(B_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + B_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = B_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + for ax0_0_3, ax1_0_3, ax2_0_2, ax0_0_4, ax1_0_4 in T.grid(1, 1, 2, 2, 1): + with T.block("C_o"): + v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused // 2 * 2 + ax0_0_3 * 2 + ax0_0_4) + v1_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused + ax1_0_3 + ax1_0_4) + v2_o = T.axis.reduce(8, ax2_0_0 * 8 + ax2_0_1 * 2 + ax2_0_2) + T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16:v2_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1}) + with T.init(): + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("C_init"): + v0_i_init, v1_i_init = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads() + T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i_init, v1_i_init]) + C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i_init, v1_i_init] = T.float32(0) + for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 16): + with T.block("C"): + v0_i, v1_i, v2_i = T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1]) + T.reads(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i, v1_i], A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i, v1_i]) + T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) + C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i, v1_i] = C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i, v1_i] + T.Cast("float32", A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i]) * T.Cast("float32", B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + for ax2 in range(2): + for ax0_ax1_fused in T.thread_binding(2, thread="threadIdx.y"): + for ax2_1, ax3 in T.grid(1, 1): + with T.block("C_reindex_shared_wmma.accumulator_o"): + v0 = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused // 2) + v1 = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_ax1_fused) + v2 = T.axis.spatial(2, ax2 + ax2_1) + v3 = T.axis.spatial(1, ax3) + v4_o = T.axis.spatial(1, 0) + v5_o = T.axis.spatial(1, 0) + T.reads(C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, 0:16, 0:16]) + T.writes(C_reindex_shared[v0, v1, v2, v3, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": f"wmma_store_16x16x16_f32_{intrin_suffix}"}) + for ax4, ax5 in T.grid(16, 16): + with T.block("C_reindex_shared_wmma.accumulator"): + v4_i, v5_i = T.axis.remap("SS", [ax4, ax5]) + T.reads(C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i]) + T.writes(C_reindex_shared[v0, v1, v2, v3, v4_i, v5_i]) + C_reindex_shared[v0, v1, v2, v3, v4_i, v5_i] = C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i] + for ax0_ax1_ax3_ax4_ax5_fused in range(512): + with T.block("C_reindex_shared"): + v0 = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused // 2) + v1 = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_ax1_ax3_ax4_ax5_fused // 256) + v2 = T.axis.spatial(2, ax2) + v3 = T.axis.spatial(1, 0) + v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 256 // 16) + v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 16) + T.reads(C_reindex_shared[v0, v1, v2, v3, v4, v5]) + T.writes(compute[v4 + v2 * 16 + v0 * 32, v5 + v1 * 16]) + T.block_attr({"meta_schedule.cooperative_fetch": 4}) + compute[v4 + v2 * 16 + v0 * 32, v5 + v1 * 16] = T.max(C_reindex_shared[v0, v1, v2, v3, v4, v5], T.float32(0)) + # fmt: on + decision_0 = [ + ("SamplePerfectTile", [4, 1, 1, 1, 2]), + ("SamplePerfectTile", [2, 2, 2, 1, 1]), + ("SamplePerfectTile", [1, 4, 2]), + ("SampleCategorical", 3), + ("SampleCategorical", 3), + ("SampleCategorical", 0), + ] + + mod = te.create_prim_func( + te_workload.matmul_relu( + n=128, + m=128, + k=128, + in_dtype="float16", + out_dtype="float32", + ) + ) + actual = generate_design_space( + kind="cuda", + mod=mod, + target=tvm.target.Target("cuda --arch=sm_90a"), + types=None, + sch_rules=[ + multi_level_tiling_tensor_core( + read_reuse_scope=shared_scope, write_reuse_scope=shared_scope + ), + ] + + get_rules(kind="cuda", types=ms.schedule_rule.AutoInline), + ) + check_sketches( + mod, + sketches=actual, + expected_mods=[matmul_relu_0], + expected_decisions=[decision_0], + ) + + +def test_matmul_relu_with_fallback(): + # fmt: off + @T.prim_func + def matmul_relu_fallback_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, 128), "float16"), compute: T.Buffer((128, 128), "float32")) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # with T.block("root"): + C_reindex_shared = T.alloc_buffer((4, 2, 2, 4, 16, 16), scope="shared") + C_reindex_shared_wmma_accumulator = T.alloc_buffer((4, 2, 2, 4, 16, 16), scope="wmma.accumulator") + A_reindex_shared = T.alloc_buffer((128, 128), "float16", scope="shared") + B_reindex_shared = T.alloc_buffer((128, 128), "float16", scope="shared") + A_reindex_shared_wmma_matrix_a = T.alloc_buffer((128, 128), "float16", scope="wmma.matrix_a") + B_reindex_shared_wmma_matrix_b = T.alloc_buffer((128, 128), "float16", scope="wmma.matrix_b") + for ax0_0_0_ax1_0_0_fused in T.thread_binding(2, thread="blockIdx.y"): + for ax0_0_1_ax1_0_1_fused in T.thread_binding(2, thread="blockIdx.x"): + for ax0_0_2_ax1_0_2_fused in T.thread_binding(2, thread="threadIdx.y"): + for ax2_0_0 in range(2): + for ax0_ax1_fused in range(2048): + with T.block("A_reindex_shared"): + v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_fused // 64) + v1 = T.axis.spatial(128, ax2_0_0 * 64 + ax0_ax1_fused % 64) + T.reads(A[v0, v1]) + T.writes(A_reindex_shared[v0, v1]) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 4}) + A_reindex_shared[v0, v1] = A[v0, v1] + for ax0_ax1_fused in range(8192): + with T.block("B_reindex_shared"): + v0 = T.axis.spatial(128, ax2_0_0 * 64 + ax0_ax1_fused // 128) + v1 = T.axis.spatial(128, ax0_ax1_fused % 128) + T.reads(B[v0, v1]) + T.writes(B_reindex_shared[v0, v1]) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 2}) + B_reindex_shared[v0, v1] = B[v0, v1] + for ax2_0_1 in range(1): + for ax0_0, ax1_0 in T.grid(2, 4): + with T.block("A_reindex_shared_wmma.matrix_a_o"): + v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0) + v1_o = T.axis.spatial(8, ax2_0_0 * 4 + ax1_0) + T.reads(A_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_a_shared"}) + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("A_reindex_shared_wmma.matrix_a"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads(A_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = A_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + for ax0_0, ax1_0 in T.grid(4, 4): + with T.block("B_reindex_shared_wmma.matrix_b_o"): + v0_o = T.axis.spatial(8, ax2_0_0 * 4 + ax0_0) + v1_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused * 4 + ax1_0) + T.reads(B_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_b_shared"}) + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("B_reindex_shared_wmma.matrix_b"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads(B_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + B_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = B_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + for ax0_0_3, ax1_0_3, ax2_0_2, ax0_0_4, ax1_0_4 in T.grid(1, 1, 4, 2, 4): + with T.block("C_o"): + v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_3 * 2 + ax0_0_4) + v1_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused * 4 + ax1_0_3 * 4 + ax1_0_4) + v2_o = T.axis.reduce(8, ax2_0_0 * 4 + ax2_0_1 * 4 + ax2_0_2) + T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16:v2_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1}) + with T.init(): + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("C_init"): + v0_i_init, v1_i_init = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads() + T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, v0_i_init, v1_i_init]) + C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, v0_i_init, v1_i_init] = T.float32(0) + for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 16): + with T.block("C"): + v0_i, v1_i, v2_i = T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1]) + T.reads(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, v0_i, v1_i], A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, v0_i, v1_i]) + T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) + C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, v0_i, v1_i] = C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, v0_i, v1_i] + T.Cast("float32", A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i]) * T.Cast("float32", B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + for ax2 in range(2): + for ax0_ax1_fused in T.thread_binding(2, thread="threadIdx.y"): + for ax2_1, ax3 in T.grid(1, 4): + with T.block("C_reindex_shared_wmma.accumulator_o"): + v0 = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused * 2 + ax0_0_1_ax1_0_1_fused) + v1 = T.axis.spatial(2, ax0_ax1_fused) + v2 = T.axis.spatial(2, ax2 + ax2_1) + v3 = T.axis.spatial(4, ax3) + v4_o = T.axis.spatial(1, 0) + v5_o = T.axis.spatial(1, 0) + T.reads(C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, 0:16, 0:16]) + T.writes(C_reindex_shared[v0, v1, v2, v3, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_store_16x16x16_f32_shared"}) + for ax4, ax5 in T.grid(16, 16): + with T.block("C_reindex_shared_wmma.accumulator"): + v4_i, v5_i = T.axis.remap("SS", [ax4, ax5]) + T.reads(C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i]) + T.writes(C_reindex_shared[v0, v1, v2, v3, v4_i, v5_i]) + C_reindex_shared[v0, v1, v2, v3, v4_i, v5_i] = C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i] + for ax0_ax1_ax3_ax4_ax5_fused in range(2048): + with T.block("C_reindex_shared"): + v0 = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused * 2 + ax0_0_1_ax1_0_1_fused) + v1 = T.axis.spatial(2, ax0_ax1_ax3_ax4_ax5_fused // 1024) + v2 = T.axis.spatial(2, ax2) + v3 = T.axis.spatial(4, ax0_ax1_ax3_ax4_ax5_fused % 1024 // 256) + v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 256 // 16) + v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 16) + T.reads(C_reindex_shared[v0, v1, v2, v3, v4, v5]) + T.writes(compute[v4 + v2 * 16 + v0 * 32, v5 + v3 * 16 + v1 * 64]) + T.block_attr({"meta_schedule.cooperative_fetch": 4}) + compute[v4 + v2 * 16 + v0 * 32, v5 + v3 * 16 + v1 * 64] = T.max(C_reindex_shared[v0, v1, v2, v3, v4, v5], T.float32(0)) + # fmt: on + decision_0 = [ + ("SamplePerfectTile", [2, 2, 1, 1, 2]), + ("SamplePerfectTile", [1, 1, 2, 1, 4]), + ("SamplePerfectTile", [2, 1, 4]), + ("SampleCategorical", 3), + ("SampleCategorical", 2), + ("SampleCategorical", 1), + ] + + mod = te.create_prim_func( + te_workload.matmul_relu( + n=128, + m=128, + k=128, + in_dtype="float16", + out_dtype="float32", + ) + ) + actual = generate_design_space( + kind="cuda", + mod=mod, + target=tvm.target.Target("cuda --arch=sm_90a"), + types=None, + sch_rules=[ + multi_level_tiling_tensor_core(), + ] + + get_rules( + "cuda", + ( + ms.schedule_rule.MultiLevelTiling, + ms.schedule_rule.AutoInline, + ), + ), + ) + check_sketches( + mod, + sketches=actual, + expected_mods=[matmul_relu_fallback_0], + expected_decisions=[decision_0], + ) + + +@pytest.mark.parametrize("shared_scope", ["shared", "shared.dyn"]) +def test_conv2d(shared_scope): + intrin_suffix = shared_scope.replace(".", "_") + # fmt: off + @T.prim_func + def conv2d_0(inputs: T.Buffer((1, 16, 16, 32), "float16"), weight: T.Buffer((3, 3, 32, 32), "float16"), conv2d_nhwc: T.Buffer((1, 16, 16, 32), "float32")): + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) + # with T.block("root"): + PadInput = T.alloc_buffer((1, 18, 18, 32), "float16") + conv2d_nhwc_reindex_shared_dyn = T.alloc_buffer((16, 2, 1, 1, 16, 16), scope=shared_scope) + conv2d_nhwc_reindex_shared_dyn_wmma_accumulator = T.alloc_buffer((16, 2, 1, 1, 16, 16), scope="wmma.accumulator") + PadInput_reindex_shared_dyn = T.alloc_buffer((256, 288), "float16", scope=shared_scope) + weight_reindex_shared_dyn = T.alloc_buffer((288, 32), "float16", scope=shared_scope) + PadInput_reindex_shared_dyn_wmma_matrix_a = T.alloc_buffer((256, 288), "float16", scope="wmma.matrix_a") + weight_reindex_shared_dyn_wmma_matrix_b = T.alloc_buffer((288, 32), "float16", scope="wmma.matrix_b") + for i0, i1, i2, i3 in T.grid(1, 18, 18, 32): + with T.block("PadInput"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(inputs[v_i0, v_i1 - 1, v_i2 - 1, v_i3]) + T.writes(PadInput[v_i0, v_i1, v_i2, v_i3]) + PadInput[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(1 <= v_i1 and v_i1 < 17 and 1 <= v_i2 and v_i2 < 17, inputs[v_i0, v_i1 - 1, v_i2 - 1, v_i3], T.float16(0)) + for ax0_0_0_ax1_0_0_fused in T.thread_binding(2, thread="blockIdx.y"): + for ax0_0_1_ax1_0_1_fused in T.thread_binding(16, thread="blockIdx.x"): + for ax0_0_2_ax1_0_2_fused in T.thread_binding(1, thread="threadIdx.y"): + for ax2_0_0 in range(1): + for ax0_ax1_fused in range(4608): + with T.block("PadInput_reindex_shared.dyn"): + v0 = T.axis.spatial(256, ax0_0_1_ax1_0_1_fused * 16 + ax0_ax1_fused // 288) + v1 = T.axis.spatial(288, ax0_ax1_fused % 288) + T.reads(PadInput[0, v0 // 16 + v1 // 96, v0 % 16 + v1 % 96 // 32, v1 % 32]) + T.writes(PadInput_reindex_shared_dyn[v0, v1]) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 2}) + PadInput_reindex_shared_dyn[v0, v1] = PadInput[0, v0 // 16 + v1 // 96, v0 % 16 + v1 % 96 // 32, v1 % 32] + for ax0_ax1_fused in range(4608): + with T.block("weight_reindex_shared.dyn"): + v0 = T.axis.spatial(288, ax0_ax1_fused // 16) + v1 = T.axis.spatial(32, ax0_0_0_ax1_0_0_fused * 16 + ax0_ax1_fused % 16) + T.reads(weight[v0 // 96, v0 % 96 // 32, v0 % 32, v1]) + T.writes(weight_reindex_shared_dyn[v0, v1]) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 8}) + weight_reindex_shared_dyn[v0, v1] = weight[v0 // 96, v0 % 96 // 32, v0 % 32, v1] + for ax2_0_1 in range(18): + for ax0_0, ax1_0 in T.grid(1, 1): + with T.block("PadInput_reindex_shared.dyn_wmma.matrix_a_o"): + v0_o = T.axis.spatial(16, ax0_0_1_ax1_0_1_fused + ax0_0) + v1_o = T.axis.spatial(18, ax2_0_1 + ax1_0) + T.reads(PadInput_reindex_shared_dyn[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(PadInput_reindex_shared_dyn_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize": f"wmma_load_16x16x16_f16_a_{intrin_suffix}"}) + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("PadInput_reindex_shared.dyn_wmma.matrix_a"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads(PadInput_reindex_shared_dyn[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(PadInput_reindex_shared_dyn_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + PadInput_reindex_shared_dyn_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = PadInput_reindex_shared_dyn[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + for ax0_0, ax1_0 in T.grid(1, 1): + with T.block("weight_reindex_shared.dyn_wmma.matrix_b_o"): + v0_o = T.axis.spatial(18, ax2_0_1 + ax0_0) + v1_o = T.axis.spatial(2, ax0_0_0_ax1_0_0_fused + ax1_0) + T.reads(weight_reindex_shared_dyn[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(weight_reindex_shared_dyn_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize": f"wmma_load_16x16x16_f16_b_{intrin_suffix}"}) + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("weight_reindex_shared.dyn_wmma.matrix_b"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads(weight_reindex_shared_dyn[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(weight_reindex_shared_dyn_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + weight_reindex_shared_dyn_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = weight_reindex_shared_dyn[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + for ax0_0_3, ax1_0_3, ax2_0_2, ax0_0_4, ax1_0_4 in T.grid(1, 1, 1, 1, 1): + with T.block("conv2d_nhwc_o"): + v0_o = T.axis.spatial(16, ax0_0_1_ax1_0_1_fused + ax0_0_3 + ax0_0_4) + v1_o = T.axis.spatial(2, ax0_0_0_ax1_0_0_fused + ax1_0_3 + ax1_0_4) + v2_o = T.axis.reduce(18, ax2_0_0 * 18 + ax2_0_1 + ax2_0_2) + T.reads(PadInput_reindex_shared_dyn_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], weight_reindex_shared_dyn_wmma_matrix_b[v2_o * 16:v2_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(conv2d_nhwc_reindex_shared_dyn_wmma_accumulator[v0_o, v1_o, 0, 0, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1}) + with T.init(): + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("conv2d_nhwc_init"): + v0_i_init, v1_i_init = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads() + T.writes(conv2d_nhwc_reindex_shared_dyn_wmma_accumulator[v0_o, v1_o, 0, 0, v0_i_init, v1_i_init]) + conv2d_nhwc_reindex_shared_dyn_wmma_accumulator[v0_o, v1_o, 0, 0, v0_i_init, v1_i_init] = T.float32(0) + for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 16): + with T.block("conv2d_nhwc"): + v0_i, v1_i, v2_i = T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1]) + T.reads(conv2d_nhwc_reindex_shared_dyn_wmma_accumulator[v0_o, v1_o, 0, 0, v0_i, v1_i], PadInput_reindex_shared_dyn_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], weight_reindex_shared_dyn_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + T.writes(conv2d_nhwc_reindex_shared_dyn_wmma_accumulator[v0_o, v1_o, 0, 0, v0_i, v1_i]) + T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) + conv2d_nhwc_reindex_shared_dyn_wmma_accumulator[v0_o, v1_o, 0, 0, v0_i, v1_i] = conv2d_nhwc_reindex_shared_dyn_wmma_accumulator[v0_o, v1_o, 0, 0, v0_i, v1_i] + T.Cast("float32", PadInput_reindex_shared_dyn_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i]) * T.Cast("float32", weight_reindex_shared_dyn_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + for ax2 in range(1): + for ax0_ax1_fused in T.thread_binding(1, thread="threadIdx.y"): + for ax2_1, ax3 in T.grid(1, 1): + with T.block("conv2d_nhwc_reindex_shared.dyn_wmma.accumulator_o"): + v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0_0_1_ax1_0_1_fused, ax0_0_0_ax1_0_0_fused, ax2_1, ax3]) + v4_o = T.axis.spatial(1, 0) + v5_o = T.axis.spatial(1, 0) + T.reads(conv2d_nhwc_reindex_shared_dyn_wmma_accumulator[v0, v1, v2, v3, 0:16, 0:16]) + T.writes(conv2d_nhwc_reindex_shared_dyn[v0, v1, v2, v3, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": f"wmma_store_16x16x16_f32_{intrin_suffix}"}) + for ax4, ax5 in T.grid(16, 16): + with T.block("conv2d_nhwc_reindex_shared.dyn_wmma.accumulator"): + v4_i, v5_i = T.axis.remap("SS", [ax4, ax5]) + T.reads(conv2d_nhwc_reindex_shared_dyn_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i]) + T.writes(conv2d_nhwc_reindex_shared_dyn[v0, v1, v2, v3, v4_i, v5_i]) + conv2d_nhwc_reindex_shared_dyn[v0, v1, v2, v3, v4_i, v5_i] = conv2d_nhwc_reindex_shared_dyn_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i] + for ax0_ax1_ax3_ax4_ax5_fused in range(256): + with T.block("conv2d_nhwc_reindex_shared.dyn"): + v0, v1, v2 = T.axis.remap("SSS", [ax0_0_1_ax1_0_1_fused, ax0_0_0_ax1_0_0_fused, ax2]) + v3 = T.axis.spatial(1, 0) + v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused // 16) + v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 16) + T.reads(conv2d_nhwc_reindex_shared_dyn[v0, v1, v2, v3, v4, v5]) + T.writes(conv2d_nhwc[0, (v4 + v0 * 16) // 16, (v4 + v0 * 16) % 16, v5 + v1 * 16]) + T.block_attr({"meta_schedule.cooperative_fetch": 3}) + conv2d_nhwc[0, (v4 + v0 * 16) // 16, (v4 + v0 * 16) % 16, v5 + v1 * 16] = conv2d_nhwc_reindex_shared_dyn[v0, v1, v2, v3, v4, v5] + # fmt: on + decision_0 = [ + ("SamplePerfectTile", [1, 16, 1, 1, 1]), + ("SamplePerfectTile", [2, 1, 1, 1, 1]), + ("SamplePerfectTile", [1, 18, 1]), + ("SampleCategorical", 2), + ("SampleCategorical", 1), + ("SampleCategorical", 3), + ] + mod = te.create_prim_func( + te_workload.conv2d_nhwc( + N=1, + H=16, + W=16, + CI=32, + CO=32, + kernel_size=3, + stride=1, + padding=1, + in_dtype="float16", + out_dtype="float32", + ) + ) + actual = generate_design_space( + kind="cuda", + mod=mod, + target=tvm.target.Target("cuda --arch=sm_90a"), + types=None, + sch_rules=[ + multi_level_tiling_tensor_core( + read_reuse_scope=shared_scope, write_reuse_scope=shared_scope + ), + ], + ) + check_sketches( + mod, + sketches=actual, + expected_mods=[conv2d_0], + expected_decisions=[decision_0], + ) + + # Test adding inapplicable tensor intrinsics doesn't change the search space + # This test case uses the same workload, decision and the expected sketch as above + actual = generate_design_space( + kind="cuda", + mod=mod, + target=tvm.target.Target("cuda --arch=sm_90a"), + types=None, + sch_rules=[ + multi_level_tiling_tensor_core( + read_reuse_scope=shared_scope, + write_reuse_scope=shared_scope, + in_dtype="float16", + out_dtype=["float16", "float32"], + ), + ], + ) + check_sketches( + mod, + sketches=actual, + expected_mods=[conv2d_0], + expected_decisions=[decision_0], + ) + + +@pytest.mark.parametrize("shared_scope", ["shared", "shared.dyn"]) +def test_matmul_relu_pipeline(shared_scope): + intrin_suffix = shared_scope.replace(".", "_") + # fmt: off + @T.prim_func + def matmul_relu_pipeline_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, 128), "float16"), compute: T.Buffer((128, 128), "float32")) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + C = T.alloc_buffer((128, 128)) + C_reindex_shared = T.alloc_buffer((4, 4, 2, 2, 16, 16), scope=shared_scope) + C_reindex_shared_wmma_accumulator = T.alloc_buffer((4, 4, 2, 2, 16, 16), scope="wmma.accumulator") + A_reindex_shared = T.alloc_buffer((128, 128), "float16", scope=shared_scope) + B_reindex_shared = T.alloc_buffer((128, 128), "float16", scope=shared_scope) + A_reindex_shared_wmma_matrix_a = T.alloc_buffer((128, 128), "float16", scope="wmma.matrix_a") + B_reindex_shared_wmma_matrix_b = T.alloc_buffer((128, 128), "float16", scope="wmma.matrix_b") + for ax0_0_0_ax1_0_0_fused in T.thread_binding(1, thread="blockIdx.y"): + for ax0_0_1_ax1_0_1_fused in T.thread_binding(16, thread="blockIdx.x"): + for ax0_0_2_ax1_0_2_fused in T.thread_binding(1, thread="threadIdx.y"): + for ax2_0_0 in T.serial(4, annotations={"software_pipeline_order": [0, 3, 1, 4, 5, 2, 6], "software_pipeline_stage": [0, 0, 0, 0, 0, 1, 1]}): + for ax0_ax1_fused in range(1024): + with T.block("A_reindex_shared"): + v0 = T.axis.spatial(128, ax0_0_1_ax1_0_1_fused // 4 * 32 + ax0_ax1_fused // 32) + v1 = T.axis.spatial(128, ax2_0_0 * 32 + ax0_ax1_fused % 32) + T.reads(A[v0, v1]) + T.writes(A_reindex_shared[v0, v1]) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "double_buffer_scope": 0, "meta_schedule.cooperative_fetch": 4, "tir.manifest_shared_memory_local_stage": 1}) + A_reindex_shared[v0, v1] = A[v0, v1] + for ax0_ax1_fused in range(1024): + with T.block("B_reindex_shared"): + v0 = T.axis.spatial(128, ax2_0_0 * 32 + ax0_ax1_fused // 32) + v1 = T.axis.spatial(128, ax0_0_1_ax1_0_1_fused % 4 * 32 + ax0_ax1_fused % 32) + T.reads(B[v0, v1]) + T.writes(B_reindex_shared[v0, v1]) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "double_buffer_scope": 0, "meta_schedule.cooperative_fetch": 2, "tir.manifest_shared_memory_local_stage": 1}) + B_reindex_shared[v0, v1] = B[v0, v1] + for ax2_0_1 in T.serial(2, annotations={"software_pipeline_order": [0, 1, 2], "software_pipeline_stage": [0, 0, 1]}): + for ax0_0, ax1_0 in T.grid(2, 1): + with T.block("A_reindex_shared_wmma.matrix_a_o"): + v0_o = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused // 4 * 2 + ax0_0) + v1_o = T.axis.spatial(8, ax2_0_0 * 2 + ax2_0_1 + ax1_0) + T.reads(A_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize": f"wmma_load_16x16x16_f16_a_{intrin_suffix}"}) + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("A_reindex_shared_wmma.matrix_a"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads(A_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = A_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + for ax0_0, ax1_0 in T.grid(1, 2): + with T.block("B_reindex_shared_wmma.matrix_b_o"): + v0_o = T.axis.spatial(8, ax2_0_0 * 2 + ax2_0_1 + ax0_0) + v1_o = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused % 4 * 2 + ax1_0) + T.reads(B_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize": f"wmma_load_16x16x16_f16_b_{intrin_suffix}"}) + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("B_reindex_shared_wmma.matrix_b"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads(B_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + B_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = B_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + for ax0_0_3, ax1_0_3, ax2_0_2, ax0_0_4, ax1_0_4 in T.grid(1, 1, 1, 2, 2): + with T.block("C_o"): + v0_o = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused // 4 * 2 + ax0_0_3 * 2 + ax0_0_4) + v1_o = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused % 4 * 2 + ax1_0_3 * 2 + ax1_0_4) + v2_o = T.axis.reduce(8, ax2_0_0 * 2 + ax2_0_1 + ax2_0_2) + T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16:v2_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 2, v0_o % 2, v1_o % 2, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1}) + with T.init(): + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("C_init"): + v0_i_init, v1_i_init = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads() + T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 2, v0_o % 2, v1_o % 2, v0_i_init, v1_i_init]) + C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 2, v0_o % 2, v1_o % 2, v0_i_init, v1_i_init] = T.float32(0) + for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 16): + with T.block("C"): + v0_i, v1_i, v2_i = T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1]) + T.reads(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 2, v0_o % 2, v1_o % 2, v0_i, v1_i], A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 2, v0_o % 2, v1_o % 2, v0_i, v1_i]) + T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) + C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 2, v0_o % 2, v1_o % 2, v0_i, v1_i] = C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 2, v0_o % 2, v1_o % 2, v0_i, v1_i] + T.Cast("float32", A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i]) * T.Cast("float32", B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + for ax2 in range(2): + for ax0_ax1_fused in T.thread_binding(1, thread="threadIdx.y"): + for ax2_1, ax3 in T.grid(1, 2): + with T.block("C_reindex_shared_wmma.accumulator_o"): + v0 = T.axis.spatial(4, ax0_0_1_ax1_0_1_fused // 4) + v1 = T.axis.spatial(4, ax0_0_1_ax1_0_1_fused % 4) + v2 = T.axis.spatial(2, ax2 + ax2_1) + v3 = T.axis.spatial(2, ax3) + v4_o = T.axis.spatial(1, 0) + v5_o = T.axis.spatial(1, 0) + T.reads(C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, 0:16, 0:16]) + T.writes(C_reindex_shared[v0, v1, v2, v3, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": f"wmma_store_16x16x16_f32_{intrin_suffix}"}) + for ax4, ax5 in T.grid(16, 16): + with T.block("C_reindex_shared_wmma.accumulator"): + v4_i, v5_i = T.axis.remap("SS", [ax4, ax5]) + T.reads(C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i]) + T.writes(C_reindex_shared[v0, v1, v2, v3, v4_i, v5_i]) + C_reindex_shared[v0, v1, v2, v3, v4_i, v5_i] = C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i] + for ax0_ax1_ax3_ax4_ax5_fused in range(512): + with T.block("C_reindex_shared"): + v0 = T.axis.spatial(4, ax0_0_1_ax1_0_1_fused // 4) + v1 = T.axis.spatial(4, ax0_0_1_ax1_0_1_fused % 4) + v2 = T.axis.spatial(2, ax2) + v3 = T.axis.spatial(2, ax0_ax1_ax3_ax4_ax5_fused // 256) + v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 256 // 16) + v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 16) + T.reads(C_reindex_shared[v0, v1, v2, v3, v4, v5]) + T.writes(C[v4 + v2 * 16 + v0 * 32, v5 + v3 * 16 + v1 * 32]) + T.block_attr({"meta_schedule.cooperative_fetch": 3}) + C[v4 + v2 * 16 + v0 * 32, v5 + v3 * 16 + v1 * 32] = C_reindex_shared[v0, v1, v2, v3, v4, v5] + for i0, i1 in T.grid(128, 128): + with T.block("compute"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(C[v_i0, v_i1]) + T.writes(compute[v_i0, v_i1]) + compute[v_i0, v_i1] = T.max(C[v_i0, v_i1], T.float32(0)) + + # fmt: on + decision_0 = [ + ("SamplePerfectTile", [1, 4, 1, 1, 2]), + ("SamplePerfectTile", [1, 4, 1, 1, 2]), + ("SamplePerfectTile", [4, 2, 1]), + ("SampleCategorical", 2), + ("SampleCategorical", 2), + ("SampleCategorical", 1), + ] + mod = te.create_prim_func( + te_workload.matmul_relu( + n=128, + m=128, + k=128, + in_dtype="float16", + out_dtype="float32", + ) + ) + actual = generate_design_space( + kind="cuda", + mod=mod, + target=tvm.target.Target("cuda --arch=sm_90a"), + types=None, + sch_rules=[ + multi_level_tiling_tensor_core( + read_reuse_scope=shared_scope, + write_reuse_scope=shared_scope, + use_software_pipeline=True, + ), + ], + ) + check_sketches( + mod, + sketches=actual, + expected_mods=[matmul_relu_pipeline_0], + expected_decisions=[decision_0], + ) + + +def test_matmul_relu_non_tensorizable(): + # expected to do nothing on non-tensorizable workloads + mod = te.create_prim_func( + te_workload.matmul_relu( # dtype doesn't match tensor intrin + n=128, + m=128, + k=128, + ) + ) + (sch,) = generate_design_space( + kind="cuda", + mod=mod, + target=tvm.target.Target("cuda --arch=sm_90a"), + types=None, + sch_rules=[multi_level_tiling_tensor_core(write_reuse_scope="shared")] + + get_rules("cuda", ms.schedule_rule.AutoInline), + ) + tvm.ir.assert_structural_equal(mod, sch.mod["main"]) + + +def test_padded_matmul_relu(): + # fmt: off + @T.prim_func + def padded_matmul_relu_0(A: T.Buffer((127, 127), "float16"), B: T.Buffer((127, 127), "float16"), compute: T.Buffer((127, 127), "float32")) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) + C_reindex_shared = T.alloc_buffer((4, 8, 2, 1, 16, 16), scope="shared") + C_reindex_shared_wmma_accumulator = T.alloc_buffer((4, 8, 2, 1, 16, 16), scope="wmma.accumulator") + A_reindex_shared = T.alloc_buffer((128, 128), "float16", scope="shared") + B_reindex_shared = T.alloc_buffer((128, 128), "float16", scope="shared") + A_reindex_shared_wmma_matrix_a = T.alloc_buffer((128, 128), "float16", scope="wmma.matrix_a") + B_reindex_shared_wmma_matrix_b = T.alloc_buffer((128, 128), "float16", scope="wmma.matrix_b") + for ax0_0_0_ax1_0_0_fused in T.thread_binding(8, thread="blockIdx.y"): + for ax0_0_1_ax1_0_1_fused in T.thread_binding(2, thread="blockIdx.x"): + for ax0_0_2_ax1_0_2_fused in T.thread_binding(2, thread="threadIdx.y"): + for ax2_0_0 in range(1): + for ax0_ax1_fused in range(4096): + with T.block("A_reindex_shared"): + v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused // 2 * 32 + ax0_ax1_fused // 128) + v1 = T.axis.spatial(128, ax0_ax1_fused % 128) + T.reads(A[v0, v1]) + T.writes(A_reindex_shared[v0, v1]) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 8}) + A_reindex_shared[v0, v1] = T.if_then_else(v0 < 127 and v1 < 127, A[v0, v1], T.float16(0)) + for ax0_ax1_fused in range(4096): + with T.block("B_reindex_shared"): + v0 = T.axis.spatial(128, ax0_ax1_fused // 32) + v1 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused % 2 * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_fused % 32) + T.reads(B[v0, v1]) + T.writes(B_reindex_shared[v0, v1]) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 1}) + B_reindex_shared[v0, v1] = T.if_then_else(v0 < 127 and v1 < 127, B[v0, v1], T.float16(0)) + for ax2_0_1 in range(4): + for ax0_0, ax1_0 in T.grid(2, 2): + with T.block("A_reindex_shared_wmma.matrix_a_o"): + v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused // 2 * 2 + ax0_0) + v1_o = T.axis.spatial(8, ax2_0_1 * 2 + ax1_0) + T.reads(A_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_a_shared"}) + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("A_reindex_shared_wmma.matrix_a"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads(A_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = A_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + for ax0_0, ax1_0 in T.grid(2, 1): + with T.block("B_reindex_shared_wmma.matrix_b_o"): + v0_o = T.axis.spatial(8, ax2_0_1 * 2 + ax0_0) + v1_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused + ax1_0) + T.reads(B_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_b_shared"}) + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("B_reindex_shared_wmma.matrix_b"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads(B_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + B_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = B_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + for ax0_0_3, ax1_0_3, ax2_0_2, ax0_0_4, ax1_0_4 in T.grid(1, 1, 2, 2, 1): + with T.block("C_o"): + v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused // 2 * 2 + ax0_0_3 * 2 + ax0_0_4) + v1_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused + ax1_0_3 + ax1_0_4) + v2_o = T.axis.reduce(8, ax2_0_0 * 8 + ax2_0_1 * 2 + ax2_0_2) + T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16:v2_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1}) + with T.init(): + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("C_init"): + v0_i_init, v1_i_init = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads() + T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i_init, v1_i_init]) + C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i_init, v1_i_init] = T.float32(0) + for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 16): + with T.block("C"): + v0_i, v1_i, v2_i = T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1]) + T.reads(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i, v1_i], A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i, v1_i]) + T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) + C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i, v1_i] = C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i, v1_i] + T.Cast("float32", A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i]) * T.Cast("float32", B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + for ax2 in range(2): + for ax0_ax1_fused in T.thread_binding(2, thread="threadIdx.y"): + for ax2_1, ax3 in T.grid(1, 1): + with T.block("C_reindex_shared_wmma.accumulator_o"): + v0 = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused // 2) + v1 = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_ax1_fused) + v2 = T.axis.spatial(2, ax2 + ax2_1) + v3 = T.axis.spatial(1, ax3) + v4_o = T.axis.spatial(1, 0) + v5_o = T.axis.spatial(1, 0) + T.reads(C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, 0:16, 0:16]) + T.writes(C_reindex_shared[v0, v1, v2, v3, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_store_16x16x16_f32_shared"}) + for ax4, ax5 in T.grid(16, 16): + with T.block("C_reindex_shared_wmma.accumulator"): + v4_i, v5_i = T.axis.remap("SS", [ax4, ax5]) + T.reads(C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i]) + T.writes(C_reindex_shared[v0, v1, v2, v3, v4_i, v5_i]) + C_reindex_shared[v0, v1, v2, v3, v4_i, v5_i] = C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i] + for ax0_ax1_ax3_ax4_ax5_fused in range(512): + with T.block("C_reindex_shared"): + v0 = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused // 2) + v1 = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_ax1_ax3_ax4_ax5_fused // 256) + v2 = T.axis.spatial(2, ax2) + v3 = T.axis.spatial(1, 0) + v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 256 // 16) + v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 16) + T.where(ax0_0_0_ax1_0_0_fused // 2 * 32 + ax2 * 16 + ax0_ax1_ax3_ax4_ax5_fused % 256 // 16 < 127 and ax0_0_0_ax1_0_0_fused % 2 * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_ax3_ax4_ax5_fused // 256 * 16 + ax0_ax1_ax3_ax4_ax5_fused % 16 < 127) + T.reads(C_reindex_shared[v0, v1, v2, v3, v4, v5]) + T.writes(compute[v4 + v2 * 16 + v0 * 32, v5 + v1 * 16]) + T.block_attr({"meta_schedule.cooperative_fetch": 4}) + compute[v4 + v2 * 16 + v0 * 32, v5 + v1 * 16] = T.max(C_reindex_shared[v0, v1, v2, v3, v4, v5], T.float32(0)) + # fmt: on + + decision_0 = [ + ("SamplePerfectTile", [4, 1, 1, 1, 2]), + ("SamplePerfectTile", [2, 2, 2, 1, 1]), + ("SamplePerfectTile", [1, 4, 2]), + ("SampleCategorical", 3), + ("SampleCategorical", 3), + ("SampleCategorical", 0), + ] + + mod = te.create_prim_func( + te_workload.matmul_relu( + n=127, + m=127, + k=127, + in_dtype="float16", + out_dtype="float32", + ) + ) + actual = generate_design_space( + kind="cuda", + mod=mod, + target=tvm.target.Target("cuda --arch=sm_90a"), + types=None, + sch_rules=[multi_level_tiling_tensor_core(write_reuse_scope="shared")] + + get_rules("cuda", ms.schedule_rule.AutoInline), + ) + check_sketches( + mod, + sketches=actual, + expected_mods=[padded_matmul_relu_0], + expected_decisions=[decision_0], + ) + + +def test_conv_1x1(): + # fmt: off + @T.prim_func + def conv2d_1x1_0(inputs: T.Buffer((1, 16, 16, 64), "float16"), weight: T.Buffer((1, 1, 64, 64), "float16"), conv2d_nhwc: T.Buffer((1, 16, 16, 64), "float32")): + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) + # with T.block("root"): + conv2d_nhwc_reindex_shared = T.alloc_buffer((2, 1, 8, 4, 16, 16), scope="shared") + conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer((2, 1, 8, 4, 16, 16), scope="wmma.accumulator") + PadInput_reindex_shared = T.alloc_buffer((256, 64), "float16", scope="shared") + weight_reindex_shared = T.alloc_buffer((1, 1, 64, 64), "float16", scope="shared") + PadInput_reindex_shared_wmma_matrix_a = T.alloc_buffer((256, 64), "float16", scope="wmma.matrix_a") + weight_reindex_shared_wmma_matrix_b = T.alloc_buffer((1, 1, 64, 64), "float16", scope="wmma.matrix_b") + for ax0_ax1_ax2_0_0_ax3_0_0_fused in T.thread_binding(1, thread="blockIdx.y"): + for ax2_0_1_ax3_0_1_fused in T.thread_binding(1, thread="blockIdx.x"): + for ax2_0_2_ax3_0_2_fused in T.thread_binding(2, thread="threadIdx.y"): + for ax4_0_0 in range(2): + for ax0_ax1_fused in range(8192): + with T.block("PadInput_reindex_shared"): + v0 = T.axis.spatial(256, ax0_ax1_fused // 32) + v1 = T.axis.spatial(64, ax4_0_0 * 32 + ax0_ax1_fused % 32) + T.reads(inputs[0, v0 // 16, v0 % 16, v1]) + T.writes(PadInput_reindex_shared[v0, v1]) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 8}) + PadInput_reindex_shared[v0, v1] = inputs[0, v0 // 16, v0 % 16, v1] + for ax0_ax1_ax2_ax3_fused in range(2048): + with T.block("weight_reindex_shared"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(1, 0) + v2 = T.axis.spatial(64, ax4_0_0 * 32 + ax0_ax1_ax2_ax3_fused // 64) + v3 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused % 64) + T.reads(weight[v0, v1, v2, v3]) + T.writes(weight_reindex_shared[v0, v1, v2, v3]) + T.block_attr({"buffer_dim_align": [[0, 2, 32, 8]], "meta_schedule.cooperative_fetch": 4}) + weight_reindex_shared[v0, v1, v2, v3] = weight[v0, v1, v2, v3] + for ax4_0_1 in range(1): + for ax0_0, ax1_0 in T.grid(8, 2): + with T.block("PadInput_reindex_shared_wmma.matrix_a_o"): + v0_o = T.axis.spatial(16, ax2_0_2_ax3_0_2_fused * 8 + ax0_0) + v1_o = T.axis.spatial(4, ax4_0_0 * 2 + ax1_0) + T.reads(PadInput_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_a_shared"}) + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("PadInput_reindex_shared_wmma.matrix_a"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads(PadInput_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = PadInput_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + for ax0, ax1, ax2_0, ax3_0 in T.grid(1, 1, 2, 4): + with T.block("weight_reindex_shared_wmma.matrix_b_o"): + v0_o, v1_o = T.axis.remap("SS", [ax0, ax1]) + v2_o = T.axis.spatial(4, ax4_0_0 * 2 + ax2_0) + v3_o = T.axis.spatial(4, ax3_0) + T.reads(weight_reindex_shared[v0_o, v1_o, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) + T.writes(weight_reindex_shared_wmma_matrix_b[v0_o, v1_o, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_b_shared"}) + for ax2_1, ax3_1 in T.grid(16, 16): + with T.block("weight_reindex_shared_wmma.matrix_b"): + v2_i, v3_i = T.axis.remap("SS", [ax2_1, ax3_1]) + T.reads(weight_reindex_shared[v0_o, v1_o, v2_o * 16 + v2_i, v3_o * 16 + v3_i]) + T.writes(weight_reindex_shared_wmma_matrix_b[v0_o, v1_o, v2_o * 16 + v2_i, v3_o * 16 + v3_i]) + weight_reindex_shared_wmma_matrix_b[v0_o, v1_o, v2_o * 16 + v2_i, v3_o * 16 + v3_i] = weight_reindex_shared[v0_o, v1_o, v2_o * 16 + v2_i, v3_o * 16 + v3_i] + for ax2_0_3, ax3_0_3, ax4_0_2, ax2_0_4, ax3_0_4 in T.grid(8, 1, 2, 1, 4): + with T.block("conv2d_nhwc_o"): + v0_o = T.axis.spatial(1, 0) + v1_o = T.axis.spatial(1, 0) + v2_o = T.axis.spatial(16, ax2_0_2_ax3_0_2_fused * 8 + ax2_0_3 + ax2_0_4) + v3_o = T.axis.spatial(4, ax3_0_3 * 4 + ax3_0_4) + v4_o = T.axis.reduce(4, ax4_0_0 * 2 + ax4_0_1 * 2 + ax4_0_2) + T.reads(PadInput_reindex_shared_wmma_matrix_a[v2_o * 16:v2_o * 16 + 16, v4_o * 16:v4_o * 16 + 16], weight_reindex_shared_wmma_matrix_b[v0_o, v1_o, v4_o * 16:v4_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, 0, v2_o % 8, v3_o, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1}) + with T.init(): + for ax2_1, ax3_1 in T.grid(16, 16): + with T.block("conv2d_nhwc_init"): + v2_i_init, v3_i_init = T.axis.remap("SS", [ax2_1, ax3_1]) + T.reads() + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, 0, v2_o % 8, v3_o, v2_i_init, v3_i_init]) + conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, 0, v2_o % 8, v3_o, v2_i_init, v3_i_init] = T.float32(0) + for ax2_1, ax3_1, ax4_1 in T.grid(16, 16, 16): + with T.block("conv2d_nhwc"): + v2_i, v3_i, v4_i = T.axis.remap("SSR", [ax2_1, ax3_1, ax4_1]) + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, 0, v2_o % 8, v3_o, v2_i, v3_i], PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i], weight_reindex_shared_wmma_matrix_b[v0_o, v1_o, v4_o * 16 + v4_i, v3_o * 16 + v3_i]) + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, 0, v2_o % 8, v3_o, v2_i, v3_i]) + T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) + conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, 0, v2_o % 8, v3_o, v2_i, v3_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, 0, v2_o % 8, v3_o, v2_i, v3_i] + T.Cast("float32", PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i]) * T.Cast("float32", weight_reindex_shared_wmma_matrix_b[v0_o, v1_o, v4_o * 16 + v4_i, v3_o * 16 + v3_i]) + for ax2 in range(8): + for ax0_ax1_fused in T.thread_binding(2, thread="threadIdx.y"): + for ax2_1, ax3 in T.grid(1, 4): + with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"): + v0_o = T.axis.spatial(2, ax0_ax1_fused) + v1_o = T.axis.spatial(1, 0) + v2_o = T.axis.spatial(8, ax2 + ax2_1) + v3_o = T.axis.spatial(4, ax3) + v4_o = T.axis.spatial(1, 0) + v5_o = T.axis.spatial(1, 0) + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, 0:16, 0:16]) + T.writes(conv2d_nhwc_reindex_shared[v0_o, v1_o, v2_o, v3_o, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_store_16x16x16_f32_shared"}) + for ax4, ax5 in T.grid(16, 16): + with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator"): + v4_i, v5_i = T.axis.remap("SS", [ax4, ax5]) + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i]) + T.writes(conv2d_nhwc_reindex_shared[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i]) + conv2d_nhwc_reindex_shared[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i] + for ax0_ax1_ax3_ax4_ax5_fused in range(2048): + with T.block("conv2d_nhwc_reindex_shared"): + v0 = T.axis.spatial(2, ax0_ax1_ax3_ax4_ax5_fused // 1024) + v1 = T.axis.spatial(1, 0) + v2 = T.axis.spatial(8, ax2) + v3 = T.axis.spatial(4, ax0_ax1_ax3_ax4_ax5_fused % 1024 // 256) + v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 256 // 16) + v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 16) + T.reads(conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4, v5]) + T.writes(conv2d_nhwc[0, (v4 + v2 * 16 + v0 * 128) // 16, (v4 + v2 * 16 + v0 * 128) % 16, v5 + v3 * 16]) + T.block_attr({"meta_schedule.cooperative_fetch": 1}) + conv2d_nhwc[0, (v4 + v2 * 16 + v0 * 128) // 16, (v4 + v2 * 16 + v0 * 128) % 16, v5 + v3 * 16] = conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4, v5] + # fmt: on + + decision_0 = [ + ("SamplePerfectTile", [1, 1, 2, 8, 1]), + ("SamplePerfectTile", [1, 1, 1, 1, 4]), + ("SamplePerfectTile", [2, 1, 2]), + ("SampleCategorical", 0), + ("SampleCategorical", 3), + ("SampleCategorical", 2), + ] + + mod = te.create_prim_func( + te_workload.conv2d_nhwc( + 1, + 16, + 16, + 64, + 64, + 1, + 1, + 0, + in_dtype="float16", + out_dtype="float32", + ) + ) + actual = generate_design_space( + kind="cuda", + mod=mod, + target=tvm.target.Target("cuda --arch=sm_90a"), + types=None, + sch_rules=[multi_level_tiling_tensor_core(write_reuse_scope="shared")] + + get_rules("cuda", ms.schedule_rule.AutoInline), + ) + check_sketches( + mod, + sketches=actual, + expected_mods=[conv2d_1x1_0], + expected_decisions=[decision_0], + ) + + +def test_padded_conv(): + # fmt: off + @T.prim_func + def padded_conv2d_0(inputs: T.Buffer((1, 224, 224, 3), "float16"), weight: T.Buffer((7, 7, 3, 64), "float16"), conv2d_nhwc: T.Buffer((1, 112, 112, 64), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + conv2d_nhwc_reindex_shared = T.alloc_buffer((56, 2, 14, 2, 16, 16), scope="shared") + conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer((56, 2, 14, 2, 16, 16), scope="wmma.accumulator") + PadInput_reindex_pad_shared = T.alloc_buffer((12544, 160), "float16", scope="shared") + weight_reindex_pad_shared = T.alloc_buffer((160, 64), "float16", scope="shared") + PadInput_reindex_pad_shared_wmma_matrix_a = T.alloc_buffer((12544, 160), "float16", scope="wmma.matrix_a") + weight_reindex_pad_shared_wmma_matrix_b = T.alloc_buffer((160, 64), "float16", scope="wmma.matrix_b") + for ax0_0_0_ax1_0_0_fused in T.thread_binding(14, thread="blockIdx.y"): + for ax0_0_1_ax1_0_1_fused in T.thread_binding(1, thread="blockIdx.x"): + for ax0_0_2_ax1_0_2_fused in T.thread_binding(8, thread="threadIdx.y"): + for ax2_0_0 in range(10): + for ax0_ax1_fused in range(28672): + with T.block("PadInput_reindex_pad_shared"): + v0 = T.axis.spatial(12544, ax0_0_0_ax1_0_0_fused // 2 * 1792 + ax0_ax1_fused // 16) + v1 = T.axis.spatial(160, ax2_0_0 * 16 + ax0_ax1_fused % 16) + T.reads(inputs[0, v0 // 112 * 2 + v1 // 21 - 3, v0 % 112 * 2 + v1 % 21 // 3 - 3, v1 % 3]) + T.writes(PadInput_reindex_pad_shared[v0, v1]) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 4}) + PadInput_reindex_pad_shared[v0, v1] = T.if_then_else(v1 < 147, T.if_then_else(3 <= v0 // 112 * 2 + v1 // 21 and v0 // 112 * 2 + v1 // 21 < 227 and 3 <= v0 % 112 * 2 + v1 % 21 // 3 and v0 % 112 * 2 + v1 % 21 // 3 < 227, inputs[0, v0 // 112 * 2 + v1 // 21 - 3, v0 % 112 * 2 + v1 % 21 // 3 - 3, v1 % 3], T.float16(0)), T.float16(0)) + for ax0_ax1_fused in range(512): + with T.block("weight_reindex_pad_shared"): + v0 = T.axis.spatial(160, ax2_0_0 * 16 + ax0_ax1_fused // 32) + v1 = T.axis.spatial(64, ax0_0_0_ax1_0_0_fused % 2 * 32 + ax0_ax1_fused % 32) + T.reads(weight[v0 // 21, v0 % 21 // 3, v0 % 3, v1]) + T.writes(weight_reindex_pad_shared[v0, v1]) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 2}) + weight_reindex_pad_shared[v0, v1] = T.if_then_else(v0 < 147, weight[v0 // 21, v0 % 21 // 3, v0 % 3, v1], T.float16(0)) + for ax2_0_1 in range(1): + for ax0_0, ax1_0 in T.grid(14, 1): + with T.block("PadInput_reindex_pad_shared_wmma.matrix_a_o"): + v0_o = T.axis.spatial(784, ax0_0_0_ax1_0_0_fused // 2 * 112 + ax0_0_2_ax1_0_2_fused * 14 + ax0_0) + v1_o = T.axis.spatial(10, ax2_0_0 + ax1_0) + T.reads(PadInput_reindex_pad_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(PadInput_reindex_pad_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_a_shared"}) + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("PadInput_reindex_pad_shared_wmma.matrix_a"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads(PadInput_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(PadInput_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + PadInput_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = PadInput_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + for ax0_0, ax1_0 in T.grid(1, 2): + with T.block("weight_reindex_pad_shared_wmma.matrix_b_o"): + v0_o = T.axis.spatial(10, ax2_0_0 + ax0_0) + v1_o = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused % 2 * 2 + ax1_0) + T.reads(weight_reindex_pad_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(weight_reindex_pad_shared_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_b_shared"}) + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("weight_reindex_pad_shared_wmma.matrix_b"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads(weight_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(weight_reindex_pad_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + weight_reindex_pad_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = weight_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + for ax0_0_3, ax1_0_3, ax2_0_2, ax0_0_4, ax1_0_4 in T.grid(7, 2, 1, 2, 1): + with T.block("conv2d_nhwc_o"): + v0_o = T.axis.spatial(784, ax0_0_0_ax1_0_0_fused // 2 * 112 + ax0_0_2_ax1_0_2_fused * 14 + ax0_0_3 * 2 + ax0_0_4) + v1_o = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused % 2 * 2 + ax1_0_3 + ax1_0_4) + v2_o = T.axis.reduce(10, ax2_0_0 + ax2_0_1 + ax2_0_2) + T.reads(PadInput_reindex_pad_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], weight_reindex_pad_shared_wmma_matrix_b[v2_o * 16:v2_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, v0_o % 14, v1_o % 2, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1}) + with T.init(): + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("conv2d_nhwc_init"): + v0_i_init, v1_i_init = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads() + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, v0_o % 14, v1_o % 2, v0_i_init, v1_i_init]) + conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, v0_o % 14, v1_o % 2, v0_i_init, v1_i_init] = T.float32(0) + for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 16): + with T.block("conv2d_nhwc"): + v0_i, v1_i, v2_i = T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1]) + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, v0_o % 14, v1_o % 2, v0_i, v1_i], PadInput_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], weight_reindex_pad_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, v0_o % 14, v1_o % 2, v0_i, v1_i]) + T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) + conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, v0_o % 14, v1_o % 2, v0_i, v1_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, v0_o % 14, v1_o % 2, v0_i, v1_i] + T.Cast("float32", PadInput_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i]) * T.Cast("float32", weight_reindex_pad_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + for ax2 in range(14): + for ax0_ax1_fused in T.thread_binding(8, thread="threadIdx.y"): + for ax2_1, ax3 in T.grid(1, 2): + with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"): + v0_o = T.axis.spatial(56, ax0_0_0_ax1_0_0_fused // 2 * 8 + ax0_ax1_fused) + v1_o = T.axis.spatial(2, ax0_0_0_ax1_0_0_fused % 2) + v2_o = T.axis.spatial(14, ax2 + ax2_1) + v3_o = T.axis.spatial(2, ax3) + v4_o = T.axis.spatial(1, 0) + v5_o = T.axis.spatial(1, 0) + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, 0:16, 0:16]) + T.writes(conv2d_nhwc_reindex_shared[v0_o, v1_o, v2_o, v3_o, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_store_16x16x16_f32_shared"}) + for ax4, ax5 in T.grid(16, 16): + with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator"): + v4_i, v5_i = T.axis.remap("SS", [ax4, ax5]) + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i]) + T.writes(conv2d_nhwc_reindex_shared[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i]) + conv2d_nhwc_reindex_shared[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i] + for ax0_ax1_ax3_ax4_ax5_fused in range(4096): + with T.block("conv2d_nhwc_reindex_shared"): + v0 = T.axis.spatial(56, ax0_0_0_ax1_0_0_fused // 2 * 8 + ax0_ax1_ax3_ax4_ax5_fused // 512) + v1 = T.axis.spatial(2, ax0_0_0_ax1_0_0_fused % 2) + v2 = T.axis.spatial(14, ax2) + v3 = T.axis.spatial(2, ax0_ax1_ax3_ax4_ax5_fused % 512 // 256) + v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 256 // 16) + v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 16) + T.reads(conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4, v5]) + T.writes(conv2d_nhwc[0, (v4 + v2 * 16 + v0 * 224) // 112, (v4 + v2 * 16 + v0 * 224) % 112, v5 + v3 * 16 + v1 * 32]) + T.block_attr({"meta_schedule.cooperative_fetch": 3}) + conv2d_nhwc[0, (v4 + v2 * 16 + v0 * 224) // 112, (v4 + v2 * 16 + v0 * 224) % 112, v5 + v3 * 16 + v1 * 32] = conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4, v5] + # fmt: on + + decision_0 = [ + ("SamplePerfectTile", [7, 1, 8, 7, 2]), + ("SamplePerfectTile", [2, 1, 1, 2, 1]), + ("SamplePerfectTile", [10, 1, 1]), + ("SampleCategorical", 2), + ("SampleCategorical", 2), + ("SampleCategorical", 1), + ] + mod = te.create_prim_func( + te_workload.conv2d_nhwc( + 1, + 224, + 224, + 3, + 64, + 7, + 2, + 3, + in_dtype="float16", + out_dtype="float32", + ) + ) + actual = generate_design_space( + kind="cuda", + mod=mod, + target=tvm.target.Target("cuda --arch=sm_90a"), + types=None, + sch_rules=[multi_level_tiling_tensor_core(write_reuse_scope="shared")] + + get_rules("cuda", ms.schedule_rule.AutoInline), + ) + check_sketches( + mod, + sketches=actual, + expected_mods=[padded_conv2d_0], + expected_decisions=[decision_0], + ) + + +def test_padded_matmul_single_padded_input(): + # fmt: off + @T.prim_func + def padded_matmul_single_padded_input_0(A: T.Buffer((1023, 4096), "float16"), B: T.Buffer((4096, 1024), "float16"), C: T.Buffer((1023, 1024), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + C_reindex_pad_shared = T.alloc_buffer((8, 32, 8, 2, 16, 16), scope="shared") + C_reindex_pad_shared_wmma_accumulator = T.alloc_buffer((8, 32, 8, 2, 16, 16), scope="wmma.accumulator") + A_reindex_pad_shared = T.alloc_buffer((1024, 4096), "float16", scope="shared") + B_reindex_shared = T.alloc_buffer((4096, 1024), "float16", scope="shared") + A_reindex_pad_shared_wmma_matrix_a = T.alloc_buffer((1024, 4096), "float16", scope="wmma.matrix_a") + B_reindex_shared_wmma_matrix_b = T.alloc_buffer((4096, 1024), "float16", scope="wmma.matrix_b") + for ax0_0_0_ax1_0_0_fused in T.thread_binding(1, thread="blockIdx.y"): + for ax0_0_1_ax1_0_1_fused in T.thread_binding(32, thread="blockIdx.x"): + for ax0_0_2_ax1_0_2_fused in T.thread_binding(8, thread="threadIdx.y"): + for ax2_0_0 in range(32): + for ax0_ax1_fused in range(65536): + with T.block("A_reindex_pad_shared"): + v0 = T.axis.spatial(1024, ax0_0_1_ax1_0_1_fused // 16 * 512 + ax0_ax1_fused // 128) + v1 = T.axis.spatial(4096, ax2_0_0 * 128 + ax0_ax1_fused % 128) + T.reads(A[v0, v1]) + T.writes(A_reindex_pad_shared[v0, v1]) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 2}) + A_reindex_pad_shared[v0, v1] = T.if_then_else(v0 < 1023, A[v0, v1], T.float16(0.0)) + for ax0_ax1_fused in range(8192): + with T.block("B_reindex_shared"): + v0 = T.axis.spatial(4096, ax2_0_0 * 128 + ax0_ax1_fused // 64) + v1 = T.axis.spatial(1024, ax0_0_1_ax1_0_1_fused % 16 * 64 + ax0_ax1_fused % 64) + T.reads(B[v0, v1]) + T.writes(B_reindex_shared[v0, v1]) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 1}) + B_reindex_shared[v0, v1] = B[v0, v1] + for ax2_0_1 in range(8): + for ax0_0, ax1_0 in T.grid(8, 1): + with T.block("A_reindex_pad_shared_wmma.matrix_a_o"): + v0_o = T.axis.spatial(64, ax0_0_1_ax1_0_1_fused // 16 * 32 + ax0_0_2_ax1_0_2_fused // 2 * 8 + ax0_0) + v1_o = T.axis.spatial(256, ax2_0_0 * 8 + ax2_0_1 + ax1_0) + T.reads(A_reindex_pad_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(A_reindex_pad_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_a_shared"}) + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("A_reindex_pad_shared_wmma.matrix_a"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads(A_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(A_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + A_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = A_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + for ax0_0, ax1_0 in T.grid(1, 2): + with T.block("B_reindex_shared_wmma.matrix_b_o"): + v0_o = T.axis.spatial(256, ax2_0_0 * 8 + ax2_0_1 + ax0_0) + v1_o = T.axis.spatial(64, ax0_0_1_ax1_0_1_fused % 16 * 4 + ax0_0_2_ax1_0_2_fused % 2 * 2 + ax1_0) + T.reads(B_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_b_shared"}) + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("B_reindex_shared_wmma.matrix_b"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads(B_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + B_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = B_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + for ax0_0_3, ax1_0_3, ax2_0_2, ax0_0_4, ax1_0_4 in T.grid(2, 1, 1, 4, 2): + with T.block("C_o"): + v0_o = T.axis.spatial(64, ax0_0_1_ax1_0_1_fused // 16 * 32 + ax0_0_2_ax1_0_2_fused // 2 * 8 + ax0_0_3 * 4 + ax0_0_4) + v1_o = T.axis.spatial(64, ax0_0_1_ax1_0_1_fused % 16 * 4 + ax0_0_2_ax1_0_2_fused % 2 * 2 + ax1_0_3 * 2 + ax1_0_4) + v2_o = T.axis.reduce(256, ax2_0_0 * 8 + ax2_0_1 + ax2_0_2) + T.reads(A_reindex_pad_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16:v2_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(C_reindex_pad_shared_wmma_accumulator[v0_o // 8, v1_o // 2, v0_o % 8, v1_o % 2, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1}) + with T.init(): + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("C_init"): + v0_i_init, v1_i_init = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads() + T.writes(C_reindex_pad_shared_wmma_accumulator[v0_o // 8, v1_o // 2, v0_o % 8, v1_o % 2, v0_i_init, v1_i_init]) + C_reindex_pad_shared_wmma_accumulator[v0_o // 8, v1_o // 2, v0_o % 8, v1_o % 2, v0_i_init, v1_i_init] = T.float32(0.0) + for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 16): + with T.block("C"): + v0_i, v1_i, v2_i = T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1]) + T.reads(C_reindex_pad_shared_wmma_accumulator[v0_o // 8, v1_o // 2, v0_o % 8, v1_o % 2, v0_i, v1_i], A_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + T.writes(C_reindex_pad_shared_wmma_accumulator[v0_o // 8, v1_o // 2, v0_o % 8, v1_o % 2, v0_i, v1_i]) + T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) + C_reindex_pad_shared_wmma_accumulator[v0_o // 8, v1_o // 2, v0_o % 8, v1_o % 2, v0_i, v1_i] = C_reindex_pad_shared_wmma_accumulator[v0_o // 8, v1_o // 2, v0_o % 8, v1_o % 2, v0_i, v1_i] + T.Cast("float32", A_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i]) * T.Cast("float32", B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + for ax2 in range(8): + for ax0_ax1_fused in T.thread_binding(8, thread="threadIdx.y"): + for ax2_1, ax3 in T.grid(1, 2): + with T.block("C_reindex_pad_shared_wmma.accumulator_o"): + v0_o = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused // 16 * 4 + ax0_ax1_fused // 2) + v1_o = T.axis.spatial(32, ax0_0_1_ax1_0_1_fused % 16 * 2 + ax0_ax1_fused % 2) + v2_o = T.axis.spatial(8, ax2 + ax2_1) + v3_o = T.axis.spatial(2, ax3) + v4_o = T.axis.spatial(1, 0) + v5_o = T.axis.spatial(1, 0) + T.reads(C_reindex_pad_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, 0:16, 0:16]) + T.writes(C_reindex_pad_shared[v0_o, v1_o, v2_o, v3_o, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_store_16x16x16_f32_shared"}) + for ax4, ax5 in T.grid(16, 16): + with T.block("C_reindex_pad_shared_wmma.accumulator"): + v4_i, v5_i = T.axis.remap("SS", [ax4, ax5]) + T.reads(C_reindex_pad_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i]) + T.writes(C_reindex_pad_shared[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i]) + C_reindex_pad_shared[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i] = C_reindex_pad_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i] + for ax0_ax1_ax3_ax4_ax5_fused in range(4096): + with T.block("C_reindex_pad_shared"): + v0 = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused // 16 * 4 + ax0_ax1_ax3_ax4_ax5_fused // 1024) + v1 = T.axis.spatial(32, ax0_0_1_ax1_0_1_fused % 16 * 2 + ax0_ax1_ax3_ax4_ax5_fused % 1024 // 512) + v2 = T.axis.spatial(8, ax2) + v3 = T.axis.spatial(2, ax0_ax1_ax3_ax4_ax5_fused % 512 // 256) + v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 256 // 16) + v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 16) + T.where(ax0_0_1_ax1_0_1_fused // 16 * 512 + ax0_ax1_ax3_ax4_ax5_fused // 1024 * 128 + ax2 * 16 + ax0_ax1_ax3_ax4_ax5_fused % 256 // 16 < 1023) + T.reads(C_reindex_pad_shared[v0, v1, v2, v3, v4, v5]) + T.writes(C[v4 + v2 * 16 + v0 * 128, v5 + v3 * 16 + v1 * 32]) + T.block_attr({"meta_schedule.cooperative_fetch": 4}) + C[v4 + v2 * 16 + v0 * 128, v5 + v3 * 16 + v1 * 32] = C_reindex_pad_shared[v0, v1, v2, v3, v4, v5] + # fmt: on + + decision_0 = [ + ("SamplePerfectTile", [1, 2, 4, 2, 4]), + ("SamplePerfectTile", [1, 16, 2, 1, 2]), + ("SamplePerfectTile", [32, 8, 1]), + ("SampleCategorical", 3), + ("SampleCategorical", 1), + ("SampleCategorical", 0), + ] + mod = te.create_prim_func( + te_workload.matmul( + n=1023, + m=1024, + k=4096, + in_dtype="float16", + out_dtype="float32", + ) + ) + actual = generate_design_space( + kind="cuda", + mod=mod, + target=tvm.target.Target("cuda --arch=sm_90a"), + types=None, + sch_rules=[multi_level_tiling_tensor_core()] + + get_rules("cuda", ms.schedule_rule.AutoInline), + ) + check_sketches( + mod, + sketches=actual, + expected_mods=[padded_matmul_single_padded_input_0], + expected_decisions=[decision_0], + ) + + +def test_padded_matmul_no_padded_output(): + # fmt: off + @T.prim_func + def padded_matmul_no_padded_output_0(A: T.Buffer((1024, 4095), "float16"), B: T.Buffer((4095, 1024), "float16"), C: T.Buffer((1024, 1024), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + C_reindex_shared = T.alloc_buffer((32, 16, 2, 4, 16, 16), scope="shared") + C_reindex_shared_wmma_accumulator = T.alloc_buffer((32, 16, 2, 4, 16, 16), scope="wmma.accumulator") + A_reindex_pad_shared = T.alloc_buffer((1024, 4096), "float16", scope="shared") + B_reindex_pad_shared = T.alloc_buffer((4096, 1024), "float16", scope="shared") + A_reindex_pad_shared_wmma_matrix_a = T.alloc_buffer((1024, 4096), "float16", scope="wmma.matrix_a") + B_reindex_pad_shared_wmma_matrix_b = T.alloc_buffer((4096, 1024), "float16", scope="wmma.matrix_b") + for ax0_0_0_ax1_0_0_fused in T.thread_binding(64, thread="blockIdx.y"): + for ax0_0_1_ax1_0_1_fused in T.thread_binding(2, thread="blockIdx.x"): + for ax0_0_2_ax1_0_2_fused in T.thread_binding(4, thread="threadIdx.y"): + for ax2_0_0 in range(128): + for ax0_ax1_fused in range(4096): + with T.block("A_reindex_pad_shared"): + v0 = T.axis.spatial(1024, ax0_0_0_ax1_0_0_fused // 16 * 256 + ax0_0_1_ax1_0_1_fused * 128 + ax0_ax1_fused // 32) + v1 = T.axis.spatial(4096, ax2_0_0 * 32 + ax0_ax1_fused % 32) + T.reads(A[v0, v1]) + T.writes(A_reindex_pad_shared[v0, v1]) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 8}) + A_reindex_pad_shared[v0, v1] = T.if_then_else(v1 < 4095, A[v0, v1], T.float16(0.0)) + for ax0_ax1_fused in range(2048): + with T.block("B_reindex_pad_shared"): + v0 = T.axis.spatial(4096, ax2_0_0 * 32 + ax0_ax1_fused // 64) + v1 = T.axis.spatial(1024, ax0_0_0_ax1_0_0_fused % 16 * 64 + ax0_ax1_fused % 64) + T.reads(B[v0, v1]) + T.writes(B_reindex_pad_shared[v0, v1]) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 1}) + B_reindex_pad_shared[v0, v1] = T.if_then_else(v0 < 4095, B[v0, v1], T.float16(0.0)) + for ax2_0_1 in range(2): + for ax0_0, ax1_0 in T.grid(2, 1): + with T.block("A_reindex_pad_shared_wmma.matrix_a_o"): + v0_o = T.axis.spatial(64, ax0_0_0_ax1_0_0_fused // 16 * 16 + ax0_0_1_ax1_0_1_fused * 8 + ax0_0_2_ax1_0_2_fused * 2 + ax0_0) + v1_o = T.axis.spatial(256, ax2_0_0 * 2 + ax2_0_1 + ax1_0) + T.reads(A_reindex_pad_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(A_reindex_pad_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_a_shared"}) + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("A_reindex_pad_shared_wmma.matrix_a"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads(A_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(A_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + A_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = A_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + for ax0_0, ax1_0 in T.grid(1, 4): + with T.block("B_reindex_pad_shared_wmma.matrix_b_o"): + v0_o = T.axis.spatial(256, ax2_0_0 * 2 + ax2_0_1 + ax0_0) + v1_o = T.axis.spatial(64, ax0_0_0_ax1_0_0_fused % 16 * 4 + ax1_0) + T.reads(B_reindex_pad_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(B_reindex_pad_shared_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_b_shared"}) + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("B_reindex_pad_shared_wmma.matrix_b"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads(B_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(B_reindex_pad_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + B_reindex_pad_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = B_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + for ax0_0_3, ax1_0_3, ax2_0_2, ax0_0_4, ax1_0_4 in T.grid(2, 1, 1, 1, 4): + with T.block("C_o"): + v0_o = T.axis.spatial(64, ax0_0_0_ax1_0_0_fused // 16 * 16 + ax0_0_1_ax1_0_1_fused * 8 + ax0_0_2_ax1_0_2_fused * 2 + ax0_0_3 + ax0_0_4) + v1_o = T.axis.spatial(64, ax0_0_0_ax1_0_0_fused % 16 * 4 + ax1_0_3 * 4 + ax1_0_4) + v2_o = T.axis.reduce(256, ax2_0_0 * 2 + ax2_0_1 + ax2_0_2) + T.reads(A_reindex_pad_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], B_reindex_pad_shared_wmma_matrix_b[v2_o * 16:v2_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1}) + with T.init(): + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("C_init"): + v0_i_init, v1_i_init = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads() + T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, v0_i_init, v1_i_init]) + C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, v0_i_init, v1_i_init] = T.float32(0.0) + for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 16): + with T.block("C"): + v0_i, v1_i, v2_i = T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1]) + T.reads(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, v0_i, v1_i], A_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], B_reindex_pad_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, v0_i, v1_i]) + T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) + C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, v0_i, v1_i] = C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, v0_i, v1_i] + T.Cast("float32", A_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i]) * T.Cast("float32", B_reindex_pad_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + for ax2 in range(2): + for ax0_ax1_fused in T.thread_binding(4, thread="threadIdx.y"): + for ax2_1, ax3 in T.grid(1, 4): + with T.block("C_reindex_shared_wmma.accumulator_o"): + v0_o = T.axis.spatial(32, ax0_0_0_ax1_0_0_fused // 16 * 8 + ax0_0_1_ax1_0_1_fused * 4 + ax0_ax1_fused) + v1_o = T.axis.spatial(16, ax0_0_0_ax1_0_0_fused % 16) + v2_o = T.axis.spatial(2, ax2 + ax2_1) + v3_o = T.axis.spatial(4, ax3) + v4_o = T.axis.spatial(1, 0) + v5_o = T.axis.spatial(1, 0) + T.reads(C_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, 0:16, 0:16]) + T.writes(C_reindex_shared[v0_o, v1_o, v2_o, v3_o, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_store_16x16x16_f32_shared"}) + for ax4, ax5 in T.grid(16, 16): + with T.block("C_reindex_shared_wmma.accumulator"): + v4_i, v5_i = T.axis.remap("SS", [ax4, ax5]) + T.reads(C_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i]) + T.writes(C_reindex_shared[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i]) + C_reindex_shared[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i] = C_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i] + for ax0_ax1_ax3_ax4_ax5_fused in range(4096): + with T.block("C_reindex_shared"): + v0 = T.axis.spatial(32, ax0_0_0_ax1_0_0_fused // 16 * 8 + ax0_0_1_ax1_0_1_fused * 4 + ax0_ax1_ax3_ax4_ax5_fused // 1024) + v1 = T.axis.spatial(16, ax0_0_0_ax1_0_0_fused % 16) + v2 = T.axis.spatial(2, ax2) + v3 = T.axis.spatial(4, ax0_ax1_ax3_ax4_ax5_fused % 1024 // 256) + v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 256 // 16) + v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 16) + T.reads(C_reindex_shared[v0, v1, v2, v3, v4, v5]) + T.writes(C[v4 + v2 * 16 + v0 * 32, v5 + v3 * 16 + v1 * 64]) + T.block_attr({"meta_schedule.cooperative_fetch": 3}) + C[v4 + v2 * 16 + v0 * 32, v5 + v3 * 16 + v1 * 64] = C_reindex_shared[v0, v1, v2, v3, v4, v5] + # fmt: on + + decision_0 = [ + ("SamplePerfectTile", [4, 2, 4, 2, 1]), + ("SamplePerfectTile", [16, 1, 1, 1, 4]), + ("SamplePerfectTile", [128, 2, 1]), + ("SampleCategorical", 2), + ("SampleCategorical", 3), + ("SampleCategorical", 0), + ] + mod = te.create_prim_func( + te_workload.matmul( + n=1024, + m=1024, + k=4095, + in_dtype="float16", + out_dtype="float32", + ) + ) + actual = generate_design_space( + kind="cuda", + mod=mod, + target=tvm.target.Target("cuda --arch=sm_90a"), + types=None, + sch_rules=[multi_level_tiling_tensor_core()] + + get_rules("cuda", ms.schedule_rule.AutoInline), + ) + check_sketches( + mod, + sketches=actual, + expected_mods=[padded_matmul_no_padded_output_0], + expected_decisions=[decision_0], + ) + + +if __name__ == "__main__": + tvm.testing.main()