diff --git a/python/tvm/tir/analysis/analysis.py b/python/tvm/tir/analysis/analysis.py index 67eb7471d22d..c0e14279bda0 100644 --- a/python/tvm/tir/analysis/analysis.py +++ b/python/tvm/tir/analysis/analysis.py @@ -245,6 +245,40 @@ def detect_buffer_access_lca(func: PrimFunc) -> Dict[Buffer, Stmt]: """ return _ffi_api.detect_buffer_access_lca(func) # type: ignore # pylint: disable=no-member +def extract_gpu_resource_usage(func: tir.PrimFunc) -> Dict[str, Any]: + """Extract actual GPU resource usage from a TIR GPU kernel. + + This function analyzes the given PrimFunc and returns a dictionary containing + statistics such as thread block dimensions, shared/local memory consumption, + and vector access patterns. + + Parameters + ---------- + func : tvm.tir.PrimFunc + The GPU kernel to analyze. Must be lowered to TIR with explicit thread bindings. + + Returns + ------- + result : Dict[str, Any] + A dictionary with the following keys (all values are integers or lists of integers): + + - "thread_x", "thread_y", "thread_z": Block dimension sizes. + - "threads_per_block": Total number of threads in a block (tx * ty * tz). + - "shared_memory_bytes": Total bytes allocated in shared memory. + - "local_memory_bytes": Total bytes allocated in local memory. + - "num_kernels": Number of GPU kernels launched (usually 1 for a single PrimFunc). + - "vector_load_bytes": List of byte widths for vectorized loads (e.g., [16] for float4). + - "vector_store_bytes": List of byte widths for vectorized stores. + + Examples + -------- + >>> res = extract_gpu_resource_usage(my_gpu_func) + >>> print(res["shared_memory_bytes"]) + 8192 + >>> print(res["threads_per_block"]) + 256 + """ + return _ffi_api.extract_gpu_resource_usage(func) # type: ignore def estimate_tir_flops(stmt_or_mod: Union[Stmt, IRModule]) -> float: """Estimate the FLOPs of a TIR fragment. diff --git a/src/tir/analysis/extract_gpu_resource_usage.cc b/src/tir/analysis/extract_gpu_resource_usage.cc new file mode 100644 index 000000000000..53c03334d477 --- /dev/null +++ b/src/tir/analysis/extract_gpu_resource_usage.cc @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file extract_gpu_resource_usage.cc + * \brief Analyze and extract actual GPU resource usage from a TIR GPU kernel. + * It collects statistics such as: + * - thread block dimensions (threadIdx.x/y/z) + * - total threads per block + * - shared memory and local memory consumption (in bytes) + * - vector access widths + * This information can be used for hardware-aware scheduling, + * cost modeling, or diagnostic reporting. + */ + +#include +#include +#include +#include +#include + +#include "../../runtime/thread_storage_scope.h" +#include "../transforms/ir_utils.h" + +namespace tvm { +namespace tir { + +class GPUResourceExtractor : public StmtExprVisitor { + public: + Map Extract(Stmt stmt) { + Reset_(); + this->VisitStmt(stmt); + return BuildResult_(); + } + + void VisitStmt_(const AllocateNode* op) final { + StmtVisitor::VisitStmt_(op); + auto scope = GetPtrStorageScope(op->buffer_var); + runtime::StorageScope storage_scope = runtime::StorageScope::Create(scope); + + size_t size_bytes = static_cast(op->ConstantAllocationSize()) * + op->dtype.bytes() * op->dtype.lanes(); + + if (storage_scope.rank == runtime::StorageRank::kLocal) { + local_memory_bytes_ += size_bytes; + } else if (storage_scope.rank == runtime::StorageRank::kShared) { + shared_memory_bytes_ += size_bytes; + } + + // Record vector usage + if (op->dtype.is_vector()) { + vector_alloc_sizes_.push_back(size_bytes); + } + } + + void VisitStmt_(const AttrStmtNode* op) final { + if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { + if (nest_level_ == 0) { + // New kernel + if (kernels_launched_ > 0) { + // TODO: support multi-kernel? For now assume single kernel. + } + kernels_launched_++; + ResetKernelStats_(); + } + + Var var = op->node.as()->var; + const auto* extent = op->value.as(); + ICHECK(extent) << "Thread extent must be constant for analysis"; + + std::string name = var.get()->name_hint; + int64_t length = extent->value; + + if (name == "threadIdx.x") { + thread_x_ = length; + visited_threads_.insert(name); + } else if (name == "threadIdx.y") { + thread_y_ = length; + visited_threads_.insert(name); + } else if (name == "threadIdx.z") { + thread_z_ = length; + visited_threads_.insert(name); + } + // ignore vthread for resource counting (it's virtual) + + nest_level_++; + StmtVisitor::VisitStmt_(op); + nest_level_--; + + if (nest_level_ == 0) { + threads_per_block_ = thread_x_ * thread_y_ * thread_z_; + } + } else { + StmtVisitor::VisitStmt_(op); + } + } + + void VisitExpr_(const BufferLoadNode* op) final { + if (op->dtype.is_vector()) { + int64_t vec_bytes = op->dtype.bytes() * op->dtype.lanes(); + vector_load_sizes_.push_back(vec_bytes); + } + ExprVisitor::VisitExpr_(op); + } + + void VisitStmt_(const BufferStoreNode* op) final { + if (op->value->dtype.is_vector()) { + int64_t vec_bytes = op->value->dtype.bytes() * op->value->dtype.lanes(); + vector_store_sizes_.push_back(vec_bytes); + } + StmtVisitor::VisitStmt_(op); + } + + private: + int nest_level_ = 0; + int64_t thread_x_ = 1, thread_y_ = 1, thread_z_ = 1; + int64_t threads_per_block_ = 1; + int64_t shared_memory_bytes_ = 0; + int64_t local_memory_bytes_ = 0; + int64_t kernels_launched_ = 0; + std::unordered_set visited_threads_; + + std::vector vector_alloc_sizes_; + std::vector vector_load_sizes_; + std::vector vector_store_sizes_; + + void Reset_() { + ResetKernelStats_(); + kernels_launched_ = 0; + shared_memory_bytes_ = 0; + local_memory_bytes_ = 0; + vector_alloc_sizes_.clear(); + vector_load_sizes_.clear(); + vector_store_sizes_.clear(); + } + + void ResetKernelStats_() { + thread_x_ = 1; + thread_y_ = 1; + thread_z_ = 1; + threads_per_block_ = 1; + visited_threads_.clear(); + } + + Map BuildResult_() { + Map result; + + result.Set("thread_x", Integer(thread_x_)); + result.Set("thread_y", Integer(thread_y_)); + result.Set("thread_z", Integer(thread_z_)); + result.Set("threads_per_block", Integer(threads_per_block_)); + result.Set("shared_memory_bytes", Integer(static_cast(shared_memory_bytes_))); + result.Set("local_memory_bytes", Integer(static_cast(local_memory_bytes_))); + result.Set("num_kernels", Integer(kernels_launched_)); + + // Optional: add vector info as arrays + Array load_vecs; + for (auto sz : vector_load_sizes_) load_vecs.push_back(Integer(sz)); + result.Set("vector_load_bytes", load_vecs); + + Array store_vecs; + for (auto sz : vector_store_sizes_) store_vecs.push_back(Integer(sz)); + result.Set("vector_store_bytes", store_vecs); + + return result; + } +}; + +Map ExtractGPUResourceUsage(const PrimFunc& func) { + GPUResourceExtractor extractor; + return extractor.Extract(func->body); +} + +TVM_REGISTER_GLOBAL("tir.analysis.ExtractGPUResourceUsage").set_body_typed(ExtractGPUResourceUsage); + +} // namespace tir +} // namespace tvm