diff --git a/python/tvm/tir/ast_dumper.py b/python/tvm/tir/ast_dumper.py new file mode 100644 index 000000000000..83febea7bc09 --- /dev/null +++ b/python/tvm/tir/ast_dumper.py @@ -0,0 +1,73 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""IR Ast Dumper""" +import tvm +from . import _ffi_api +from tvm.ir import PrimExpr +from tvm.tir import Stmt + +def get_valid_fields(stmt_or_expr): + result = {} + for key in dir(stmt_or_expr): + attr = getattr(stmt_or_expr, key) + if not key.startswith("__") and (isinstance(attr, PrimExpr) or isinstance(attr, Stmt)): + result[key] = attr + return result + +def match(fields, child): + for key, value in fields.items(): + if str(value) == str(child): + return key + return "None" + +def dump(stmt, filename="graph.txt"): + stack = [] + ast_node = [] + ast_edge = [] + count = [0] + #idx2obj = {} + + def pre_func(stmt): + node_idx = count[0] + count[0] += 1 + #idx2obj[node_idx] = (stmt, get_valid_fields(stmt)) + + ast_node.append([node_idx, stmt]) + if len(stack): + ast_edge.append([stack[-1], node_idx]) + + stack.append(node_idx) + + def post_func(stmt): + del stack[-1] + + _ffi_api.PrePostOrderVisit(stmt, pre_func, post_func) + + with open(filename, "w") as f: + f.write("digraph {\n") + f.write(" node [shape=matrix]\n") + for node in ast_node: + ast_type = type(node[1]) + ast_str = str(node[1]).replace("\n", "\\l").replace("\\n", "\\l").replace('"', '\\"') + f.write(" node%d" % (node[0])) + f.write("[label=\"%s\n%s\"]" % (ast_type, ast_str)) + f.write(";\n") + for edge in ast_edge: + #f.write(" node%d -> node%d [label=\"%s\"];\n" % (edge[0], edge[1], match(idx2obj[edge[0]][1], idx2obj[edge[1]][0]))) + f.write(" node%d -> node%d;\n" % (edge[0], edge[1])) + f.write("}\n") diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index c0abf953eec2..fd011cf9ee2f 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -585,6 +585,45 @@ void PostOrderVisit(const ObjectRef& node, std::function } } +// Implementations of Dump ast +class IRAstDumper : public StmtExprVisitor { + public: + explicit IRAstDumper(std::function f, + std::function b) : f_(f), b_(b) {} + + void VisitExpr(const PrimExpr& node) final { + if (visited_.count(node.get()) != 0) return; + visited_.insert(node.get()); + f_(node); + ExprVisitor::VisitExpr(node); + b_(node); + } + + void VisitStmt(const Stmt& node) final { + if (visited_.count(node.get()) != 0) return; + visited_.insert(node.get()); + f_(node); + StmtVisitor::VisitStmt(node); + b_(node); + } + + private: + std::function f_, b_; + std::unordered_set visited_; +}; + +void PrePostOrderVisit(const ObjectRef& node, + std::function fvisit, + std::function bvisit) { + if (node.as()) { + IRAstDumper visitor(fvisit, bvisit); + visitor(Downcast(node)); + } else { + IRAstDumper visitor(fvisit, bvisit); + visitor(Downcast(node)); + } +} + class IRTransformer final : public StmtExprMutator { public: IRTransformer(const runtime::PackedFunc& f_preorder, const runtime::PackedFunc& f_postorder, @@ -802,6 +841,10 @@ TVM_REGISTER_GLOBAL("tir.PostOrderVisit").set_body_typed([](ObjectRef node, Pack tir::PostOrderVisit(node, [f](const ObjectRef& n) { f(n); }); }); +TVM_REGISTER_GLOBAL("tir.PrePostOrderVisit").set_body_typed([](ObjectRef node, PackedFunc f, PackedFunc b) { + tir::PrePostOrderVisit(node, [f](const ObjectRef& n) { f(n); }, [b](const ObjectRef& n) { b(n); }); +}); + TVM_REGISTER_GLOBAL("tir.PreOrderVisit").set_body_typed([](ObjectRef node, PackedFunc f) { tir::PreOrderVisit(node, [f](const ObjectRef& n) { return f(n); }); }); diff --git a/workspace/tir_ast/demo.py b/workspace/tir_ast/demo.py new file mode 100644 index 000000000000..8510dcad3789 --- /dev/null +++ b/workspace/tir_ast/demo.py @@ -0,0 +1,95 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np + +import tvm +from tvm import te +from tvm.tir import ast_dumper +from tvm.ir.module import IRModule +from tvm.script import tir as T +import argparse +import os + +################################################################# +# Parse arguments + +def parse_args(): + parser = argparse.ArgumentParser("Evaluate tuned result") + parser.add_argument( + '-b', + '--batch_size', + type=int, + default=16, + help='batch size') + parser.add_argument( + '-d', + '--device_id', + type=int, + default=7, + help='device id to be used' + ) + parser.add_argument( + '--tuned_dir', + default='./result', + help='dirname of tuned result stored' + ) + args = parser.parse_args() + return args + +args = parse_args() +print("Arguments: %s" % args) + +#n = te.var("m") +#A = te.placeholder((n,), name='A') +#B = te.placeholder((n,), name='B') +#C = te.compute(A.shape, lambda i: A[i] + B[i], name="C") +# +#s = te.create_schedule([C.op]) +##bx, tx = s[C].split(C.op.axis[0], factor=64) +# +#res = tvm.lower(s, [A, B, C], simple_mode=True) +#print("--->Module") +#print(res) +#print("--->PrimFunc.body") +#print(res["main"].body) +#print("--->Dump") +##print(ast_dumper.dump(res["main"].body)) +#ast_dumper.dump(res["main"].body, os.path.join("./log", "task_%s.dot" % (0))) + +@tvm.script.ir_module +class MyModule: + @T.prim_func + def main(a: T.handle, b: T.handle): + # We exchange data between function by handles, which are similar to pointer. + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # Create buffer from handles. + A = T.match_buffer(a, (8,), dtype="float32") + B = T.match_buffer(b, (8,), dtype="float32") + for i in range(8): + # A block is an abstraction for computation. + with T.block("B"): + # Define a spatial block iterator and bind it to value i. + vi = T.axis.spatial(8, i) + B[vi] = A[vi] + 1.0 + + +ir_module = MyModule +print("--->Module") +print(ir_module) +print("--->PrimFunc.body") +print(ir_module["main"].body) +ast_dumper.dump(ir_module["main"].body, os.path.join("./ast", "task_%s.dot" % (1)))