diff --git a/include/tvm/ir_mutator.h b/include/tvm/ir_mutator.h deleted file mode 100644 index 14769586a959..000000000000 --- a/include/tvm/ir_mutator.h +++ /dev/null @@ -1,127 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/ir_mutator.h - * \brief Defines general IRMutation pass - */ -#ifndef TVM_IR_MUTATOR_H_ -#define TVM_IR_MUTATOR_H_ - -#include -#include -#include "expr.h" -#include "ir.h" -#include "tvm/node/functor.h" - -namespace tvm { -namespace ir { -/*! - * \brief a base class for mutator to iterative mutate the IR - * - * This IRMutator is implemented via Visitor Pattern. - * Also you can implement via NodeFunctor. - * This enables easy extensions of possible new Node. - * It also makes changing return types easier. - * - * \note If you want to return a different type other than Expr and Stmt, - * Simply following the same pattern as IRMutator and create a seperate class. - * \sa NodeFunctor - */ -class TVM_DLL IRMutator { - public: - /*! - * \brief mutate expression - * \return the mutated expr - */ - virtual Expr Mutate(Expr expr) { - static const FMutateExpr& f = vtable_expr(); - return f(expr, expr, this); - } - /*! - * \brief mutate expression - * \return the mutated stmt - */ - virtual Stmt Mutate(Stmt stmt) { - static const FMutateStmt& f = vtable_stmt(); - return f(stmt, stmt, this); - } - /*! \brief destructor */ - virtual ~IRMutator() {} - /*! \brief functor type of expr mutation */ - using FMutateExpr = NodeFunctor; - /*! \brief functor type of stmt mutation */ - using FMutateStmt = NodeFunctor; - /*! \return internal vtable of expr */ - static FMutateExpr& vtable_expr(); // NOLINT(*) - /*! \return internal stmt of expr */ - static FMutateStmt& vtable_stmt(); // NOLINT(*) - // Set of overloadable functions - // The underscore allows Mutate not to be shadowed by inheritance - virtual Stmt Mutate_(const LetStmt* op, const Stmt& s); - virtual Stmt Mutate_(const AttrStmt* op, const Stmt& s); - virtual Stmt Mutate_(const IfThenElse* op, const Stmt& s); - virtual Stmt Mutate_(const For* op, const Stmt& s); - virtual Stmt Mutate_(const Allocate* op, const Stmt& s); - virtual Stmt Mutate_(const Store* op, const Stmt& s); - virtual Stmt Mutate_(const Free* op, const Stmt& s); - virtual Stmt Mutate_(const AssertStmt* op, const Stmt& s); - virtual Stmt Mutate_(const ProducerConsumer* op, const Stmt& s); - virtual Stmt Mutate_(const Provide* op, const Stmt& s); - virtual Stmt Mutate_(const Realize* op, const Stmt& s); - virtual Stmt Mutate_(const Prefetch* op, const Stmt& s); - virtual Stmt Mutate_(const Block* op, const Stmt& s); - virtual Stmt Mutate_(const Evaluate* op, const Stmt& s); - - virtual Expr Mutate_(const Variable* op, const Expr& e); - virtual Expr Mutate_(const Load* op, const Expr& e); - virtual Expr Mutate_(const Let* op, const Expr& e); - virtual Expr Mutate_(const Call* op, const Expr& e); - virtual Expr Mutate_(const Add* op, const Expr& e); - virtual Expr Mutate_(const Sub* op, const Expr& e); - virtual Expr Mutate_(const Mul* op, const Expr& e); - virtual Expr Mutate_(const Div* op, const Expr& e); - virtual Expr Mutate_(const Mod* op, const Expr& e); - virtual Expr Mutate_(const FloorDiv* op, const Expr& e); - virtual Expr Mutate_(const FloorMod* op, const Expr& e); - virtual Expr Mutate_(const Min* op, const Expr& e); - virtual Expr Mutate_(const Max* op, const Expr& e); - virtual Expr Mutate_(const EQ* op, const Expr& e); - virtual Expr Mutate_(const NE* op, const Expr& e); - virtual Expr Mutate_(const LT* op, const Expr& e); - virtual Expr Mutate_(const LE* op, const Expr& e); - virtual Expr Mutate_(const GT* op, const Expr& e); - virtual Expr Mutate_(const GE* op, const Expr& e); - virtual Expr Mutate_(const And* op, const Expr& e); - virtual Expr Mutate_(const Or* op, const Expr& e); - virtual Expr Mutate_(const Reduce* op, const Expr& e); - virtual Expr Mutate_(const Cast* op, const Expr& e); - virtual Expr Mutate_(const Not* op, const Expr& e); - virtual Expr Mutate_(const Select* op, const Expr& e); - virtual Expr Mutate_(const Ramp* op, const Expr& e); - virtual Expr Mutate_(const Broadcast* op, const Expr& e); - virtual Expr Mutate_(const IntImm* op, const Expr& e); - virtual Expr Mutate_(const UIntImm* op, const Expr& e); - virtual Expr Mutate_(const FloatImm* op, const Expr& e); - virtual Expr Mutate_(const StringImm* op, const Expr& e); - virtual Expr Mutate_(const Shuffle* op, const Expr& e); -}; -} // namespace ir -} // namespace tvm -#endif // TVM_IR_MUTATOR_H_ diff --git a/include/tvm/ir_visitor.h b/include/tvm/ir_visitor.h deleted file mode 100644 index e6bd3c6f344d..000000000000 --- a/include/tvm/ir_visitor.h +++ /dev/null @@ -1,151 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/ir_visitor.h - * \brief Visitor to quickly visit IR trees - */ -#ifndef TVM_IR_VISITOR_H_ -#define TVM_IR_VISITOR_H_ - -#include "ir.h" -#include "tvm/node/functor.h" - -namespace tvm { -namespace ir { - -/*! - * \brief a base class for visitor to iterative traverse the IR - * - * This IRVisitor is implemented via NodeFunctor - * This enables extensions of possible new Node. - * - * \sa ExprFunctor, StmtFunctor, PostOrderVisit - * - * \note If you need to return values during Visit: - * - If it is mutation of the IR, use IRMutator - * - If you want to return other things, consider use ExprFunctor/StmtFunctor - * - Watch out for possible bug pattern if you use IRVisitor to simulate returns. - * - * \code - * - * // This is an example code to show cases for traps in IRVisitor - * // The use case is to count number of Variables in the ir tree. - * class MyCounter : public IRVisitor { - * public: - * int Count(const ObjectRef& n) { - * ret_ = 0; - * this->Visit(n); - * return ret_; - * } - * void Visit_(const Variable* op) final { - * ret_ = 1; - * } - * void Visit_(const Add* op) final { - * ret_ = count(op->a) + count(op->b); - * } - - * private: - * int ret_; - * }; - * MyCounter counter; - * Var x("x"); - * // this returns 2 - * CHECK_EQ(counter.Count(x + x), 2); - * // Think what is the result of the following count - * counter.count(Max::make(x, x)); - * // The result is actually 1 - * // This is because Visit is not overriden for Max - * // so it simply calls Visit for the left and right children - * // and because Count is not called, ret_ is not cleared. - * // There can also be cases where ret_ is forgetten to be set. - * - * // These traps may not happen if we program carefully - * // But it is recommended to use ExprFunctor, which allows direct - * // return the value, this helps us to avoid such problems. - * - * \endcode - */ -class TVM_DLL IRVisitor { - public: - /*! - * \brief recursively visit an IR node - */ - virtual void Visit(const ObjectRef& node) { - static const FVisit& f = vtable(); - if (node.defined()) f(node, this); - } - /*! \brief destructor */ - virtual ~IRVisitor() {} - /*! \brief functor type of visitor */ - using FVisit = NodeFunctor; - /*! \return internal vtable*/ - static FVisit& vtable(); - // overloadable visit function. - virtual void Visit_(const Variable* op); - virtual void Visit_(const LetStmt* op); - virtual void Visit_(const AttrStmt* op); - virtual void Visit_(const IfThenElse* op); - virtual void Visit_(const For* op); - virtual void Visit_(const Allocate* op); - virtual void Visit_(const Load* op); - virtual void Visit_(const Store* op); - virtual void Visit_(const Let* op); - virtual void Visit_(const Free* op); - virtual void Visit_(const Call* op); - virtual void Visit_(const Add* op); - virtual void Visit_(const Sub* op); - virtual void Visit_(const Mul* op); - virtual void Visit_(const Div* op); - virtual void Visit_(const Mod* op); - virtual void Visit_(const FloorDiv* op); - virtual void Visit_(const FloorMod* op); - virtual void Visit_(const Min* op); - virtual void Visit_(const Max* op); - virtual void Visit_(const EQ* op); - virtual void Visit_(const NE* op); - virtual void Visit_(const LT* op); - virtual void Visit_(const LE* op); - virtual void Visit_(const GT* op); - virtual void Visit_(const GE* op); - virtual void Visit_(const And* op); - virtual void Visit_(const Or* op); - virtual void Visit_(const Reduce* op); - virtual void Visit_(const Cast* op); - virtual void Visit_(const Not* op); - virtual void Visit_(const Select* op); - virtual void Visit_(const Ramp* op); - virtual void Visit_(const Shuffle* op); - virtual void Visit_(const Broadcast* op); - virtual void Visit_(const AssertStmt* op); - virtual void Visit_(const ProducerConsumer* op); - virtual void Visit_(const Provide* op); - virtual void Visit_(const Realize* op); - virtual void Visit_(const Prefetch* op); - virtual void Visit_(const Block* op); - virtual void Visit_(const Evaluate* op); - virtual void Visit_(const IntImm* op); - virtual void Visit_(const UIntImm* op); - virtual void Visit_(const FloatImm* op); - virtual void Visit_(const StringImm* op); -}; -} // namespace ir -} // namespace tvm - -#endif // TVM_IR_VISITOR_H_ diff --git a/src/autotvm/feature_visitor.h b/src/autotvm/feature_visitor.h index 32d4e092ce75..685becdc4e3f 100644 --- a/src/autotvm/feature_visitor.h +++ b/src/autotvm/feature_visitor.h @@ -51,12 +51,12 @@ enum AnnotationType { class FeatureVisitor : public StmtExprVisitor { public: // for loop - void VisitStmt_(const For *op); - void VisitStmt_(const AttrStmt *op); + void VisitStmt_(const For* op) final; + void VisitStmt_(const AttrStmt* op) final; // memory access - void VisitExpr_(const Load *op); - void VisitStmt_(const Store *op); + void VisitExpr_(const Load* op) final; + void VisitStmt_(const Store* op) final; using StmtExprVisitor::VisitStmt_; using StmtExprVisitor::VisitExpr_; diff --git a/src/autotvm/touch_extractor.cc b/src/autotvm/touch_extractor.cc index fcb1d611c3b0..31c803554c33 100644 --- a/src/autotvm/touch_extractor.cc +++ b/src/autotvm/touch_extractor.cc @@ -51,7 +51,7 @@ class IndexParser: public ExprVisitor { this->VisitExpr(expr); } - void VisitExpr_(const Variable *op) { + void VisitExpr_(const Variable* op) final { // TODO(lmzheng): handle more index types (multiple occurrence) if (pattern_map.count(op) == 0) { pattern_map[op] = TouchPattern(); @@ -60,7 +60,7 @@ class IndexParser: public ExprVisitor { } } - void VisitExpr_(const Mul *op) { + void VisitExpr_(const Mul* op) final { if (op->a.as()) { if (const auto stride = op->b.as()) { next_stride_ = stride->value; diff --git a/src/autotvm/touch_extractor.h b/src/autotvm/touch_extractor.h index 027788cfaf03..b456e4b6f86e 100644 --- a/src/autotvm/touch_extractor.h +++ b/src/autotvm/touch_extractor.h @@ -90,31 +90,31 @@ class TouchExtractor : public FeatureVisitor { } // arithmetic stats - void VisitExpr_(const Add *op) { + void VisitExpr_(const Add* op) final { if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].add_ct++; FeatureVisitor::VisitExpr_(op); } - void VisitExpr_(const Sub *op) { + void VisitExpr_(const Sub* op) final { if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].add_ct++; FeatureVisitor::VisitExpr_(op); } - void VisitExpr_(const Mul *op) { + void VisitExpr_(const Mul* op) final { if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].mul_ct++; FeatureVisitor::VisitExpr_(op); } - void VisitExpr_(const Div *op) { + void VisitExpr_(const Div* op) final { if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].div_ct++; FeatureVisitor::VisitExpr_(op); } - void VisitExpr_(const Mod *op) { + void VisitExpr_(const Mod* op) final { if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].div_ct++; FeatureVisitor::VisitExpr_(op); diff --git a/src/pass/ir_mutator.cc b/src/pass/ir_mutator.cc deleted file mode 100644 index 5ba29fc0b6e9..000000000000 --- a/src/pass/ir_mutator.cc +++ /dev/null @@ -1,482 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file ir_mutator.cc - */ -#include -#include -#include -#include "ir_util.h" - -namespace tvm { -namespace ir { - -IRMutator::FMutateExpr& IRMutator::vtable_expr() { // NOLINT(*) - static FMutateExpr inst; return inst; -} - -IRMutator::FMutateStmt& IRMutator::vtable_stmt() { // NOLINT(*) - static FMutateStmt inst; return inst; -} - -inline Array MutateArray(Array arr, IRMutator* m) { - return UpdateArray(arr, [&m](const Expr& e) { return m->Mutate(e); }); -} - -inline Array MutateIterVarArr(Array rdom, IRMutator* m) { - std::vector new_dom(rdom.size()); - bool changed = false; - for (size_t i = 0; i < rdom.size(); i++) { - IterVar v = rdom[i]; - Range r = v->dom; - Expr new_min = m->Mutate(r->min); - Expr new_extent = m->Mutate(r->extent); - if (!r->min.same_as(new_min)) changed = true; - if (!r->extent.same_as(new_extent)) changed = true; - new_dom[i] = IterVarNode::make( - Range::make_by_min_extent(new_min, new_extent), - v->var, v->iter_type, v->thread_tag); - } - if (!changed) { - return rdom; - } else { - return Array(new_dom); - } -} - - -// Mutate Stmt - -#define DISPATCH_TO_MUTATE_STMT(OP) \ - set_dispatch([](const ObjectRef& node, const Stmt& s, IRMutator* m) { \ - return m->Mutate_(static_cast(node.get()), s); \ - }) - -Stmt IRMutator::Mutate_(const AttrStmt* op, const Stmt& s) { - Expr value = this->Mutate(op->value); - Stmt body = this->Mutate(op->body); - if (value.same_as(op->value) && - body.same_as(op->body)) { - return s; - } else { - return AttrStmt::make(op->node, op->attr_key, value, body); - } -} - -Stmt IRMutator::Mutate_(const LetStmt* op, const Stmt& s) { - Expr value = this->Mutate(op->value); - Stmt body = this->Mutate(op->body); - if (value.same_as(op->value) && - body.same_as(op->body)) { - return s; - } else { - return LetStmt::make(op->var, value, body); - } -} - -Stmt IRMutator::Mutate_(const For* op, const Stmt& s) { - Expr min = this->Mutate(op->min); - Expr extent = this->Mutate(op->extent); - Stmt body = this->Mutate(op->body); - if (min.same_as(op->min) && - extent.same_as(op->extent) && - body.same_as(op->body)) { - return s; - } else { - return For::make( - op->loop_var, min, extent, op->for_type, op->device_api, body); - } -} - -Stmt IRMutator::Mutate_(const Allocate* op, const Stmt& s) { - IRMutator* m = this; - std::vector new_extents; - bool all_extents_unmodified = true; - for (size_t i = 0; i < op->extents.size(); i++) { - new_extents.push_back(m->Mutate(op->extents[i])); - all_extents_unmodified &= new_extents[i].same_as(op->extents[i]); - } - Stmt body = m->Mutate(op->body); - Expr condition = m->Mutate(op->condition); - Expr new_expr; - if (op->new_expr.defined()) { - new_expr = m->Mutate(op->new_expr); - } - if (all_extents_unmodified && - body.same_as(op->body) && - condition.same_as(op->condition) && - new_expr.same_as(op->new_expr)) { - return s; - } else { - return Allocate::make( - op->buffer_var, op->dtype, - new_extents, condition, body, - new_expr, op->free_function); - } -} - -Stmt IRMutator::Mutate_(const IfThenElse* op, const Stmt& s) { - Expr condition = this->Mutate(op->condition); - Stmt then_case = this->Mutate(op->then_case); - Stmt else_case; - if (op->else_case.defined()) { - else_case = this->Mutate(op->else_case); - } - if (condition.same_as(op->condition) && - then_case.same_as(op->then_case) && - else_case.same_as(op->else_case)) { - return s; - } else { - return IfThenElse::make(condition, then_case, else_case); - } -} - -Stmt IRMutator::Mutate_(const Store* op, const Stmt& s) { - Expr value = this->Mutate(op->value); - Expr index = this->Mutate(op->index); - Expr pred = this->Mutate(op->predicate); - if (value.same_as(op->value) && index.same_as(op->index) && pred.same_as(op->predicate)) { - return s; - } else { - return Store::make(op->buffer_var, value, index, pred); - } -} - -Stmt IRMutator::Mutate_(const Provide* op, const Stmt& s) { - auto new_args = MutateArray(op->args, this); - auto new_value = this->Mutate(op->value); - if (op->args.same_as(new_args) && op->value.same_as(new_value)) { - return s; - } else { - return Provide::make(op->func, op->value_index, new_value, new_args); - } -} - -Stmt IRMutator::Mutate_(const Realize* op, const Stmt& s) { - IRMutator* m = this; - Region new_bounds; - bool bounds_changed = false; - - // Mutate the bounds - for (size_t i = 0; i < op->bounds.size(); i++) { - Expr old_min = op->bounds[i]->min; - Expr old_extent = op->bounds[i]->extent; - Expr new_min = m->Mutate(old_min); - Expr new_extent = m->Mutate(old_extent); - if (!new_min.same_as(old_min)) bounds_changed = true; - if (!new_extent.same_as(old_extent)) bounds_changed = true; - new_bounds.push_back( - Range::make_by_min_extent(new_min, new_extent)); - } - - Stmt body = m->Mutate(op->body); - Expr condition = m->Mutate(op->condition); - if (!bounds_changed && - body.same_as(op->body) && - condition.same_as(op->condition)) { - return s; - } else { - return Realize::make(op->func, op->value_index, - op->dtype, new_bounds, - condition, body); - } -} - -Stmt IRMutator::Mutate_(const Prefetch* op, const Stmt& s) { - IRMutator* m = this; - Region new_bounds; - bool bounds_changed = false; - - // Mutate the bounds - for (size_t i = 0; i < op->bounds.size(); i++) { - Expr old_min = op->bounds[i]->min; - Expr old_extent = op->bounds[i]->extent; - Expr new_min = m->Mutate(old_min); - Expr new_extent = m->Mutate(old_extent); - if (!new_min.same_as(old_min)) bounds_changed = true; - if (!new_extent.same_as(old_extent)) bounds_changed = true; - new_bounds.push_back( - Range::make_by_min_extent(new_min, new_extent)); - } - - if (!bounds_changed) { - return s; - } else { - return Prefetch::make(op->func, op->value_index, - op->dtype, new_bounds); - } -} - -Stmt IRMutator::Mutate_(const Block* op, const Stmt& s) { - Stmt first = this->Mutate(op->first); - Stmt rest = this->Mutate(op->rest); - if (first.same_as(op->first) && - rest.same_as(op->rest)) { - return s; - } else { - return Block::make(first, rest); - } -} - -Stmt IRMutator::Mutate_(const AssertStmt* op, const Stmt& s) { - Expr condition = this->Mutate(op->condition); - Expr message = this->Mutate(op->message); - Stmt body = this->Mutate(op->body); - - if (condition.same_as(op->condition) && - message.same_as(op->message) && - body.same_as(op->body)) { - return s; - } else { - return AssertStmt::make(condition, message, body); - } -} - -Stmt IRMutator::Mutate_(const ProducerConsumer* op, const Stmt& s) { - Stmt body = this->Mutate(op->body); - if (body.same_as(op->body)) { - return s; - } else { - return ProducerConsumer::make(op->func, op->is_producer, body); - } -} - -Stmt IRMutator::Mutate_(const Evaluate* op, const Stmt& s) { - Expr v = this->Mutate(op->value); - if (v.same_as(op->value)) { - return s; - } else { - return Evaluate::make(v); - } -} - -Stmt IRMutator::Mutate_(const Free* op, const Stmt& s) { - return s; -} - -TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt) -.DISPATCH_TO_MUTATE_STMT(LetStmt) -.DISPATCH_TO_MUTATE_STMT(AttrStmt) -.DISPATCH_TO_MUTATE_STMT(IfThenElse) -.DISPATCH_TO_MUTATE_STMT(For) -.DISPATCH_TO_MUTATE_STMT(Allocate) -.DISPATCH_TO_MUTATE_STMT(Store) -.DISPATCH_TO_MUTATE_STMT(Free) -.DISPATCH_TO_MUTATE_STMT(AssertStmt) -.DISPATCH_TO_MUTATE_STMT(ProducerConsumer) -.DISPATCH_TO_MUTATE_STMT(Provide) -.DISPATCH_TO_MUTATE_STMT(Realize) -.DISPATCH_TO_MUTATE_STMT(Block) -.DISPATCH_TO_MUTATE_STMT(Evaluate) -.DISPATCH_TO_MUTATE_STMT(Prefetch); - - -// Mutate Expr - -#define DISPATCH_TO_MUTATE_EXPR(OP) \ - set_dispatch([](const ObjectRef& node, const Expr& e, IRMutator* m) { \ - return m->Mutate_(static_cast(node.get()), e); \ - }) - -Expr IRMutator::Mutate_(const Variable* op, const Expr& e) { - return e; -} - -Expr IRMutator::Mutate_(const Load* op, const Expr& e) { - Expr index = this->Mutate(op->index); - Expr pred = this->Mutate(op->predicate); - if (index.same_as(op->index) && pred.same_as(op->predicate)) { - return e; - } else { - return Load::make(op->dtype, op->buffer_var, index, pred); - } -} - -Expr IRMutator::Mutate_(const Let* op, const Expr& e) { - Expr value = this->Mutate(op->value); - Expr body = this->Mutate(op->body); - if (value.same_as(op->value) && - body.same_as(op->body)) { - return e; - } else { - return Let::make(op->var, value, body); - } -} - -Expr IRMutator::Mutate_(const Call* op, const Expr& e) { - auto new_args = MutateArray(op->args, this); - if (op->args.same_as(new_args)) { - return e; - } else { - return Call::make(op->dtype, op->name, new_args, op->call_type, - op->func, op->value_index); - } -} - -#define DEFINE_BIOP_EXPR_MUTATE_(OP) \ - Expr IRMutator::Mutate_(const OP* op, const Expr& e) { \ - Expr a = this->Mutate(op->a); \ - Expr b = this->Mutate(op->b); \ - if (a.same_as(op->a) && \ - b.same_as(op->b)) { \ - return e; \ - } else { \ - return OP::make(a, b); \ - } \ - } - -DEFINE_BIOP_EXPR_MUTATE_(Add) -DEFINE_BIOP_EXPR_MUTATE_(Sub) -DEFINE_BIOP_EXPR_MUTATE_(Mul) -DEFINE_BIOP_EXPR_MUTATE_(Div) -DEFINE_BIOP_EXPR_MUTATE_(Mod) -DEFINE_BIOP_EXPR_MUTATE_(FloorDiv) -DEFINE_BIOP_EXPR_MUTATE_(FloorMod) -DEFINE_BIOP_EXPR_MUTATE_(Min) -DEFINE_BIOP_EXPR_MUTATE_(Max) -DEFINE_BIOP_EXPR_MUTATE_(EQ) -DEFINE_BIOP_EXPR_MUTATE_(NE) -DEFINE_BIOP_EXPR_MUTATE_(LT) -DEFINE_BIOP_EXPR_MUTATE_(LE) -DEFINE_BIOP_EXPR_MUTATE_(GT) -DEFINE_BIOP_EXPR_MUTATE_(GE) -DEFINE_BIOP_EXPR_MUTATE_(And) -DEFINE_BIOP_EXPR_MUTATE_(Or) - -Expr IRMutator::Mutate_(const Reduce* op, const Expr& e) { - Array new_axis = MutateIterVarArr(op->axis, this); - Array new_source = MutateArray(op->source, this); - Expr new_cond = this->Mutate(op->condition); - if (op->axis.same_as(new_axis) && - op->source.same_as(new_source) && - op->condition.same_as(new_cond)) { - return e; - } else { - return Reduce::make( - op->combiner, new_source, new_axis, new_cond, op->value_index); - } -} - -Expr IRMutator::Mutate_(const Cast* op, const Expr& e) { - Expr value = this->Mutate(op->value); - if (value.same_as(op->value)) { - return e; - } else { - return Cast::make(op->dtype, value); - } -} - -Expr IRMutator::Mutate_(const Not* op, const Expr& e) { - Expr a = this->Mutate(op->a); - if (a.same_as(op->a)) { - return e; - } else { - return Not::make(a); - } -} - -Expr IRMutator::Mutate_(const Select* op, const Expr& e) { - Expr cond = this->Mutate(op->condition); - Expr t = this->Mutate(op->true_value); - Expr f = this->Mutate(op->false_value); - if (cond.same_as(op->condition) && - t.same_as(op->true_value) && - f.same_as(op->false_value)) { - return e; - } else { - return Select::make(cond, t, f); - } -} - -Expr IRMutator::Mutate_(const Ramp* op, const Expr& e) { - Expr base = this->Mutate(op->base); - Expr stride = this->Mutate(op->stride); - if (base.same_as(op->base) && - stride.same_as(op->stride)) { - return e; - } else { - return Ramp::make(base, stride, op->lanes); - } -} - -Expr IRMutator::Mutate_(const Broadcast* op, const Expr& e) { - Expr value = this->Mutate(op->value); - if (value.same_as(op->value)) { - return e; - } else { - return Broadcast::make(value, op->lanes); - } -} - -Expr IRMutator::Mutate_(const Shuffle* op, const Expr& e) { - auto new_vec = MutateArray(op->vectors, this); - if (new_vec.same_as(op->vectors)) { - return e; - } else { - return Shuffle::make(new_vec, op->indices); - } -} - -#define DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(OP) \ - Expr IRMutator::Mutate_(const OP *op, const Expr& e) { \ - return e; \ - } - -DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(IntImm) -DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(UIntImm) -DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(FloatImm) -DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(StringImm) - -TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) -.DISPATCH_TO_MUTATE_EXPR(Variable) -.DISPATCH_TO_MUTATE_EXPR(Load) -.DISPATCH_TO_MUTATE_EXPR(Let) -.DISPATCH_TO_MUTATE_EXPR(Call) -.DISPATCH_TO_MUTATE_EXPR(Add) -.DISPATCH_TO_MUTATE_EXPR(Sub) -.DISPATCH_TO_MUTATE_EXPR(Mul) -.DISPATCH_TO_MUTATE_EXPR(Div) -.DISPATCH_TO_MUTATE_EXPR(Mod) -.DISPATCH_TO_MUTATE_EXPR(FloorDiv) -.DISPATCH_TO_MUTATE_EXPR(FloorMod) -.DISPATCH_TO_MUTATE_EXPR(Min) -.DISPATCH_TO_MUTATE_EXPR(Max) -.DISPATCH_TO_MUTATE_EXPR(EQ) -.DISPATCH_TO_MUTATE_EXPR(NE) -.DISPATCH_TO_MUTATE_EXPR(LT) -.DISPATCH_TO_MUTATE_EXPR(LE) -.DISPATCH_TO_MUTATE_EXPR(GT) -.DISPATCH_TO_MUTATE_EXPR(GE) -.DISPATCH_TO_MUTATE_EXPR(And) -.DISPATCH_TO_MUTATE_EXPR(Or) -.DISPATCH_TO_MUTATE_EXPR(Reduce) -.DISPATCH_TO_MUTATE_EXPR(Cast) -.DISPATCH_TO_MUTATE_EXPR(Not) -.DISPATCH_TO_MUTATE_EXPR(Select) -.DISPATCH_TO_MUTATE_EXPR(Ramp) -.DISPATCH_TO_MUTATE_EXPR(Broadcast) -.DISPATCH_TO_MUTATE_EXPR(IntImm) -.DISPATCH_TO_MUTATE_EXPR(UIntImm) -.DISPATCH_TO_MUTATE_EXPR(FloatImm) -.DISPATCH_TO_MUTATE_EXPR(StringImm) -.DISPATCH_TO_MUTATE_EXPR(Shuffle); - -} // namespace ir -} // namespace tvm diff --git a/src/pass/ir_visitor.cc b/src/pass/ir_visitor.cc deleted file mode 100644 index af6ea5252166..000000000000 --- a/src/pass/ir_visitor.cc +++ /dev/null @@ -1,272 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file ir_visitor.cc - */ -#include -#include -#include - -namespace tvm { -namespace ir { - -IRVisitor::FVisit& IRVisitor::vtable() { // NOLINT(*) - static FVisit inst; return inst; -} - -inline void VisitArray(const Array& arr, IRVisitor* v) { - for (size_t i = 0; i < arr.size(); i++) { - v->Visit(arr[i]); - } -} - -inline void VisitRDom(const Array& rdom, IRVisitor* v) { - for (size_t i = 0; i < rdom.size(); i++) { - Range r = rdom[i]->dom; - v->Visit(r->min); - v->Visit(r->extent); - } -} - -void IRVisitor::Visit_(const Variable* op) {} - -void IRVisitor::Visit_(const LetStmt* op) { - this->Visit(op->value); - this->Visit(op->body); -} - -void IRVisitor::Visit_(const AttrStmt* op) { - this->Visit(op->value); - this->Visit(op->body); -} - -void IRVisitor::Visit_(const For* op) { - IRVisitor* v = this; - v->Visit(op->min); - v->Visit(op->extent); - v->Visit(op->body); -} - -void IRVisitor::Visit_(const Allocate* op) { - IRVisitor* v = this; - for (size_t i = 0; i < op->extents.size(); i++) { - v->Visit(op->extents[i]); - } - v->Visit(op->body); - v->Visit(op->condition); - if (op->new_expr.defined()) { - v->Visit(op->new_expr); - } -} - -void IRVisitor::Visit_(const Load* op) { - this->Visit(op->index); - this->Visit(op->predicate); -} - -void IRVisitor::Visit_(const Store* op) { - this->Visit(op->value); - this->Visit(op->index); - this->Visit(op->predicate); -} - -void IRVisitor::Visit_(const IfThenElse* op) { - this->Visit(op->condition); - this->Visit(op->then_case); - if (op->else_case.defined()) { - this->Visit(op->else_case); - } -} - -void IRVisitor::Visit_(const Let* op) { - this->Visit(op->value); - this->Visit(op->body); -} - -void IRVisitor::Visit_(const Free* op) {} - -void IRVisitor::Visit_(const Call* op) { - VisitArray(op->args, this); -} - -#define DEFINE_BINOP_VISIT_(OP) \ - void IRVisitor::Visit_(const OP* op) { \ - this->Visit(op->a); \ - this->Visit(op->b); \ - } - -DEFINE_BINOP_VISIT_(Add) -DEFINE_BINOP_VISIT_(Sub) -DEFINE_BINOP_VISIT_(Mul) -DEFINE_BINOP_VISIT_(Div) -DEFINE_BINOP_VISIT_(Mod) -DEFINE_BINOP_VISIT_(FloorDiv) -DEFINE_BINOP_VISIT_(FloorMod) -DEFINE_BINOP_VISIT_(Min) -DEFINE_BINOP_VISIT_(Max) -DEFINE_BINOP_VISIT_(EQ) -DEFINE_BINOP_VISIT_(NE) -DEFINE_BINOP_VISIT_(LT) -DEFINE_BINOP_VISIT_(LE) -DEFINE_BINOP_VISIT_(GT) -DEFINE_BINOP_VISIT_(GE) -DEFINE_BINOP_VISIT_(And) -DEFINE_BINOP_VISIT_(Or) - -void IRVisitor::Visit_(const Reduce* op) { - VisitRDom(op->axis, this); - VisitArray(op->source, this); - this->Visit(op->condition); -} - -void IRVisitor::Visit_(const Cast* op) { - this->Visit(op->value); -} - -void IRVisitor::Visit_(const Not* op) { - this->Visit(op->a); -} - -void IRVisitor::Visit_(const Select* op) { - this->Visit(op->condition); - this->Visit(op->true_value); - this->Visit(op->false_value); -} - -void IRVisitor::Visit_(const Ramp* op) { - this->Visit(op->base); - this->Visit(op->stride); -} - -void IRVisitor::Visit_(const Shuffle* op) { - for (const auto& elem : op->indices) - this->Visit(elem); - for (const auto& elem : op->vectors) - this->Visit(elem); -} - -void IRVisitor::Visit_(const Broadcast* op) { - this->Visit(op->value); -} - -void IRVisitor::Visit_(const AssertStmt* op) { - this->Visit(op->condition); - this->Visit(op->message); - this->Visit(op->body); -} - -void IRVisitor::Visit_(const ProducerConsumer* op) { - this->Visit(op->body); -} - -void IRVisitor::Visit_(const Provide* op) { - VisitArray(op->args, this); - this->Visit(op->value); -} - -void IRVisitor::Visit_(const Realize* op) { - for (size_t i = 0; i < op->bounds.size(); i++) { - this->Visit(op->bounds[i]->min); - this->Visit(op->bounds[i]->extent); - } - - this->Visit(op->body); - this->Visit(op->condition); -} - -void IRVisitor::Visit_(const Prefetch* op) { - for (size_t i = 0; i < op->bounds.size(); i++) { - this->Visit(op->bounds[i]->min); - this->Visit(op->bounds[i]->extent); - } -} - -void IRVisitor::Visit_(const Block* op) { - this->Visit(op->first); - this->Visit(op->rest); -} - -void IRVisitor::Visit_(const Evaluate* op) { - this->Visit(op->value); -} - -#define DEFINE_OP_NO_VISIT_(OP) \ - void IRVisitor::Visit_(const OP* op) {} - -DEFINE_OP_NO_VISIT_(IntImm) -DEFINE_OP_NO_VISIT_(UIntImm) -DEFINE_OP_NO_VISIT_(FloatImm) -DEFINE_OP_NO_VISIT_(StringImm) - -#define DISPATCH_TO_VISIT(OP) \ - set_dispatch([](const ObjectRef& node, IRVisitor* v) { \ - v->Visit_(static_cast(node.get())); \ - }) - -TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) -.DISPATCH_TO_VISIT(Variable) -.DISPATCH_TO_VISIT(LetStmt) -.DISPATCH_TO_VISIT(AttrStmt) -.DISPATCH_TO_VISIT(IfThenElse) -.DISPATCH_TO_VISIT(For) -.DISPATCH_TO_VISIT(Allocate) -.DISPATCH_TO_VISIT(Load) -.DISPATCH_TO_VISIT(Store) -.DISPATCH_TO_VISIT(Let) -.DISPATCH_TO_VISIT(Free) -.DISPATCH_TO_VISIT(Call) -.DISPATCH_TO_VISIT(Add) -.DISPATCH_TO_VISIT(Sub) -.DISPATCH_TO_VISIT(Mul) -.DISPATCH_TO_VISIT(Div) -.DISPATCH_TO_VISIT(Mod) -.DISPATCH_TO_VISIT(FloorDiv) -.DISPATCH_TO_VISIT(FloorMod) -.DISPATCH_TO_VISIT(Min) -.DISPATCH_TO_VISIT(Max) -.DISPATCH_TO_VISIT(EQ) -.DISPATCH_TO_VISIT(NE) -.DISPATCH_TO_VISIT(LT) -.DISPATCH_TO_VISIT(LE) -.DISPATCH_TO_VISIT(GT) -.DISPATCH_TO_VISIT(GE) -.DISPATCH_TO_VISIT(And) -.DISPATCH_TO_VISIT(Or) -.DISPATCH_TO_VISIT(Reduce) -.DISPATCH_TO_VISIT(Cast) -.DISPATCH_TO_VISIT(Not) -.DISPATCH_TO_VISIT(Select) -.DISPATCH_TO_VISIT(Ramp) -.DISPATCH_TO_VISIT(Shuffle) -.DISPATCH_TO_VISIT(Broadcast) -.DISPATCH_TO_VISIT(AssertStmt) -.DISPATCH_TO_VISIT(ProducerConsumer) -.DISPATCH_TO_VISIT(Provide) -.DISPATCH_TO_VISIT(Realize) -.DISPATCH_TO_VISIT(Block) -.DISPATCH_TO_VISIT(Evaluate) -.DISPATCH_TO_VISIT(IntImm) -.DISPATCH_TO_VISIT(UIntImm) -.DISPATCH_TO_VISIT(FloatImm) -.DISPATCH_TO_VISIT(StringImm) -.DISPATCH_TO_VISIT(Prefetch); - -} // namespace ir -} // namespace tvm diff --git a/tests/cpp/ir_functor_test.cc b/tests/cpp/ir_functor_test.cc index ea854f52518e..a10ddd413c20 100644 --- a/tests/cpp/ir_functor_test.cc +++ b/tests/cpp/ir_functor_test.cc @@ -41,6 +41,19 @@ TEST(IRF, Basic) { CHECK_EQ(f(z, 2), 4); } +TEST(IRF, CountVar) { + using namespace tvm; + int n_var = 0; + Var x("x"), y; + + auto z = x + 1 + y + y; + ir::PostOrderVisit(z, [&n_var](const ObjectRef& n) { + if (n.as()) ++n_var; + }); + CHECK_EQ(n_var, 2); +} + + TEST(IRF, ExprTransform) { using namespace tvm; using namespace tvm::ir; diff --git a/tests/cpp/ir_mutator_test.cc b/tests/cpp/ir_mutator_test.cc deleted file mode 100644 index 6f73b5fd06ff..000000000000 --- a/tests/cpp/ir_mutator_test.cc +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include -#include -#include -#include - -namespace { -using namespace tvm; -using namespace tvm::ir; - -// replace variable to constant -class IRVar2Const : public IRMutator { - public: - Var var; - int int_val; - Expr Mutate(Expr expr) final { - static const FMutateExpr& f = IRVar2Const::vtable_expr(); - return (f.can_dispatch(expr) ? - f(expr, expr, this) : IRMutator::Mutate(expr)); - } - static FMutateExpr &vtable_expr(); -}; - -// implement vtable -IRMutator::FMutateExpr &IRVar2Const::vtable_expr() { // NOLINT(*) - static FMutateExpr inst; return inst; -} - -TVM_STATIC_IR_FUNCTOR(IRVar2Const, vtable_expr) -.set_dispatch([](const ObjectRef& ref, const Expr &e, IRMutator* m) { - IRVar2Const* vm = static_cast(m); - if (e.same_as(vm->var)) { - return Expr(IntImm::make(DataType::Int(32), vm->int_val)); - } else { - return e; - } - }); - -} // namespace - -TEST(IRMutator, Basic) { - using namespace tvm::ir; - using namespace tvm; - Var x("x"), y; - auto z = x + y; - IRVar2Const mu; - mu.var = y; - mu.int_val = 10; - auto zz = mu.Mutate(z); - std::ostringstream os; - os << zz; - CHECK(os.str() == "(x + 10)"); -} - -int main(int argc, char ** argv) { - testing::InitGoogleTest(&argc, argv); - testing::FLAGS_gtest_death_test_style = "threadsafe"; - return RUN_ALL_TESTS(); -} diff --git a/tests/cpp/ir_visitor_test.cc b/tests/cpp/ir_visitor_test.cc deleted file mode 100644 index 1f34b2549d0d..000000000000 --- a/tests/cpp/ir_visitor_test.cc +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include -#include -#include -#include -#include - -TEST(IRVisitor, CountVar) { - using namespace tvm; - int n_var = 0; - Var x("x"), y; - - auto z = x + 1 + y + y; - ir::PostOrderVisit(z, [&n_var](const ObjectRef& n) { - if (n.as()) ++n_var; - }); - CHECK_EQ(n_var, 2); -} - -int main(int argc, char ** argv) { - testing::InitGoogleTest(&argc, argv); - testing::FLAGS_gtest_death_test_style = "threadsafe"; - return RUN_ALL_TESTS(); -}