From 55a173a8b449040783bbdfa5074a13091d1185ef Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Wed, 22 Apr 2020 13:25:36 -0700 Subject: [PATCH] [TIR] Enhance Substitute, python bindings for Substitute/PostOrderVisit/IRTransform. (#5400) Substitute now takes a std::function to customize more replacing behaviors. Co-authored-by: Siyuan Feng Co-authored-by: Siyuan Feng --- docs/api/python/tir.rst | 7 + include/tvm/runtime/container.h | 7 + include/tvm/tir/ir_pass.h | 34 --- include/tvm/tir/stmt_functor.h | 72 ++++- python/tvm/te/hybrid/util.py | 4 +- python/tvm/tir/__init__.py | 1 + python/tvm/tir/stmt_functor.py | 77 +++++ src/arith/solve_linear_equation.cc | 31 +- src/te/autodiff/ad_util.cc | 2 +- src/te/operation/hybrid_op.cc | 1 - src/te/operation/op_util.h | 2 +- src/te/operation/tensor_compute_op.cc | 3 +- src/tir/ir/data_layout.cc | 2 +- src/tir/ir/expr.cc | 6 +- src/tir/ir/stmt_functor.cc | 283 +++++++++++------- src/tir/pass/ffi_api.cc | 18 +- src/tir/pass/hoist_if_then_else.cc | 8 +- src/tir/pass/simple_passes.cc | 72 ----- .../test_arith_solve_linear_system.py | 8 +- .../unittest/test_target_codegen_cuda.py | 2 +- .../unittest/test_target_codegen_llvm.py | 2 +- .../python/unittest/test_te_hybrid_script.py | 2 +- .../test_te_schedule_bound_inference.py | 12 +- tests/python/unittest/test_te_schedule_ops.py | 6 +- tests/python/unittest/test_te_tensor.py | 2 +- .../python/unittest/test_tir_pass_hoist_if.py | 4 +- .../unittest/test_tir_pass_ir_transform.py | 2 +- ...test_tir_transform_inject_double_buffer.py | 2 +- ...tir_transform_instrument_bound_checkers.py | 2 +- .../test_tir_transform_loop_partition.py | 2 +- .../test_tir_transform_storage_flatten.py | 2 +- .../test_tir_transform_storage_rewrite.py | 24 +- tutorials/dev/low_level_custom_pass.py | 12 +- vta/python/vta/transform.py | 22 +- 34 files changed, 419 insertions(+), 317 deletions(-) create mode 100644 python/tvm/tir/stmt_functor.py diff --git a/docs/api/python/tir.rst b/docs/api/python/tir.rst index 8ef247aff2f7..9f2581b8c0a8 100644 --- a/docs/api/python/tir.rst +++ b/docs/api/python/tir.rst @@ -38,3 +38,10 @@ tvm.tir.analysis :members: :imported-members: :autosummary: + + +tvm.tir.stmt_functor +-------------------- +.. automodule:: tvm.tir.stmt_functor + :members: + :autosummary: diff --git a/include/tvm/runtime/container.h b/include/tvm/runtime/container.h index 8f426415ffee..7d08613af215 100644 --- a/include/tvm/runtime/container.h +++ b/include/tvm/runtime/container.h @@ -611,6 +611,10 @@ struct PackedFuncValueConverter<::tvm::runtime::String> { } }; +/*! \brief Helper to represent nullptr for optional. */ +struct NullOptType { +}; + /*! * \brief Optional container that to represent to a Nullable variant of T. * \tparam T The original ObjectRef. @@ -642,6 +646,8 @@ class Optional : public ObjectRef { * \param ptr */ explicit Optional(ObjectPtr ptr) : ObjectRef(ptr) {} + /*! \brief Nullopt handling */ + Optional(NullOptType) {} // NOLINT(*) // nullptr handling. // disallow implicit conversion as 0 can be implicitly converted to nullptr_t explicit Optional(std::nullptr_t) {} @@ -751,6 +757,7 @@ struct PackedFuncValueConverter> { // expose the functions to the root namespace. using runtime::String; using runtime::Optional; +constexpr runtime::NullOptType NullOpt{}; } // namespace tvm namespace std { diff --git a/include/tvm/tir/ir_pass.h b/include/tvm/tir/ir_pass.h index 8bfe4f4b3032..5dd080b0904f 100644 --- a/include/tvm/tir/ir_pass.h +++ b/include/tvm/tir/ir_pass.h @@ -81,40 +81,6 @@ bool ExprUseVar(const PrimExpr& e, const std::unordered_set& vse */ TVM_DLL Stmt ConvertSSA(Stmt stmt); -/*! - * \brief Substitute the var specified in key->var to be value. - * \param stmt The source statement to be substituted - * \param value_map The map of new values. - * \return The converted form. - */ -Stmt Substitute(Stmt stmt, - const std::unordered_map& value_map); - -/*! - * \brief Substitute the var specified in key->var to be value. - * \param expr The source expression to be substituted - * \param value_map The map of new values. - * \return The converted expression. - */ -PrimExpr Substitute(PrimExpr expr, - const std::unordered_map& value_map); - -/*! - * \brief Substitute the var specified in key->var to be value. - * \param stmt The source statement to be substituted - * \param value_map The map of new values. - * \return The converted form. - */ -Stmt Substitute(Stmt stmt, const Map& value_map); - -/*! - * \brief Substitute the var specified in key->var to be value. - * \param expr The source expression to be substituted - * \param value_map The map of new values. - * \return The converted expression. - */ -PrimExpr Substitute(PrimExpr expr, const Map& value_map); - /*! * \brief Verify if there is any argument bound to compact buffer. * diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h index a87ff9737d0c..0f8038e13ca6 100644 --- a/include/tvm/tir/stmt_functor.h +++ b/include/tvm/tir/stmt_functor.h @@ -20,17 +20,20 @@ /*! * \file tvm/tir/stmt_functor.h * - * \brief Functors for tir stmts. + * \brief Functors for tir stmts + * utility functions to call common functors. */ #ifndef TVM_TIR_STMT_FUNCTOR_H_ #define TVM_TIR_STMT_FUNCTOR_H_ #include +#include #include #include #include #include +#include namespace tvm { namespace tir { @@ -318,9 +321,9 @@ class StmtExprMutator : }; /*! - * \brief recursively visit the ir in post DFS order node, and transform it + * \brief recursively visit the ir nodes in post DFS order, and transform it * - * \param node The ir to be transformed. + * \param stmt The ir to be transformed. * \param preorder The function called in before recursive mutation * If preorder returns None, then the transform will proceed to recursive call. * If preorder returns a not None Stmt/Expr, the transformer will simply return it and @@ -328,23 +331,76 @@ class StmtExprMutator : * \param postorder The function called after recursive mutation. * The recursive mutation result is passed to postorder for further mutation. * \param only_enable List of runtime::String. - * If it is empty, all IRNode will call preorder/postorder - * If it is not empty, preorder/postorder will only be called + * If it is null, all IRNode will call preorder/postorder + * If it is not null, preorder/postorder will only be called * when the IRNode's type key is in the list. */ -TVM_DLL Stmt IRTransform(Stmt node, +TVM_DLL Stmt IRTransform(Stmt stmt, const runtime::PackedFunc& preorder, const runtime::PackedFunc& postorder, - const Array& only_enable = {}); + Optional> only_enable = NullOpt); /*! - * \brief recursively visit the ir in post DFS order node, apply fvisit + * \brief Recursively visit the ir in post DFS order node, apply fvisit * Each node is guaranteed to be visited only once. * \param node The ir to be visited. * \param fvisit The visitor function to be applied. */ TVM_DLL void PostOrderVisit(const ObjectRef& node, std::function fvisit); +/*! + * \brief Substitute the var specified by vmap. + * \param stmt The source statement to be substituted + * \param vmap returns a new value if re-mapping is needed, otherwise returns nullptr. + * \return The converted form. + */ +TVM_DLL Stmt Substitute(Stmt stmt, + std::function(const Var& var)> vmap); + +/*! + * \brief Substitute the var specified by vmap. + * \param expr The source statement to be substituted + * \param vmap returns a new value if re-mapping is needed, otherwise returns nullptr. + * \return The result. + */ +TVM_DLL PrimExpr Substitute(PrimExpr expr, + std::function(const Var& var)> vmap); + +/*! + * \brief Sugar for substitute via a given map. + * \param input The input to be updated. + * \param value_map The map of new values. + * \return The result. + * \tparam T the input type, can be PrimExpr or Stmt. + */ +template +inline T Substitute(T input, const Map& value_map) { + auto vmap = [&](const Var& var) -> Optional { + auto it = value_map.find(var); + if (it != value_map.end()) return (*it).second; + return Optional(nullptr); + }; + return Substitute(std::move(input), vmap); +} + +/*! + * \brief Sugar for substitute via a given map. + * \param input The input to be updated. + * \param value_map The map of new values. + * \return The result. + * \tparam T the input type, can be PrimExpr or Stmt. + */ +template +inline T Substitute(T input, + const std::unordered_map& value_map) { + auto vmap = [&](const Var& var) -> Optional { + auto it = value_map.find(var.get()); + if (it != value_map.end()) return (*it).second; + return Optional(nullptr); + }; + return Substitute(std::move(input), vmap); +} + } // namespace tir } // namespace tvm diff --git a/python/tvm/te/hybrid/util.py b/python/tvm/te/hybrid/util.py index 6c019893bf20..01eeeec16142 100644 --- a/python/tvm/te/hybrid/util.py +++ b/python/tvm/te/hybrid/util.py @@ -72,7 +72,7 @@ def _pruned_source(func): def replace_io(body, rmap): """Replacing tensors usage according to the dict given""" # pylint: disable=import-outside-toplevel - from tvm.tir import ir_pass + from tvm.tir import stmt_functor def replace(op): if isinstance(op, _stmt.Provide) and op.func in rmap.keys(): @@ -84,7 +84,7 @@ def replace(op): _expr.Call.Halide, buf.op, buf.value_index) return None - return ir_pass.IRTransform(body, None, replace, ['Provide', 'Call']) + return stmt_functor.ir_transform(body, None, replace, ['Provide', 'Call']) def _is_tvm_arg_types(args): diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index ddfb6a5f69c1..6a62505a3034 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -48,3 +48,4 @@ from . import ir_pass from . import transform from . import analysis +from . import stmt_functor diff --git a/python/tvm/tir/stmt_functor.py b/python/tvm/tir/stmt_functor.py new file mode 100644 index 000000000000..cea8d1474621 --- /dev/null +++ b/python/tvm/tir/stmt_functor.py @@ -0,0 +1,77 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Statement functor utilities for IR transformations""" +from . import _ffi_api + + +def ir_transform(stmt, preorder, postorder, only_enable=None): + """Recursively visit and transform ir nodes in post DFS order. + + Parameters + ---------- + stmt : Stmt + The input to be transformed. + + preorder: function + The function called in before recursive mutation + If preorder returns None, then the transform will proceed to recursive call. + If preorder returns a not None Stmt/Expr, the transformer will simply return it and + won't do further recursion. + + postorder : function + The function called after recursive mutation. + + only_enable : Optional[List[str]] + List of types that we only enable. + + Returns + ------- + result : Stmt + The result. + """ + return _ffi_api.IRTransform(stmt, preorder, postorder, only_enable) + + +def post_order_visit(stmt, fvisit): + """Recursively visit the ir in post DFS order node, apply fvisit + Each node is guaranteed to be visited only once. + + Parameters + ---------- + fvisit: function + The visitor function. + """ + return _ffi_api.PostOrderVisit(stmt, fvisit) + + +def substitute(node, vmap): + """ Substitute the var specified by vmap. + + Parameters + ---------- + node: ObjectRef + The input. + + vmap : Dict[Var, PrimExpr] + The variable mapping. + + Returns + ------- + result : Stmt + The result. + """ + return _ffi_api.Substitute(node, vmap) diff --git a/src/arith/solve_linear_equation.cc b/src/arith/solve_linear_equation.cc index 8142a03155c8..a89cebe0bf04 100644 --- a/src/arith/solve_linear_equation.cc +++ b/src/arith/solve_linear_equation.cc @@ -26,9 +26,10 @@ #include #include #include -#include #include -#include + +#include +#include #include namespace tvm { @@ -130,10 +131,10 @@ void SmithNormalFormDiag(std::vector >* S, (*S)[i][j] = new_i_j; } // We have to do the same with rhs - PrimExpr ea = te::make_const((*y)[index].dtype(), a); - PrimExpr eb = te::make_const((*y)[i].dtype(), b); - PrimExpr e_m_g = te::make_const((*y)[i].dtype(), m_g); - PrimExpr e_n_g = te::make_const((*y)[index].dtype(), n_g); + PrimExpr ea = tir::make_const((*y)[index].dtype(), a); + PrimExpr eb = tir::make_const((*y)[i].dtype(), b); + PrimExpr e_m_g = tir::make_const((*y)[i].dtype(), m_g); + PrimExpr e_n_g = tir::make_const((*y)[index].dtype(), n_g); PrimExpr new_index_rhs = ea*(*y)[index] + eb*(*y)[i]; PrimExpr new_i_rhs = e_n_g*(*y)[index] - e_m_g*(*y)[i]; (*y)[index] = new_index_rhs; @@ -190,10 +191,10 @@ void SmithNormalFormDiag(std::vector >* S, (*V)[i][j] = new_i_j; } // And apply reverse transformations to new_to_old. - PrimExpr ea = te::make_const((*x)[j].dtype(), a); - PrimExpr eb = te::make_const((*x)[index].dtype(), b); - PrimExpr e_m_g = te::make_const((*x)[index].dtype(), m_g); - PrimExpr e_n_g = te::make_const((*x)[j].dtype(), n_g); + PrimExpr ea = tir::make_const((*x)[j].dtype(), a); + PrimExpr eb = tir::make_const((*x)[index].dtype(), b); + PrimExpr e_m_g = tir::make_const((*x)[index].dtype(), m_g); + PrimExpr e_n_g = tir::make_const((*x)[j].dtype(), n_g); PrimExpr new_index = e_m_g*(*x)[index] + e_n_g*(*x)[j]; PrimExpr new_j = eb*(*x)[index] - ea*(*x)[j]; (*x)[index] = new_index; @@ -369,7 +370,7 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_sol IntConstraints( /*variables=*/{}, /*ranges=*/{}, - /*relations=*/{te::make_zero(DataType::Bool())}), + /*relations=*/{tir::make_zero(DataType::Bool())}), {}, {}); } else if (!tir::is_const_int(new_relation, 1)) { new_relations.push_back(new_relation); @@ -403,13 +404,13 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_sol // The j-th variable is just a single value, don't create a tvm variable // S^{-1}_{nxm} Uy_{mxn} if (S[j][j] >= 0) { - PrimExpr a = te::make_const(Uy[j].dtype(), S[j][j]); + PrimExpr a = tir::make_const(Uy[j].dtype(), S[j][j]); solution_for_V_inv_x.push_back( analyzer_problem.Simplify(floordiv(Uy[j], a))); } else { // This is required because some simplifiers // have problems with dividing by negative numbers - PrimExpr a = te::make_const(Uy[j].dtype(), -S[j][j]); + PrimExpr a = tir::make_const(Uy[j].dtype(), -S[j][j]); solution_for_V_inv_x.push_back( analyzer_problem.Simplify(floordiv(-Uy[j], a))); } @@ -418,9 +419,9 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_sol // V V^{-1} x = x for (size_t i = 0; i < num_vars; ++i) { - PrimExpr e = te::make_zero(system_to_solve->variables[i].dtype()); + PrimExpr e = tir::make_zero(system_to_solve->variables[i].dtype()); for (size_t j = 0; j < num_vars; ++j) { - e = e + te::make_const(e.dtype(), V[i][j])*solution_for_V_inv_x[j]; + e = e + tir::make_const(e.dtype(), V[i][j])*solution_for_V_inv_x[j]; } e = analyzer_problem.Simplify(e); old_to_new_map.Set(system_to_solve->variables[i], e); diff --git a/src/te/autodiff/ad_util.cc b/src/te/autodiff/ad_util.cc index 3a90beff4822..b1c97e3cecb8 100644 --- a/src/te/autodiff/ad_util.cc +++ b/src/te/autodiff/ad_util.cc @@ -22,7 +22,7 @@ * \brief Utility for tensor-level auto-differentiation. */ #include -#include +#include #include #include "ad_util.h" diff --git a/src/te/operation/hybrid_op.cc b/src/te/operation/hybrid_op.cc index 4da127ea0a85..7bb5d6153d8d 100644 --- a/src/te/operation/hybrid_op.cc +++ b/src/te/operation/hybrid_op.cc @@ -26,7 +26,6 @@ #include #include #include -#include #include #include #include diff --git a/src/te/operation/op_util.h b/src/te/operation/op_util.h index 5e16b8e4a879..fbe2e0a95d79 100644 --- a/src/te/operation/op_util.h +++ b/src/te/operation/op_util.h @@ -79,7 +79,7 @@ Stmt ReplaceTensor(Stmt stmt, * \param replace The replacement rule. */ PrimExpr ReplaceTensor(PrimExpr expr, - const std::unordered_map& replace); + const std::unordered_map& replace); /*! * \brief Substitute the variables of stmt by value map. diff --git a/src/te/operation/tensor_compute_op.cc b/src/te/operation/tensor_compute_op.cc index 4cdc9e1f8d32..f714691f4171 100644 --- a/src/te/operation/tensor_compute_op.cc +++ b/src/te/operation/tensor_compute_op.cc @@ -25,8 +25,9 @@ #include #include #include -#include +#include #include + #include "./op_util.h" #include "./compute_op.h" #include "../../arith/compute_expr.h" diff --git a/src/tir/ir/data_layout.cc b/src/tir/ir/data_layout.cc index fb63fed623cf..77de9f4aacbc 100644 --- a/src/tir/ir/data_layout.cc +++ b/src/tir/ir/data_layout.cc @@ -23,7 +23,7 @@ */ #include #include -#include +#include #include #include diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 03925ec3783a..e8c850a4831a 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -24,7 +24,7 @@ #include #include #include -#include +#include #include #include #include "../pass/ir_util.h" @@ -363,8 +363,8 @@ Array CommReducerNode::operator()(Array a, Array b value_map.Set(rhs[i], b[i]); } return UpdateArray(result, [&value_map] (const PrimExpr& e) { - return Substitute(e, value_map); - }); + return Substitute(e, value_map); + }); } TVM_REGISTER_GLOBAL("tir.CommReducer") diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index 5e584ebf55f4..ec97b03c88c4 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -19,116 +19,14 @@ /*! * \file stmt_functor.cc */ +#include #include +#include #include "functor_common.h" namespace tvm { namespace tir { -// visitor to implement apply -class IRApplyVisit : - public StmtExprVisitor { - public: - explicit IRApplyVisit(std::function f) : f_(f) {} - - void VisitExpr(const PrimExpr& node) final { - if (visited_.count(node.get()) != 0) return; - visited_.insert(node.get()); - ExprVisitor::VisitExpr(node); - f_(node); - } - - void VisitStmt(const Stmt& node) final { - if (visited_.count(node.get()) != 0) return; - visited_.insert(node.get()); - StmtVisitor::VisitStmt(node); - f_(node); - } - - private: - std::function f_; - std::unordered_set visited_; -}; - -void PostOrderVisit(const ObjectRef& node, - std::function fvisit) { - if (node.as()) { - IRApplyVisit visitor(fvisit); - visitor(Downcast(node)); - } else { - IRApplyVisit visitor(fvisit); - visitor(Downcast(node)); - } -} - -class IRTransformer final : - public StmtExprMutator { - public: - IRTransformer(const runtime::PackedFunc& f_preorder, - const runtime::PackedFunc& f_postorder, - const std::unordered_set& only_enable) - : f_preorder_(f_preorder), - f_postorder_(f_postorder), - only_enable_(only_enable) { - } - - Stmt VisitStmt(const Stmt& stmt) final { - return MutateInternal(stmt, [this](const Stmt& s) { - return this->BaseVisitStmt(s); - }); - } - PrimExpr VisitExpr(const PrimExpr& expr) final { - return MutateInternal(expr, [this](const PrimExpr& e) { - return this->BaseVisitExpr(e); - }); - } - - private: - // NOTE: redirect to parent's call - // This is used to get around limitation of gcc-4.8 - Stmt BaseVisitStmt(const Stmt& s) { - return StmtMutator::VisitStmt(s); - } - PrimExpr BaseVisitExpr(const PrimExpr& e) { - return ExprMutator::VisitExpr(e); - } - - template - T MutateInternal(const T& node, F fmutate) { - if (only_enable_.size() && - !only_enable_.count(node->type_index())) { - return fmutate(node); - } - if (f_preorder_ != nullptr) { - T pre = f_preorder_(node); - if (pre.defined()) return pre; - } - T new_node = fmutate(node); - if (f_postorder_ != nullptr) { - T post = f_postorder_(new_node); - if (post.defined()) return post; - } - return new_node; - } - // The functions - const runtime::PackedFunc& f_preorder_; - const runtime::PackedFunc& f_postorder_; - // type indices enabled. - const std::unordered_set& only_enable_; -}; - -Stmt IRTransform(Stmt ir_node, - const runtime::PackedFunc& f_preorder, - const runtime::PackedFunc& f_postorder, - const Array& only_enable) { - std::unordered_set only_type_index; - for (auto s : only_enable) { - only_type_index.insert(Object::TypeKey2Index(s.c_str())); - } - IRTransformer transform(f_preorder, f_postorder, only_type_index); - return transform(std::move(ir_node)); -} - void StmtVisitor::VisitStmt_(const LetStmtNode* op) { this->VisitExpr(op->value); this->VisitStmt(op->body); @@ -511,6 +409,183 @@ Stmt StmtMutator::VisitStmt_(const FreeNode* op) { } +// Implementations of IRTransform, PostOrderVisit and Substitute +class IRApplyVisit : + public StmtExprVisitor { + public: + explicit IRApplyVisit(std::function f) : f_(f) {} + + void VisitExpr(const PrimExpr& node) final { + if (visited_.count(node.get()) != 0) return; + visited_.insert(node.get()); + ExprVisitor::VisitExpr(node); + f_(node); + } + + void VisitStmt(const Stmt& node) final { + if (visited_.count(node.get()) != 0) return; + visited_.insert(node.get()); + StmtVisitor::VisitStmt(node); + f_(node); + } + + private: + std::function f_; + std::unordered_set visited_; +}; + +void PostOrderVisit(const ObjectRef& node, + std::function fvisit) { + if (node.as()) { + IRApplyVisit visitor(fvisit); + visitor(Downcast(node)); + } else { + IRApplyVisit visitor(fvisit); + visitor(Downcast(node)); + } +} + +class IRTransformer final : + public StmtExprMutator { + public: + IRTransformer(const runtime::PackedFunc& f_preorder, + const runtime::PackedFunc& f_postorder, + const std::unordered_set& only_enable) + : f_preorder_(f_preorder), + f_postorder_(f_postorder), + only_enable_(only_enable) { + } + + Stmt VisitStmt(const Stmt& stmt) final { + return MutateInternal(stmt, [this](const Stmt& s) { + return this->BaseVisitStmt(s); + }); + } + PrimExpr VisitExpr(const PrimExpr& expr) final { + return MutateInternal(expr, [this](const PrimExpr& e) { + return this->BaseVisitExpr(e); + }); + } + + private: + // NOTE: redirect to parent's call + // This is used to get around limitation of gcc-4.8 + Stmt BaseVisitStmt(const Stmt& s) { + return StmtMutator::VisitStmt(s); + } + PrimExpr BaseVisitExpr(const PrimExpr& e) { + return ExprMutator::VisitExpr(e); + } + + template + T MutateInternal(const T& node, F fmutate) { + if (only_enable_.size() && + !only_enable_.count(node->type_index())) { + return fmutate(node); + } + if (f_preorder_ != nullptr) { + T pre = f_preorder_(node); + if (pre.defined()) return pre; + } + T new_node = fmutate(node); + if (f_postorder_ != nullptr) { + T post = f_postorder_(new_node); + if (post.defined()) return post; + } + return new_node; + } + // The functions + const runtime::PackedFunc& f_preorder_; + const runtime::PackedFunc& f_postorder_; + // type indices enabled. + const std::unordered_set& only_enable_; +}; + +Stmt IRTransform(Stmt ir_node, + const runtime::PackedFunc& f_preorder, + const runtime::PackedFunc& f_postorder, + Optional> only_enable) { + std::unordered_set only_type_index; + if (only_enable.defined()) { + for (auto s : only_enable.value()) { + only_type_index.insert(Object::TypeKey2Index(s.c_str())); + } + } + IRTransformer transform(f_preorder, f_postorder, only_type_index); + return transform(std::move(ir_node)); +} + +class IRSubstitue : public StmtExprMutator { + public: + explicit IRSubstitue(std::function(const Var&)> vmap) + : vmap_(vmap) { + } + + PrimExpr VisitExpr_(const VarNode* op) final { + Var var = GetRef(op); + auto ret = vmap_(var); + if (ret.defined()) return ret.value(); + return std::move(var); + } + + PrimExpr VisitExpr_(const LoadNode* op) final { + // NOTE: we do not explicit recursivly mutate op->buffer_var + PrimExpr ret = StmtExprMutator::VisitExpr_(op); + op = ret.as(); + if (auto mapped_var = vmap_(op->buffer_var)) { + return LoadNode::make( + op->dtype, Downcast(mapped_var.value()), op->index, op->predicate); + } else { + return ret; + } + } + + Stmt VisitStmt_(const StoreNode* op) final { + // NOTE: we do not explicit recursivly mutate op->buffer_var + Stmt ret = StmtExprMutator::VisitStmt_(op); + op = ret.as(); + if (auto mapped_var = vmap_(op->buffer_var)) { + return StoreNode::make( + Downcast(mapped_var.value()), op->value, op->index, op->predicate); + } else { + return ret; + } + } + + private: + std::function(const Var&)> vmap_; +}; + +Stmt Substitute(Stmt stmt, + std::function(const Var&)> vmap) { + return IRSubstitue(vmap)(std::move(stmt)); +} + +PrimExpr Substitute(PrimExpr expr, + std::function(const Var&)> vmap) { + return IRSubstitue(vmap)(std::move(expr)); +} + + +TVM_REGISTER_GLOBAL("tir.IRTransform") +.set_body_typed(IRTransform); + + +TVM_REGISTER_GLOBAL("tir.PostOrderVisit") +.set_body_typed([](ObjectRef node, PackedFunc f) { + tir::PostOrderVisit(node, [f](const ObjectRef& n) { + f(n); + }); +}); + +TVM_REGISTER_GLOBAL("tir.Substitute") +.set_body_typed([](ObjectRef node, Map vmap) -> ObjectRef{ + if (node->IsInstance()) { + return Substitute(Downcast(node), vmap); + } else { + return Substitute(Downcast(node), vmap); + } +}); } // namespace tir } // namespace tvm diff --git a/src/tir/pass/ffi_api.cc b/src/tir/pass/ffi_api.cc index 5de38588cc38..95da62a502b0 100644 --- a/src/tir/pass/ffi_api.cc +++ b/src/tir/pass/ffi_api.cc @@ -32,28 +32,13 @@ namespace tvm { namespace tir { -TVM_REGISTER_GLOBAL("ir_pass.Substitute") -.set_body([](TVMArgs args, TVMRetValue *ret) { - if (args[0].IsObjectRef()) { - *ret = Substitute(args[0].operator Stmt(), args[1].operator Map()); - } else { - *ret = Substitute(args[0].operator PrimExpr(), args[1].operator Map()); - } - }); + TVM_REGISTER_GLOBAL("ir_pass.ExprUseVar") .set_body([](TVMArgs args, TVMRetValue *ret) { *ret = ExprUseVar(args[0].operator PrimExpr(), args[1].operator Var()); }); -TVM_REGISTER_GLOBAL("ir_pass.PostOrderVisit") -.set_body([](TVMArgs args, TVMRetValue *ret) { - PackedFunc f = args[1]; - tir::PostOrderVisit(args[0], [f](const ObjectRef& n) { - f(n); - }); - }); - // make from two arguments #define REGISTER_PASS(PassName) \ @@ -63,7 +48,6 @@ TVM_REGISTER_GLOBAL("ir_pass.PostOrderVisit") REGISTER_PASS(ConvertSSA); REGISTER_PASS(VerifySSA); -REGISTER_PASS(IRTransform); REGISTER_PASS(VerifyGPUCode); REGISTER_PASS(DecorateDeviceScope); REGISTER_PASS(VerifyCompactBuffer); diff --git a/src/tir/pass/hoist_if_then_else.cc b/src/tir/pass/hoist_if_then_else.cc index 8bc462079eae..3fa8687e4e7f 100644 --- a/src/tir/pass/hoist_if_then_else.cc +++ b/src/tir/pass/hoist_if_then_else.cc @@ -159,7 +159,7 @@ Stmt update_for(const Stmt& parent_for_stmt, const Stmt& new_if_stmt) { } }); - return IRTransform(parent_for_stmt, nullptr, replace_target_for, {"For"}); + return IRTransform(parent_for_stmt, nullptr, replace_target_for, Array{"For"}); } // Remove IfThenElse node from a For node. @@ -185,9 +185,9 @@ std::pair RemoveIf(const Stmt& for_stmt, const Stmt& if_stmt) { } }); - then_for = IRTransform(for_stmt, nullptr, replace_then_case, {"IfThenElse"}); + then_for = IRTransform(for_stmt, nullptr, replace_then_case, Array{"IfThenElse"}); if (if_stmt.as()->else_case.defined()) { - else_for = IRTransform(for_stmt, nullptr, replace_else_case, {"IfThenElse"}); + else_for = IRTransform(for_stmt, nullptr, replace_else_case, Array{"IfThenElse"}); } return std::make_pair(then_for, else_for); @@ -408,7 +408,7 @@ Stmt IfThenElseHoist::PostOrderMutate(const Stmt& stmt) { *ret = new_for; } }); - return IRTransform(stmt, nullptr, replace_top_for, {runtime::String("For")}); + return IRTransform(stmt, nullptr, replace_top_for, Array{"For"}); } Stmt HoistIfThenElse(Stmt stmt) { diff --git a/src/tir/pass/simple_passes.cc b/src/tir/pass/simple_passes.cc index 93d17ba347fc..a7dc12f8a87f 100644 --- a/src/tir/pass/simple_passes.cc +++ b/src/tir/pass/simple_passes.cc @@ -52,79 +52,7 @@ bool HasSideEffect(const PrimExpr& e) { return v.has_side_effect_; } -class IRSubstitue : public StmtExprMutator { - public: - explicit IRSubstitue( - const std::unordered_map& smap) - : smap_(smap) { - } - - PrimExpr VisitExpr_(const VarNode* op) final { - auto it = smap_.find(op); - if (it != smap_.end()) { - return it->second; - } else { - return GetRef(op); - } - } - - PrimExpr VisitExpr_(const LoadNode* op) final { - // NOTE: we do not explicit recursivly mutate op->buffer_var - PrimExpr ret = StmtExprMutator::VisitExpr_(op); - op = ret.as(); - auto it = smap_.find(op->buffer_var.get()); - if (it != smap_.end()) { - return LoadNode::make( - op->dtype, Downcast(it->second), op->index, op->predicate); - } else { - return ret; - } - } - - Stmt VisitStmt_(const StoreNode* op) final { - // NOTE: we do not explicit recursivly mutate op->buffer_var - Stmt ret = StmtExprMutator::VisitStmt_(op); - op = ret.as(); - auto it = smap_.find(op->buffer_var.get()); - if (it != smap_.end()) { - return StoreNode::make( - Downcast(it->second), op->value, op->index, op->predicate); - } else { - return ret; - } - } - - private: - const std::unordered_map& smap_; -}; -Stmt Substitute(Stmt stmt, - const std::unordered_map& value_map) { - if (value_map.size() == 0) return stmt; - return IRSubstitue(value_map)(std::move(stmt)); -} - -PrimExpr Substitute(PrimExpr expr, - const std::unordered_map& value_map) { - if (value_map.size() == 0) return expr; - return IRSubstitue(value_map)(std::move(expr)); -} - -Stmt Substitute(Stmt stmt, const Map& value_map) { - std::unordered_map vmap; - for (const auto& kv : value_map) { - vmap[kv.first.get()] = kv.second; - } - return Substitute(stmt, vmap); -} - -PrimExpr Substitute(PrimExpr expr, const Map& value_map) { - std::unordered_map vmap; - for (const auto& kv : value_map) { - vmap[kv.first.get()] = kv.second; - } - return Substitute(expr, vmap); -} class VarTouchVisitor : public ExprVisitor { public: diff --git a/tests/python/unittest/test_arith_solve_linear_system.py b/tests/python/unittest/test_arith_solve_linear_system.py index 645b1a2b537c..4f4c5ee97944 100644 --- a/tests/python/unittest/test_arith_solve_linear_system.py +++ b/tests/python/unittest/test_arith_solve_linear_system.py @@ -29,7 +29,7 @@ def run_expr(expr, vranges): """ def _compute_body(*us): vmap = {v: u + r.min for (v, r), u in zip(vranges.items(), us)} - return tir.ir_pass.Substitute(expr, vmap) + return tir.stmt_functor.substitute(expr, vmap) A = te.compute([r.extent.value for v, r in vranges.items()], _compute_body) args = [tvm.nd.empty(A.shape, A.dtype)] @@ -69,17 +69,17 @@ def _check_forward(constraints1, constraints2, varmap, backvarmap): cond_on_vars = tir.const(1, 'bool') for v in constraints1.variables: # variable mapping is consistent - v_back = ana.simplify(tir.ir_pass.Substitute(varmap[v], backvarmap)) + v_back = ana.simplify(tir.stmt_functor.substitute(varmap[v], backvarmap)) cond_on_vars = te.all(cond_on_vars, v == v_back) # Also we have to check that the new relations are true when old relations are true - cond_subst = tir.ir_pass.Substitute( + cond_subst = tir.stmt_functor.substitute( te.all(tir.const(1, 'bool'), *constraints2.relations), backvarmap) # We have to include relations from vranges too for v in constraints2.variables: if v in constraints2.ranges: r = constraints2.ranges[v] range_cond = te.all(v >= r.min, v < r.min + r.extent) - range_cond = tir.ir_pass.Substitute(range_cond, backvarmap) + range_cond = tir.stmt_functor.substitute(range_cond, backvarmap) cond_subst = te.all(cond_subst, range_cond) cond_subst = ana.simplify(cond_subst) check_bruteforce(te.all(cond_subst, cond_on_vars), all_vranges, diff --git a/tests/python/unittest/test_target_codegen_cuda.py b/tests/python/unittest/test_target_codegen_cuda.py index 4c2ec2e884bb..49a7933a0cac 100644 --- a/tests/python/unittest/test_target_codegen_cuda.py +++ b/tests/python/unittest/test_target_codegen_cuda.py @@ -201,7 +201,7 @@ def vectorizer(op): def _transform(f, *_): return f.with_body( - tvm.tir.ir_pass.IRTransform(f.body, None, vectorizer, ['For'])) + tvm.tir.stmt_functor.ir_transform(f.body, None, vectorizer, ['For'])) return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name="MyVectorize") with tvm.target.build_config(add_lower_pass=[(1, MyVectorize())]): diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py index 26f93478b4fc..a7e1e57481a7 100644 --- a/tests/python/unittest/test_target_codegen_llvm.py +++ b/tests/python/unittest/test_target_codegen_llvm.py @@ -685,7 +685,7 @@ def vectorizer(op): def _transform(f, *_): return f.with_body( - tvm.tir.ir_pass.IRTransform(f.body, None, vectorizer, ['For'])) + tvm.tir.stmt_functor.ir_transform(f.body, None, vectorizer, ['For'])) return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name="my_vectorize") diff --git a/tests/python/unittest/test_te_hybrid_script.py b/tests/python/unittest/test_te_hybrid_script.py index 8afd65330f98..ea4179d7ca3f 100644 --- a/tests/python/unittest/test_te_hybrid_script.py +++ b/tests/python/unittest/test_te_hybrid_script.py @@ -24,7 +24,7 @@ @pytest.mark.skip def run_and_check(func, args, var_dict={}, target='llvm', sch=None, outs=None): def tvm_val_2_py_val(val): - val = tvm.tir.ir_pass.Substitute(val, var_dict) + val = tvm.tir.stmt_functor.substitute(val, var_dict) val = tvm.arith.Analyzer().simplify(val) assert isinstance(val, (tvm.tir.IntImm,)) return val.value diff --git a/tests/python/unittest/test_te_schedule_bound_inference.py b/tests/python/unittest/test_te_schedule_bound_inference.py index 6b6c519c8fa3..1fcc88d0d9f4 100644 --- a/tests/python/unittest/test_te_schedule_bound_inference.py +++ b/tests/python/unittest/test_te_schedule_bound_inference.py @@ -148,8 +148,8 @@ def test_bound_fusesplit1(): for k in range(1, 6): vars = tvm.runtime.convert({split1: tvm.tir.const(i, "int32"), l: tvm.tir.const(j, "int32"), xo.var: tvm.tir.const(k, "int32")}) tvm.testing.assert_prim_expr_equal( - tvm.tir.ir_pass.Substitute(bounds[A1.op.axis[0]].extent, vars), - tvm.tir.ir_pass.Substitute(expected_extent, vars) + tvm.tir.stmt_functor.substitute(bounds[A1.op.axis[0]].extent, vars), + tvm.tir.stmt_functor.substitute(expected_extent, vars) ) tvm.testing.assert_prim_expr_equal(bounds[A1.op.axis[1]].extent, l) @@ -170,10 +170,10 @@ def test_bound_fusesplit2(): bounds = tvm.te.schedule.InferBound(s) assert isinstance(bounds, tvm.container.Map) vars = tvm.runtime.convert({xo.var: tvm.tir.const(5, "int32")}) - tvm.testing.assert_prim_expr_equal(tvm.tir.ir_pass.Substitute(bounds[A1.op.axis[0]].min, vars), 2) - tvm.testing.assert_prim_expr_equal(tvm.tir.ir_pass.Substitute(bounds[A1.op.axis[1]].min, vars), 3) - tvm.testing.assert_prim_expr_equal(tvm.tir.ir_pass.Substitute(bounds[A1.op.axis[0]].extent, vars), 1) - tvm.testing.assert_prim_expr_equal(tvm.tir.ir_pass.Substitute(bounds[A1.op.axis[1]].extent, vars), 3) + tvm.testing.assert_prim_expr_equal(tvm.tir.stmt_functor.substitute(bounds[A1.op.axis[0]].min, vars), 2) + tvm.testing.assert_prim_expr_equal(tvm.tir.stmt_functor.substitute(bounds[A1.op.axis[1]].min, vars), 3) + tvm.testing.assert_prim_expr_equal(tvm.tir.stmt_functor.substitute(bounds[A1.op.axis[0]].extent, vars), 1) + tvm.testing.assert_prim_expr_equal(tvm.tir.stmt_functor.substitute(bounds[A1.op.axis[1]].extent, vars), 3) def test_bound_warp(): diff --git a/tests/python/unittest/test_te_schedule_ops.py b/tests/python/unittest/test_te_schedule_ops.py index 2a0c6c1f40af..7cbf20eccf12 100644 --- a/tests/python/unittest/test_te_schedule_ops.py +++ b/tests/python/unittest/test_te_schedule_ops.py @@ -155,7 +155,7 @@ def test_inline_mixed(): def check(x): if isinstance(x, tvm.tir.Call): assert x.func != A2 - tvm.tir.ir_pass.PostOrderVisit(s[C].op.body[0], check) + tvm.tir.stmt_functor.post_order_visit(s[C].op.body[0], check) def test_scan_inline1(): @@ -517,7 +517,7 @@ def schedule(thread_tag, mem_scope) : def collect_visit(stmt, f): ret = [] - tvm.tir.ir_pass.PostOrderVisit(stmt, lambda x: ret.append(f(x))) + tvm.tir.stmt_functor.post_order_visit(stmt, lambda x: ret.append(f(x))) return ret # local vs. threadIdx s = schedule(tx, "local") @@ -563,7 +563,7 @@ def test_local_stage_predicate2(): def collect_visit(stmt, f): ret = [] - tvm.tir.ir_pass.PostOrderVisit(stmt, lambda x: ret.append(f(x))) + tvm.tir.stmt_functor.post_order_visit(stmt, lambda x: ret.append(f(x))) return ret def visit_stmt(op): diff --git a/tests/python/unittest/test_te_tensor.py b/tests/python/unittest/test_te_tensor.py index 45280866af38..5d3cbadce165 100644 --- a/tests/python/unittest/test_te_tensor.py +++ b/tests/python/unittest/test_te_tensor.py @@ -264,7 +264,7 @@ def get_B1_realize(x): x.func == B1.op and x.value_index == 1: ret.append(x) ret = [] - tvm.tir.ir_pass.PostOrderVisit(stmt, get_B1_realize) + tvm.tir.stmt_functor.post_order_visit(stmt, get_B1_realize) assert stmt.node == C.op and len(ret) == 1 diff --git a/tests/python/unittest/test_tir_pass_hoist_if.py b/tests/python/unittest/test_tir_pass_hoist_if.py index f6bdbd6130f4..e7e3a250a770 100644 --- a/tests/python/unittest/test_tir_pass_hoist_if.py +++ b/tests/python/unittest/test_tir_pass_hoist_if.py @@ -32,7 +32,7 @@ def _visit(op): key = op if isinstance(op, tvm.tir.IfThenElse): global var_list - tvm.tir.ir_pass.PostOrderVisit(op.condition, _extract_vars) + tvm.tir.stmt_functor.post_order_visit(op.condition, _extract_vars) val = [(op.then_case, op.else_case), ("IfThenElse", tuple(var_list))] var_list.clear() elif isinstance(op, tvm.tir.For): @@ -43,7 +43,7 @@ def _visit(op): return node_dict[key] = val - tvm.tir.ir_pass.PostOrderVisit(stmt, _visit) + tvm.tir.stmt_functor.post_order_visit(stmt, _visit) for key, val in node_dict.items(): struct[val[1]] = tuple(node_dict[child][1] if child in node_dict else None for child in val[0]) diff --git a/tests/python/unittest/test_tir_pass_ir_transform.py b/tests/python/unittest/test_tir_pass_ir_transform.py index cb7417a7a54f..7bf70119e4aa 100644 --- a/tests/python/unittest/test_tir_pass_ir_transform.py +++ b/tests/python/unittest/test_tir_pass_ir_transform.py @@ -37,7 +37,7 @@ def postorder(op): if op.name == "TestA": return tvm.tir.call_extern("int32", "TestB", op.args[0] + 1) return op - body = tvm.tir.ir_pass.IRTransform(body, preorder, postorder, ["Call"]) + body = tvm.tir.stmt_functor.ir_transform(body, preorder, postorder, ["Call"]) stmt_list = tvm.tir.stmt_list(body.body.body) assert stmt_list[0].value.args[0].name == "TestB" assert stmt_list[1].value.value == 0 diff --git a/tests/python/unittest/test_tir_transform_inject_double_buffer.py b/tests/python/unittest/test_tir_transform_inject_double_buffer.py index 4c0573da7616..dd1f6a31c087 100644 --- a/tests/python/unittest/test_tir_transform_inject_double_buffer.py +++ b/tests/python/unittest/test_tir_transform_inject_double_buffer.py @@ -54,7 +54,7 @@ def test_double_buffer(): def count_sync(op): if isinstance(op, tvm.tir.Call) and op.name == "tvm_storage_sync": count[0] += 1 - tvm.tir.ir_pass.PostOrderVisit(f.body, count_sync) + tvm.tir.stmt_functor.post_order_visit(f.body, count_sync) assert count[0] == 4 diff --git a/tests/python/unittest/test_tir_transform_instrument_bound_checkers.py b/tests/python/unittest/test_tir_transform_instrument_bound_checkers.py index 47c1f7bf1159..dcedca9bf311 100644 --- a/tests/python/unittest/test_tir_transform_instrument_bound_checkers.py +++ b/tests/python/unittest/test_tir_transform_instrument_bound_checkers.py @@ -21,7 +21,7 @@ def collect_visit(stmt, f): ret = [] - tvm.tir.ir_pass.PostOrderVisit(stmt, lambda x: ret.append(f(x))) + tvm.tir.stmt_functor.post_order_visit(stmt, lambda x: ret.append(f(x))) return ret diff --git a/tests/python/unittest/test_tir_transform_loop_partition.py b/tests/python/unittest/test_tir_transform_loop_partition.py index 6ca3f596f196..59b8796f743a 100644 --- a/tests/python/unittest/test_tir_transform_loop_partition.py +++ b/tests/python/unittest/test_tir_transform_loop_partition.py @@ -20,7 +20,7 @@ def collect_visit(stmt, f): ret = [] - tvm.tir.ir_pass.PostOrderVisit(stmt, lambda x : ret.append(f(x))) + tvm.tir.stmt_functor.post_order_visit(stmt, lambda x : ret.append(f(x))) return ret diff --git a/tests/python/unittest/test_tir_transform_storage_flatten.py b/tests/python/unittest/test_tir_transform_storage_flatten.py index 57eb349f18df..8400915c5fbb 100644 --- a/tests/python/unittest/test_tir_transform_storage_flatten.py +++ b/tests/python/unittest/test_tir_transform_storage_flatten.py @@ -123,7 +123,7 @@ def test_flatten_double_buffer(): def count_sync(op): if isinstance(op, tvm.tir.Call) and op.name == "tvm_storage_sync": count[0] += 1 - tvm.tir.ir_pass.PostOrderVisit(f.body, count_sync) + tvm.tir.stmt_functor.post_order_visit(f.body, count_sync) assert count[0] == 4 if __name__ == "__main__": diff --git a/tests/python/unittest/test_tir_transform_storage_rewrite.py b/tests/python/unittest/test_tir_transform_storage_rewrite.py index 85f856db366b..46ba687aebda 100644 --- a/tests/python/unittest/test_tir_transform_storage_rewrite.py +++ b/tests/python/unittest/test_tir_transform_storage_rewrite.py @@ -45,7 +45,7 @@ def test_storage_share(): def verify(n): if isinstance(n, tvm.tir.Allocate): num_alloc[0] += 1 - tvm.tir.ir_pass.PostOrderVisit(stmt, verify) + tvm.tir.stmt_functor.post_order_visit(stmt, verify) assert num_alloc[0] == 1 def register_mem(scope_tb, max_bits): @@ -84,7 +84,7 @@ def verify(n): if isinstance(n, tvm.tir.Allocate): num_alloc[0] += 1 assert n.extents[0].value == 200 - tvm.tir.ir_pass.PostOrderVisit(body, verify) + tvm.tir.stmt_functor.post_order_visit(body, verify) assert num_alloc[0] == 1 def test_alloc_different_dtypes(): @@ -139,7 +139,7 @@ def verify(n): mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], body)) body = tvm.tir.transform.StorageRewrite()(mod)["main"].body - tvm.tir.ir_pass.PostOrderVisit(body, verify) + tvm.tir.stmt_functor.post_order_visit(body, verify) length = 1024 dtype_list = ["float16", "int32", "uint16", "int8"] @@ -181,7 +181,7 @@ def test_inplace_rule(): def verify(n): if isinstance(n, tvm.tir.Allocate): num_alloc[0] += 1 - tvm.tir.ir_pass.PostOrderVisit(stmt, verify) + tvm.tir.stmt_functor.post_order_visit(stmt, verify) assert num_alloc[0] == 2 @@ -214,7 +214,7 @@ def verify(n): if isinstance(n, tvm.tir.Allocate): num_alloc[0] += 1 assert (n.extents[0].value == 16) - tvm.tir.ir_pass.PostOrderVisit(stmt, verify) + tvm.tir.stmt_functor.post_order_visit(stmt, verify) assert num_alloc[0] == 1 @@ -250,7 +250,7 @@ def verify(n): if isinstance(n, tvm.tir.AttrStmt): if n.attr_key == "storage_scope": alloc_stats[n.value.value] += 1 - tvm.tir.ir_pass.PostOrderVisit(stmt, verify) + tvm.tir.stmt_functor.post_order_visit(stmt, verify) assert alloc_stats["global"] == 2 assert alloc_stats["shared"] == num_stage @@ -318,7 +318,7 @@ def test_inplace_rule2(scope_tb = "local_TB2", max_bits = 1024 * 1024 * 1024): def verify(n): if isinstance(n, tvm.tir.Allocate): num_alloc[0] += 1 - tvm.tir.ir_pass.PostOrderVisit(stmt, verify) + tvm.tir.stmt_functor.post_order_visit(stmt, verify) assert num_alloc[0] == 2 def test_exceed_mem(): @@ -407,7 +407,7 @@ def test_inplace_rule3(): def verify(n): if isinstance(n, tvm.tir.Allocate): assert n.extents[0].value == 70 - tvm.tir.ir_pass.PostOrderVisit(stmt, verify) + tvm.tir.stmt_functor.post_order_visit(stmt, verify) def test_alloc_seq_type(): ib = tvm.tir.ir_builder.create() @@ -437,7 +437,7 @@ def verify(n): if isinstance(n, tvm.tir.Allocate): num_alloc[0] += 1 assert n.extents[0].value == 500 - tvm.tir.ir_pass.PostOrderVisit(body, verify) + tvm.tir.stmt_functor.post_order_visit(body, verify) assert num_alloc[0] == 1 def test_alloc_seq_type2(): @@ -469,7 +469,7 @@ def verify(n): if isinstance(n, tvm.tir.Allocate): num_alloc[0] += 1 assert n.extents[0].value == 200 - tvm.tir.ir_pass.PostOrderVisit(body, verify) + tvm.tir.stmt_functor.post_order_visit(body, verify) assert num_alloc[0] == 1 @@ -502,7 +502,7 @@ def verify(n): if isinstance(n, tvm.tir.Allocate): num_alloc[0] += 1 assert n.extents[0].value == 800 - tvm.tir.ir_pass.PostOrderVisit(body, verify) + tvm.tir.stmt_functor.post_order_visit(body, verify) assert num_alloc[0] == 1 def test_replace_dataflow(): @@ -540,7 +540,7 @@ def compute(a, b): def verify(n): if isinstance(n, tvm.tir.Allocate): assert n.extents[0].value == 268435456 - tvm.tir.ir_pass.PostOrderVisit(stmt, verify) + tvm.tir.stmt_functor.post_order_visit(stmt, verify) if __name__ == "__main__": diff --git a/tutorials/dev/low_level_custom_pass.py b/tutorials/dev/low_level_custom_pass.py index 49e86fdb8e9b..03eae1cf6414 100644 --- a/tutorials/dev/low_level_custom_pass.py +++ b/tutorials/dev/low_level_custom_pass.py @@ -70,7 +70,7 @@ # # IR Visitor # ~~~~~~~~~~ -# We can use ``tvm.tir.ir_pass.PostOrderVisit(stmt, func)`` to gather information from the Halide IR. +# We can use ``tvm.tir.stmt_functor.post_order_visit(stmt, func)`` to gather information from the Halide IR. # ``func`` is a function callback. This function will be called before exiting the current IR node, # i.e. post-order visit. Then we leverage side effects to store the result of IR visit, because the # return value of ``func`` will be ignored. @@ -111,7 +111,7 @@ def vectorize8(op): extent = op.extent.value name = op.loop_var.name lo, li = te.var(name + '.outer'), te.var(name + '.inner') - body = tvm.tir.ir_pass.Substitute(op.body, {op.loop_var: lo * 8 + li}) + body = tvm.tir.stmt_functor.substitute(op.body, {op.loop_var: lo * 8 + li}) body = tvm.tir.For(li, 0, 8, tvm.tir.For.Vectorized, 0, body) body = tvm.tir.For(lo, 0, extent // 8, tvm.tir.For.Serial, 0, body) return body @@ -121,7 +121,7 @@ def vectorize8(op): def vectorize(f, mod, ctx): global loops - tvm.tir.ir_pass.PostOrderVisit(f.body, find_width8) + tvm.tir.stmt_functor.post_order_visit(f.body, find_width8) if not loops: return sf @@ -129,7 +129,7 @@ def vectorize(f, mod, ctx): # The last list arugment indicates what kinds of nodes will be transformed. # Thus, in this case only `For` nodes will call `vectorize8` return f.with_body( - tvm.tir.ir_pass.IRTransform(f.body, None, vectorize8, ['For'])) + tvm.tir.stmt_functor.ir_transform(f.body, None, vectorize8, ['For'])) ##################################################################### @@ -161,8 +161,8 @@ def vectorize(f, mod, ctx): # Quick View # ---------- # This tutorial gives a quick view of writing a customized IR transformation pass: -# - Use ``tvm.tir.ir_pass.PostOrderVisit`` to gather information on each IR nodes. -# - Use ``tvm.tir.ir_pass.IRTransform`` to transform IR nodes. +# - Use ``tvm.tir.stmt_functor.post_order_visit`` to gather information on each IR nodes. +# - Use ``tvm.tir.stmt_functor.ir_transform`` to transform IR nodes. # - Wrap up two above to write an IR-transformation function. # - Use ``tvm.target.build_config`` to put this function to TVM lowering pass # diff --git a/vta/python/vta/transform.py b/vta/python/vta/transform.py index f930b3f1e59c..1d54bb01bb49 100644 --- a/vta/python/vta/transform.py +++ b/vta/python/vta/transform.py @@ -86,14 +86,14 @@ def _post_order(op): raise RuntimeError("unexpected op %s" % op) return op - ret = tvm.tir.ir_pass.IRTransform( + ret = tvm.tir.stmt_functor.ir_transform( stmt.body, None, _post_order, ["Call"]) if not fail[0] and all(x is not None for x in gemm_offsets): def _visit(op): if op.same_as(loop_var): fail[0] = True - tvm.tir.ir_pass.PostOrderVisit(ret, _visit) + tvm.tir.stmt_functor.post_order_visit(ret, _visit) if not fail[0]: begin = tvm.tir.call_extern( "int32", "VTAUopLoopBegin", stmt.extent, *gemm_offsets) @@ -131,7 +131,7 @@ def _do_fold(stmt): return None def _ftransform(f, mod, ctx): - return f.with_body(tvm.tir.ir_pass.IRTransform( + return f.with_body(tvm.tir.stmt_functor.ir_transform( f.body, _do_fold, None, ["AttrStmt"])) return tvm.tir.transform.prim_func_pass( @@ -187,7 +187,7 @@ def _post_order(op): raise RuntimeError("not reached") stmt_in = f.body - stmt = tvm.tir.ir_pass.IRTransform( + stmt = tvm.tir.stmt_functor.ir_transform( stmt_in, None, _post_order, ["Allocate", "Load", "Store"]) for buffer_var, new_var in rw_info.items(): @@ -253,7 +253,7 @@ def _post_order(op): return _merge_block(lift_stmt.pop() + [op], op.body) raise RuntimeError("not reached") stmt_in = f.body - stmt = tvm.tir.ir_pass.IRTransform( + stmt = tvm.tir.stmt_functor.ir_transform( stmt_in, _pre_order, _post_order, ["Allocate", "AttrStmt", "For"]) assert len(lift_stmt) == 1 return f.with_body(_merge_block(lift_stmt[0], stmt)) @@ -276,7 +276,7 @@ def _do_fold(stmt): return None def _ftransform(f, mod, ctx): - return f.with_body(tvm.tir.ir_pass.IRTransform( + return f.with_body(tvm.tir.stmt_functor.ir_transform( f.body, _do_fold, None, ["AttrStmt"])) return tvm.tir.transform.prim_func_pass( @@ -306,7 +306,7 @@ def _do_fold(stmt): op.loop_var, op.min, 2, op.for_type, op.device_api, op.body) return None - return f.with_body(tvm.tir.ir_pass.IRTransform( + return f.with_body(tvm.tir.stmt_functor.ir_transform( f.body, None, _do_fold, ["AttrStmt"])) return tvm.transform.Sequential( [tvm.tir.transform.prim_func_pass(_ftransform, 0, "tir.vta.InjectCoProcSync"), @@ -635,7 +635,7 @@ def _find_basics(op): def _do_fold(op): if _match_pragma(op, "conv2d_transpose_gemm"): is_init = ".init" in str(op) - tvm.tir.ir_pass.PostOrderVisit(op, _find_basics) + tvm.tir.stmt_functor.post_order_visit(op, _find_basics) if is_init: # create inner most block @@ -707,7 +707,7 @@ def _do_fold(op): return inner return None - return func.with_body(tvm.tir.ir_pass.IRTransform( + return func.with_body(tvm.tir.stmt_functor.ir_transform( func.body, _do_fold, None, ["AttrStmt"])) return tvm.tir.transform.prim_func_pass( _ftransform, opt_level=0, name="tir.vta.InjectConv2DTrasnposeSkip") @@ -736,7 +736,7 @@ def _do_fold(stmt): return tvm.tir.Evaluate(0) return stmt - return func.with_body(tvm.tir.ir_pass.IRTransform( + return func.with_body(tvm.tir.stmt_functor.ir_transform( func.body, None, _do_fold, ["AttrStmt"])) return tvm.tir.transform.prim_func_pass( _ftransform, opt_level=0, name="tir.vta.AnnotateALUCoProcScope") @@ -955,7 +955,7 @@ def _flatten_loop(src_coeff, dst_coeff, extents): return irb.get() return stmt - return func.with_body(tvm.tir.ir_pass.IRTransform( + return func.with_body(tvm.tir.stmt_functor.ir_transform( func.body, None, _do_fold, ["AttrStmt"])) return tvm.tir.transform.prim_func_pass(