From 9eef824113858e900bec6c1d33a92972f6834e5b Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 31 Dec 2019 14:30:51 -0800 Subject: [PATCH] [IR] Unify approach to Visitor/Mutator under Functor IRMutator and IRVisitor were the main data structures for doing low level IR visiting. As the project evolves, we start to introduce more powerful variants such as StmtFunctor and ExprFunctor. This PR brings new classes that allows us to migrate the visitor mutator to be sub-class of these functors. List of changes: - Create separate class for ExprMutator and StmtMutator, following convention used in relay. - Introduce copy-on-write to StmtMutator that can later benefit the statement mutations if we use move semantics and keep a single copy of stmt. - Move two generic visit mutate util to use the new classes. We will send followup PRs to migrate the existing passes that use the legacy visitors to the new one. --- include/tvm/ir_functor_ext.h | 259 +++++++++++++ include/tvm/ir_mutator.h | 3 +- include/tvm/node/container.h | 42 ++ src/pass/ir_functor.cc | 717 +++++++++++++++++++++++++++++++++++ src/pass/ir_mutator.cc | 53 --- src/pass/ir_visitor.cc | 20 - tests/cpp/ir_functor_test.cc | 93 ++++- 7 files changed, 1112 insertions(+), 75 deletions(-) create mode 100644 src/pass/ir_functor.cc diff --git a/include/tvm/ir_functor_ext.h b/include/tvm/ir_functor_ext.h index 9b2632f87b3c..f7a0f2a3b61c 100644 --- a/include/tvm/ir_functor_ext.h +++ b/include/tvm/ir_functor_ext.h @@ -287,6 +287,265 @@ class StmtFunctor { #undef EXPR_FUNCTOR_DEFAULT #undef STMT_FUNCTOR_DEFAULT +/*! + * \brief ExprVisitor + */ +class TVM_DLL ExprVisitor : + public ExprFunctor { + public: + using ExprFunctor::operator(); + + protected: + using ExprFunctor::VisitExpr; + // list of functions to override. + void VisitExpr_(const Variable* op) override; + void VisitExpr_(const Load* op) override; + void VisitExpr_(const Let* op) override; + void VisitExpr_(const Call* op) override; + void VisitExpr_(const Add* op) override; + void VisitExpr_(const Sub* op) override; + void VisitExpr_(const Mul* op) override; + void VisitExpr_(const Div* op) override; + void VisitExpr_(const Mod* op) override; + void VisitExpr_(const FloorDiv* op) override; + void VisitExpr_(const FloorMod* op) override; + void VisitExpr_(const Min* op) override; + void VisitExpr_(const Max* op) override; + void VisitExpr_(const EQ* op) override; + void VisitExpr_(const NE* op) override; + void VisitExpr_(const LT* op) override; + void VisitExpr_(const LE* op) override; + void VisitExpr_(const GT* op) override; + void VisitExpr_(const GE* op) override; + void VisitExpr_(const And* op) override; + void VisitExpr_(const Or* op) override; + void VisitExpr_(const Reduce* op) override; + void VisitExpr_(const Cast* op) override; + void VisitExpr_(const Not* op) override; + void VisitExpr_(const Select* op) override; + void VisitExpr_(const Ramp* op) override; + void VisitExpr_(const Broadcast* op) override; + void VisitExpr_(const Shuffle* op) override; + void VisitExpr_(const IntImm* op) override; + void VisitExpr_(const UIntImm* op) override; + void VisitExpr_(const FloatImm* op) override; + void VisitExpr_(const StringImm* op) override; +}; + +/*! + * \brief ExprMutator that mutates expressions. + */ +class TVM_DLL ExprMutator : + protected ExprFunctor { + public: + using ExprFunctor::operator(); + + protected: + using ExprFunctor::VisitExpr; + // list of functions to override. + Expr VisitExpr_(const Variable* op) override; + Expr VisitExpr_(const Load* op) override; + Expr VisitExpr_(const Let* op) override; + Expr VisitExpr_(const Call* op) override; + Expr VisitExpr_(const Add* op) override; + Expr VisitExpr_(const Sub* op) override; + Expr VisitExpr_(const Mul* op) override; + Expr VisitExpr_(const Div* op) override; + Expr VisitExpr_(const Mod* op) override; + Expr VisitExpr_(const FloorDiv* op) override; + Expr VisitExpr_(const FloorMod* op) override; + Expr VisitExpr_(const Min* op) override; + Expr VisitExpr_(const Max* op) override; + Expr VisitExpr_(const EQ* op) override; + Expr VisitExpr_(const NE* op) override; + Expr VisitExpr_(const LT* op) override; + Expr VisitExpr_(const LE* op) override; + Expr VisitExpr_(const GT* op) override; + Expr VisitExpr_(const GE* op) override; + Expr VisitExpr_(const And* op) override; + Expr VisitExpr_(const Or* op) override; + Expr VisitExpr_(const Reduce* op) override; + Expr VisitExpr_(const Cast* op) override; + Expr VisitExpr_(const Not* op) override; + Expr VisitExpr_(const Select* op) override; + Expr VisitExpr_(const Ramp* op) override; + Expr VisitExpr_(const Broadcast* op) override; + Expr VisitExpr_(const Shuffle* op) override; + Expr VisitExpr_(const IntImm* op) override; + Expr VisitExpr_(const UIntImm* op) override; + Expr VisitExpr_(const FloatImm* op) override; + Expr VisitExpr_(const StringImm* op) override; +}; + +/*! + * \brief StmtVisitor. + */ +class TVM_DLL StmtVisitor : + protected StmtFunctor { + public: + using StmtFunctor::operator(); + + protected: + using StmtFunctor::VisitStmt; + /*! + * \brief Visitor to Exprs, can be overriden + * to do recursive changes to Exprs. + * \note A common pattern is to call ExprVisitor here, + * or have a class sub-class both StmtVisitor and ExprVisitor + * and redirect Visit to ExprMutator::VisitExpr(Expr) + */ + virtual void VisitExpr(const Expr& e) {} + // statement visitor + void VisitStmt_(const AttrStmt* op) override; + void VisitStmt_(const IfThenElse* op) override; + void VisitStmt_(const LetStmt* op) override; + void VisitStmt_(const For* op) override; + void VisitStmt_(const Allocate* op) override; + void VisitStmt_(const Store* op) override; + void VisitStmt_(const Free* op) override; + void VisitStmt_(const AssertStmt* op) override; + void VisitStmt_(const ProducerConsumer* op) override; + void VisitStmt_(const Provide* op) override; + void VisitStmt_(const Realize* op) override; + void VisitStmt_(const Prefetch* op) override; + void VisitStmt_(const Block* op) override; + void VisitStmt_(const Evaluate* op) override; +}; + +/*! + * \brief StmtMutator that mutates the statements. + */ +class TVM_DLL StmtMutator : + protected StmtFunctor { + public: + /*! + * \brief Mutate stmt. + * \param stmt The input statement to be mutated. + * \return The result of the call + * \note It is important that stmt is passed by value. + * so copy on write can be triggered correctly. + * do mutator(std::move(stmt)) or when copy elison is triggered. + */ + Stmt operator()(Stmt stmt) { + allow_copy_on_write_ = true; + return VisitStmt(stmt); + } + + protected: + // We perform copy on write optimizations on the StmtMutator + // so that an unique copy of parent can be mutated inplace + // when some of its children changed. + // We only do such optimization for Stmt nests(instead of Exprs) for now + // as Stmt's parent state is more likely remain unchanged when one of + // its child block changes. + /*! + * \brief Internal state to indicate whether copy on write is enabled. + * COW is enabled iff all the parents of the node are unique. + */ + bool allow_copy_on_write_{false}; + /*! + * \brief Perform copy on write on node. + * + * If CopyOnWrite is allowed, directly return + * a strong reference to the node container. + * Otherwise, return a copy of the node. + * + * \return The result object pointer. + */ + template + ObjectPtr CopyOnWrite(const TNode* node) { + if (allow_copy_on_write_) { + // return the old node. + return runtime::GetObjectPtr(const_cast(node)); + } else { + // Make a new copy of the node. + // need to rely on the default copy constructor + return runtime::make_object(*node); + } + } + /*! + * \brief Internal mutator that everyone calls. + * \note To override mutate's behavior, override VisitExpr instead. + * \param stmt The input stmt. + * \return The mutated results. + */ + Stmt VisitStmt(const Stmt& stmt) override { + if (allow_copy_on_write_ && !stmt.unique()) { + allow_copy_on_write_ = false; + Stmt ret = StmtFunctor::VisitStmt(stmt); + allow_copy_on_write_ = true; + return ret; + } else { + return StmtFunctor::VisitStmt(stmt); + } + } + /*! + * \brief Visitor to Exprs, can be overriden + * to do recursive changes to Exprs. + * \note A common pattern is to call ExprMutator here, + * or have a class sub-class both StmtMutator and ExprMutator + * and redirect Mutate to ExprMutator::Mutate(Expr) + */ + virtual Expr VisitExpr(const Expr& e) { + return e; + } + // statement visitor + Stmt VisitStmt_(const AttrStmt* op) override; + Stmt VisitStmt_(const IfThenElse* op) override; + Stmt VisitStmt_(const LetStmt* op) override; + Stmt VisitStmt_(const For* op) override; + Stmt VisitStmt_(const Allocate* op) override; + Stmt VisitStmt_(const Store* op) override; + Stmt VisitStmt_(const Free* op) override; + Stmt VisitStmt_(const AssertStmt* op) override; + Stmt VisitStmt_(const ProducerConsumer* op) override; + Stmt VisitStmt_(const Provide* op) override; + Stmt VisitStmt_(const Realize* op) override; + Stmt VisitStmt_(const Prefetch* op) override; + Stmt VisitStmt_(const Block* op) override; + Stmt VisitStmt_(const Evaluate* op) override; + // internal helper. + class Internal; +}; + +/*! + * \brief Visitor that recursively visit stmts and exprs on them. + */ +class StmtExprVisitor : + public StmtVisitor, + public ExprVisitor { + public: + using StmtVisitor::operator(); + using ExprVisitor::operator(); + + protected: + using StmtVisitor::VisitStmt; + using ExprVisitor::VisitExpr; + + void VisitExpr(const Expr& e) override { + return ExprVisitor::VisitExpr(e); + } +}; + +/*! + * \brief Mutator that recursively mutates stmts and exprs on them. + */ +class StmtExprMutator : + public StmtMutator, + public ExprMutator { + public: + using StmtMutator::operator(); + using ExprMutator::operator(); + + protected: + using StmtMutator::VisitExpr; + using ExprMutator::VisitExpr; + + Expr VisitExpr(const Expr& e) override { + return ExprMutator::VisitExpr(e); + } +}; + } // namespace ir } // namespace tvm #endif // TVM_IR_FUNCTOR_EXT_H_ diff --git a/include/tvm/ir_mutator.h b/include/tvm/ir_mutator.h index 5460ae0f4ba9..702cea3ce8fd 100644 --- a/include/tvm/ir_mutator.h +++ b/include/tvm/ir_mutator.h @@ -123,6 +123,7 @@ class TVM_DLL IRMutator { virtual Expr Mutate_(const Shuffle* op, const Expr& e); }; + /*! * \brief recursively visit the ir in post DFS order node, and transform it * @@ -138,7 +139,7 @@ class TVM_DLL IRMutator { * If it is not empty, preorder/postorder will only be called * when the IRNode's type key is in the list. */ -Stmt IRTransform(const Stmt& node, +Stmt IRTransform(Stmt node, const runtime::PackedFunc& preorder, const runtime::PackedFunc& postorder, const Array& only_enable = {}); diff --git a/include/tvm/node/container.h b/include/tvm/node/container.h index d20fb288039c..0d5ab376d50a 100644 --- a/include/tvm/node/container.h +++ b/include/tvm/node/container.h @@ -284,6 +284,48 @@ class Array : public ObjectRef { inline bool empty() const { return size() == 0; } + /*! + * \brief Helper function to apply fmutate to mutate an array. + * \param fmutate The transformation function T -> T. + * \tparam F the type of the mutation function. + * \note This function performs copy on write optimization. + */ + template + inline void MutateByApply(F fmutate) { + ArrayNode* ptr = static_cast(data_.get()); + if (ptr == nullptr) return; + if (data_.unique()) { + // Copy on write optimization. + // Perform inplace update because this is an unique copy. + for (size_t i = 0; i < ptr->data.size(); ++i) { + // It is important to use move here + // to make prevent the element's ref count from increasing + // so fmutate itself can perform copy-on-write optimization + T old_elem = DowncastNoCheck(std::move(ptr->data[i])); + T new_elem = fmutate(std::move(old_elem)); + ptr->data[i] = std::move(new_elem); + } + } else { + // lazily trigger copy if there is element change. + ObjectPtr copy; + for (size_t i = 0; i < ptr->data.size(); ++i) { + T old_elem = DowncastNoCheck(ptr->data[i]); + T new_elem = fmutate(old_elem); + if (!new_elem.same_as(ptr->data[i])) { + // copy the old array + if (copy == nullptr) { + copy = runtime::make_object(*ptr); + } + copy->data[i] = std::move(new_elem); + } + } + // replace the data with the new copy. + if (copy != nullptr) { + data_ = std::move(copy); + } + } + } + /*! \brief specify container node */ using ContainerType = ArrayNode; diff --git a/src/pass/ir_functor.cc b/src/pass/ir_functor.cc new file mode 100644 index 000000000000..079da756ce92 --- /dev/null +++ b/src/pass/ir_functor.cc @@ -0,0 +1,717 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file ir_functor.cc + */ +#include +#include +#include + +namespace tvm { +namespace ir { + +// visitor to implement apply +class IRApplyVisit : + public StmtExprVisitor { + public: + explicit IRApplyVisit(std::function f) : f_(f) {} + + void VisitExpr(const Expr& node) final { + if (visited_.count(node.get()) != 0) return; + visited_.insert(node.get()); + ExprVisitor::VisitExpr(node); + f_(node); + } + + void VisitStmt(const Stmt& node) final { + if (visited_.count(node.get()) != 0) return; + visited_.insert(node.get()); + StmtVisitor::VisitStmt(node); + f_(node); + } + + private: + std::function f_; + std::unordered_set visited_; +}; + +void PostOrderVisit(const ObjectRef& node, + std::function fvisit) { + if (node.as()) { + IRApplyVisit visitor(fvisit); + visitor(Downcast(node)); + } else { + IRApplyVisit visitor(fvisit); + visitor(Downcast(node)); + } +} + +class IRTransformer final : + public StmtExprMutator { + public: + IRTransformer(const runtime::PackedFunc& f_preorder, + const runtime::PackedFunc& f_postorder, + const std::unordered_set& only_enable) + : f_preorder_(f_preorder), + f_postorder_(f_postorder), + only_enable_(only_enable) { + } + Stmt VisitStmt(const Stmt& stmt) final { + return MutateInternal(stmt, [this](const Stmt& s) { + return StmtMutator::VisitStmt(s); + }); + } + Expr VisitExpr(const Expr& expr) final { + return MutateInternal(expr, [this](const Expr& e) { + return ExprMutator::VisitExpr(e); + }); + } + + private: + template + T MutateInternal(const T& node, F fmutate) { + if (only_enable_.size() && + !only_enable_.count(node->type_index())) { + return fmutate(node); + } + if (f_preorder_ != nullptr) { + T pre = f_preorder_(node); + if (pre.defined()) return pre; + } + T new_node = fmutate(node); + if (f_postorder_ != nullptr) { + T post = f_postorder_(new_node); + if (post.defined()) return post; + } + return new_node; + } + // The functions + const runtime::PackedFunc& f_preorder_; + const runtime::PackedFunc& f_postorder_; + // type indices enabled. + const std::unordered_set& only_enable_; +}; + +Stmt IRTransform(Stmt ir_node, + const runtime::PackedFunc& f_preorder, + const runtime::PackedFunc& f_postorder, + const Array& only_enable) { + std::unordered_set only_type_index; + for (Expr s : only_enable) { + only_type_index.insert(Object::TypeKey2Index(s.as()->value.c_str())); + } + IRTransformer transform(f_preorder, f_postorder, only_type_index); + return transform(std::move(ir_node)); +} + +// Implementation of Visitors +template +inline void VisitArray(const Array& arr, F fvisit) { + for (size_t i = 0; i < arr.size(); i++) { + fvisit(arr[i]); + } +} + +void StmtVisitor::VisitStmt_(const LetStmt* op) { + this->VisitExpr(op->value); + this->VisitStmt(op->body); +} + +void StmtVisitor::VisitStmt_(const AttrStmt* op) { + this->VisitExpr(op->value); + this->VisitStmt(op->body); +} + +void StmtVisitor::VisitStmt_(const For* op) { + this->VisitExpr(op->min); + this->VisitExpr(op->extent); + this->VisitStmt(op->body); +} + +void StmtVisitor::VisitStmt_(const Allocate* op) { + VisitArray(op->extents, [this](const Expr& e) { this->VisitExpr(e); }); + this->VisitStmt(op->body); + this->VisitExpr(op->condition); + if (op->new_expr.defined()) { + this->VisitExpr(op->new_expr); + } +} + +void StmtVisitor::VisitStmt_(const Store* op) { + this->VisitExpr(op->value); + this->VisitExpr(op->index); + this->VisitExpr(op->predicate); +} + +void StmtVisitor::VisitStmt_(const IfThenElse* op) { + this->VisitExpr(op->condition); + this->VisitStmt(op->then_case); + if (op->else_case.defined()) { + this->VisitStmt(op->else_case); + } +} + +void StmtVisitor::VisitStmt_(const Free* op) {} + +void StmtVisitor::VisitStmt_(const AssertStmt* op) { + this->VisitExpr(op->condition); + this->VisitExpr(op->message); + this->VisitStmt(op->body); +} + +void StmtVisitor::VisitStmt_(const ProducerConsumer* op) { + this->VisitStmt(op->body); +} + +void StmtVisitor::VisitStmt_(const Provide* op) { + VisitArray(op->args, [this](const Expr& e) { this->VisitExpr(e); }); + this->VisitExpr(op->value); +} + +void StmtVisitor::VisitStmt_(const Realize* op) { + VisitArray(op->bounds, [this](const Range& r) { + this->VisitExpr(r->min); + this->VisitExpr(r->extent); + }); + this->VisitStmt(op->body); + this->VisitExpr(op->condition); +} + +void StmtVisitor::VisitStmt_(const Prefetch* op) { + VisitArray(op->bounds, [this](const Range& r) { + this->VisitExpr(r->min); + this->VisitExpr(r->extent); + }); +} + +void StmtVisitor::VisitStmt_(const Block* op) { + this->VisitStmt(op->first); + this->VisitStmt(op->rest); +} + +void StmtVisitor::VisitStmt_(const Evaluate* op) { + this->VisitExpr(op->value); +} + +void ExprVisitor::VisitExpr_(const Variable* op) {} + +void ExprVisitor::VisitExpr_(const Load* op) { + this->VisitExpr(op->index); + this->VisitExpr(op->predicate); +} + +void ExprVisitor::VisitExpr_(const Let* op) { + this->VisitExpr(op->value); + this->VisitExpr(op->body); +} + +void ExprVisitor::VisitExpr_(const Call* op) { + VisitArray(op->args, [this](const Expr& e) { this->VisitExpr(e); }); +} + +#define DEFINE_BINOP_VISIT_(OP) \ + void ExprVisitor::VisitExpr_(const OP* op) { \ + this->VisitExpr(op->a); \ + this->VisitExpr(op->b); \ + } + +DEFINE_BINOP_VISIT_(Add); +DEFINE_BINOP_VISIT_(Sub); +DEFINE_BINOP_VISIT_(Mul); +DEFINE_BINOP_VISIT_(Div); +DEFINE_BINOP_VISIT_(Mod); +DEFINE_BINOP_VISIT_(FloorDiv); +DEFINE_BINOP_VISIT_(FloorMod); +DEFINE_BINOP_VISIT_(Min); +DEFINE_BINOP_VISIT_(Max); +DEFINE_BINOP_VISIT_(EQ); +DEFINE_BINOP_VISIT_(NE); +DEFINE_BINOP_VISIT_(LT); +DEFINE_BINOP_VISIT_(LE); +DEFINE_BINOP_VISIT_(GT); +DEFINE_BINOP_VISIT_(GE); +DEFINE_BINOP_VISIT_(And); +DEFINE_BINOP_VISIT_(Or); + +void ExprVisitor::VisitExpr_(const IntImm* op) {} +void ExprVisitor::VisitExpr_(const UIntImm* op) {} +void ExprVisitor::VisitExpr_(const FloatImm* op) {} +void ExprVisitor::VisitExpr_(const StringImm* op) {} + +void ExprVisitor::VisitExpr_(const Reduce* op) { + VisitArray(op->axis, [this](const IterVar& r) { + this->VisitExpr(r->dom->min); + this->VisitExpr(r->dom->extent); + }); + VisitArray(op->source, [this](const Expr& e) { this->VisitExpr(e); }); + this->VisitExpr(op->condition); +} + +void ExprVisitor::VisitExpr_(const Cast* op) { + this->VisitExpr(op->value); +} + +void ExprVisitor::VisitExpr_(const Not* op) { + this->VisitExpr(op->a); +} + +void ExprVisitor::VisitExpr_(const Select* op) { + this->VisitExpr(op->condition); + this->VisitExpr(op->true_value); + this->VisitExpr(op->false_value); +} + +void ExprVisitor::VisitExpr_(const Ramp* op) { + this->VisitExpr(op->base); + this->VisitExpr(op->stride); +} + +void ExprVisitor::VisitExpr_(const Shuffle* op) { + VisitArray(op->indices, [this](const Expr& e) { this->VisitExpr(e); }); + VisitArray(op->vectors, [this](const Expr& e) { this->VisitExpr(e); }); +} + +void ExprVisitor::VisitExpr_(const Broadcast* op) { + this->VisitExpr(op->value); +} + +// Implementation of mutators +template +inline Array MutateArray(const Array& arr, + F fmutate, + bool allow_copy_on_write = false) { + if (allow_copy_on_write) { + // if we allow copy on write, we can directly + // call the inplace mutate function. + const_cast&>(arr).MutateByApply(fmutate); + return arr; + } else { + Array copy = arr; + copy.MutateByApply(fmutate); + return copy; + } +} + +class StmtMutator::Internal { + public: + static Array Mutate(StmtMutator* self, const Array& arr) { + auto fmutate = [self](const Expr& e) { return self->VisitExpr(e); }; + return MutateArray(arr, fmutate, self->allow_copy_on_write_); + } + + static Array Mutate(StmtMutator* self, const Array& arr) { + auto fmutate = [self](const Stmt& s) { return self->VisitStmt(s); }; + return MutateArray(arr, fmutate, self->allow_copy_on_write_); + } + + static Array Mutate(StmtMutator* self, const Array& arr) { + auto fmutate = [self](const Range& r) { + Expr min = self->VisitExpr(r->min); + Expr extent = self->VisitExpr(r->extent); + if (min.same_as(r->min) && extent.same_as(r->extent)) { + return r; + } else { + return Range::make_by_min_extent(min, extent); + } + }; + return MutateArray(arr, fmutate, self->allow_copy_on_write_); + } +}; + +Stmt StmtMutator::VisitStmt_(const AttrStmt* op) { + Expr value = this->VisitExpr(op->value); + Stmt body = this->VisitStmt(op->body); + if (value.same_as(op->value) && + body.same_as(op->body)) { + return GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->value = std::move(value); + n->body = std::move(body); + return Stmt(n); + } +} + +Stmt StmtMutator::VisitStmt_(const LetStmt* op) { + Expr value = this->VisitExpr(op->value); + Stmt body = this->VisitStmt(op->body); + if (value.same_as(op->value) && + body.same_as(op->body)) { + return GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->value = std::move(value); + n->body = std::move(body); + return Stmt(n); + } +} + +Stmt StmtMutator::VisitStmt_(const For* op) { + Expr min = this->VisitExpr(op->min); + Expr extent = this->VisitExpr(op->extent); + Stmt body = this->VisitStmt(op->body); + if (min.same_as(op->min) && + extent.same_as(op->extent) && + body.same_as(op->body)) { + return GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->min = std::move(min); + n->extent = std::move(extent); + n->body = std::move(body); + return Stmt(n); + } +} + +Stmt StmtMutator::VisitStmt_(const Allocate* op) { + Array extents = Internal::Mutate(this, op->extents); + Stmt body = this->VisitStmt(op->body); + Expr condition = this->VisitExpr(op->condition); + Expr new_expr; + if (op->new_expr.defined()) { + new_expr = this->VisitExpr(op->new_expr); + } + if (extents.same_as(op->extents) && + body.same_as(op->body) && + condition.same_as(op->condition) && + new_expr.same_as(op->new_expr)) { + return GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->extents = std::move(extents); + n->body = std::move(body); + n->condition = std::move(condition); + n->new_expr = std::move(new_expr); + return Stmt(n); + } +} + +Stmt StmtMutator::VisitStmt_(const IfThenElse* op) { + Expr condition = this->VisitExpr(op->condition); + Stmt then_case = this->VisitStmt(op->then_case); + Stmt else_case; + if (op->else_case.defined()) { + else_case = this->VisitStmt(op->else_case); + } + if (condition.same_as(op->condition) && + then_case.same_as(op->then_case) && + else_case.same_as(op->else_case)) { + return GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->condition = std::move(condition); + n->then_case = std::move(then_case); + n->else_case = std::move(then_case); + return Stmt(n); + } +} + +Stmt StmtMutator::VisitStmt_(const Store* op) { + Expr value = this->VisitExpr(op->value); + Expr index = this->VisitExpr(op->index); + Expr predicate = this->VisitExpr(op->predicate); + if (value.same_as(op->value) && + index.same_as(op->index) && + predicate.same_as(op->predicate)) { + return GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->value = std::move(value); + n->index = std::move(index); + n->predicate = std::move(predicate); + return Stmt(n); + } +} + +Stmt StmtMutator::VisitStmt_(const Provide* op) { + Array args = Internal::Mutate(this, op->args); + Expr value = this->VisitExpr(op->value); + if (args.same_as(op->args) && + value.same_as(op->value)) { + return GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->args = std::move(args); + n->value = std::move(value); + return Stmt(n); + } +} + +Stmt StmtMutator::VisitStmt_(const Realize* op) { + Region bounds = Internal::Mutate(this, op->bounds); + Stmt body = this->VisitStmt(op->body); + Expr condition = this->VisitExpr(op->condition); + if (bounds.same_as(op->bounds) && + body.same_as(op->body) && + condition.same_as(op->condition)) { + return GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->bounds = std::move(bounds); + n->body = std::move(body); + n->condition = std::move(condition); + return Stmt(n); + } +} + +Stmt StmtMutator::VisitStmt_(const Prefetch* op) { + Region bounds = Internal::Mutate(this, op->bounds); + if (bounds.same_as(op->bounds)) { + return GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->bounds = std::move(bounds); + return Stmt(n); + } +} + +Stmt StmtMutator::VisitStmt_(const Block* op) { + Stmt first = this->VisitStmt(op->first); + Stmt rest = this->VisitStmt(op->rest); + if (first.same_as(op->first) && + rest.same_as(op->rest)) { + return GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->first = std::move(first); + n->rest = std::move(rest); + return Stmt(n); + } +} + +Stmt StmtMutator::VisitStmt_(const AssertStmt* op) { + Expr condition = this->VisitExpr(op->condition); + Expr message = this->VisitExpr(op->message); + Stmt body = this->VisitStmt(op->body); + + if (condition.same_as(op->condition) && + message.same_as(op->message) && + body.same_as(op->body)) { + return GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->condition = std::move(condition); + n->message = std::move(message); + n->body = std::move(body); + return Stmt(n); + } +} + +Stmt StmtMutator::VisitStmt_(const ProducerConsumer* op) { + Stmt body = this->VisitStmt(op->body); + if (body.same_as(op->body)) { + return GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->body = std::move(body); + return Stmt(n); + } +} + +Stmt StmtMutator::VisitStmt_(const Evaluate* op) { + Expr value = this->VisitExpr(op->value); + if (value.same_as(op->value)) { + return GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->value = std::move(value); + return Stmt(n); + } +} + +Stmt StmtMutator::VisitStmt_(const Free* op) { + return GetRef(op); +} + + +Expr ExprMutator::VisitExpr_(const Variable* op) { + return GetRef(op); +} + +Expr ExprMutator::VisitExpr_(const Load* op) { + Expr index = this->VisitExpr(op->index); + Expr predicate = this->VisitExpr(op->predicate); + if (index.same_as(op->index) && predicate.same_as(op->predicate)) { + return GetRef(op); + } else { + return Load::make(op->dtype, op->buffer_var, index, predicate); + } +} + +Expr ExprMutator::VisitExpr_(const Let* op) { + Expr value = this->VisitExpr(op->value); + Expr body = this->VisitExpr(op->body); + if (value.same_as(op->value) && + body.same_as(op->body)) { + return GetRef(op); + } else { + return Let::make(op->var, value, body); + } +} + +Expr ExprMutator::VisitExpr_(const Call* op) { + auto fmutate = [this](const Expr& e) { return this->VisitExpr(e); }; + Array args = MutateArray(op->args, fmutate); + + if (args.same_as(op->args)) { + return GetRef(op); + } else { + return Call::make(op->dtype, + op->name, + args, + op->call_type, + op->func, + op->value_index); + } +} + +#define DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(OP) \ + Expr ExprMutator::VisitExpr_(const OP *op) { \ + return GetRef(op); \ + } + +DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(IntImm) +DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(UIntImm) +DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(FloatImm) +DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(StringImm) + +#define DEFINE_BIOP_EXPR_MUTATE_(OP) \ + Expr ExprMutator::VisitExpr_(const OP* op) { \ + Expr a = this->VisitExpr(op->a); \ + Expr b = this->VisitExpr(op->b); \ + if (a.same_as(op->a) && \ + b.same_as(op->b)) { \ + return GetRef(op); \ + } else { \ + return OP::make(a, b); \ + } \ + } + +DEFINE_BIOP_EXPR_MUTATE_(Add); +DEFINE_BIOP_EXPR_MUTATE_(Sub); +DEFINE_BIOP_EXPR_MUTATE_(Mul); +DEFINE_BIOP_EXPR_MUTATE_(Div); +DEFINE_BIOP_EXPR_MUTATE_(Mod); +DEFINE_BIOP_EXPR_MUTATE_(FloorDiv); +DEFINE_BIOP_EXPR_MUTATE_(FloorMod); +DEFINE_BIOP_EXPR_MUTATE_(Min); +DEFINE_BIOP_EXPR_MUTATE_(Max); +DEFINE_BIOP_EXPR_MUTATE_(EQ); +DEFINE_BIOP_EXPR_MUTATE_(NE); +DEFINE_BIOP_EXPR_MUTATE_(LT); +DEFINE_BIOP_EXPR_MUTATE_(LE); +DEFINE_BIOP_EXPR_MUTATE_(GT); +DEFINE_BIOP_EXPR_MUTATE_(GE); +DEFINE_BIOP_EXPR_MUTATE_(And); +DEFINE_BIOP_EXPR_MUTATE_(Or); + +Expr ExprMutator::VisitExpr_(const Reduce* op) { + auto fitervar = [this](const IterVar& v) { + Range r = v->dom; + Expr min = this->VisitExpr(r->min); + Expr extent = this->VisitExpr(r->extent); + if (min.same_as(r->min) && + extent.same_as(r->extent)) { + return v; + } else { + return IterVarNode::make( + Range::make_by_min_extent(min, extent), + v->var, v->iter_type, v->thread_tag); + } + }; + Array axis = MutateArray(op->axis, fitervar); + + auto fexpr = [this](const Expr& e) { return this->VisitExpr(e); }; + Array source = MutateArray(op->source, fexpr); + + Expr condition = this->VisitExpr(op->condition); + + if (axis.same_as(op->axis) && + source.same_as(op->source) && + condition.same_as(op->condition)) { + return GetRef(op); + } else { + return Reduce::make( + op->combiner, source, axis, condition, op->value_index); + } +} + +Expr ExprMutator::VisitExpr_(const Cast* op) { + Expr value = this->VisitExpr(op->value); + if (value.same_as(op->value)) { + return GetRef(op); + } else { + return Cast::make(op->dtype, value); + } +} + +Expr ExprMutator::VisitExpr_(const Not* op) { + Expr a = this->VisitExpr(op->a); + if (a.same_as(op->a)) { + return GetRef(op); + } else { + return Not::make(a); + } +} + +Expr ExprMutator::VisitExpr_(const Select* op) { + Expr condition = this->VisitExpr(op->condition); + Expr true_value = this->VisitExpr(op->true_value); + Expr false_value = this->VisitExpr(op->false_value); + if (condition.same_as(op->condition) && + true_value.same_as(op->true_value) && + false_value.same_as(op->false_value)) { + return GetRef(op); + } else { + return Select::make(condition, true_value, false_value); + } +} + +Expr ExprMutator::VisitExpr_(const Ramp* op) { + Expr base = this->VisitExpr(op->base); + Expr stride = this->VisitExpr(op->stride); + if (base.same_as(op->base) && + stride.same_as(op->stride)) { + return GetRef(op); + } else { + return Ramp::make(base, stride, op->lanes); + } +} + +Expr ExprMutator::VisitExpr_(const Broadcast* op) { + Expr value = this->VisitExpr(op->value); + if (value.same_as(op->value)) { + return GetRef(op); + } else { + return Broadcast::make(value, op->lanes); + } +} + +Expr ExprMutator::VisitExpr_(const Shuffle* op) { + auto fexpr = [this](const Expr& e) { return this->VisitExpr(e); }; + auto vectors = MutateArray(op->vectors, fexpr); + if (vectors.same_as(op->vectors)) { + return GetRef(op); + } else { + return Shuffle::make(vectors, op->indices); + } +} + +} // namespace ir +} // namespace tvm diff --git a/src/pass/ir_mutator.cc b/src/pass/ir_mutator.cc index b300989dd2fd..5ba29fc0b6e9 100644 --- a/src/pass/ir_mutator.cc +++ b/src/pass/ir_mutator.cc @@ -28,59 +28,6 @@ namespace tvm { namespace ir { -class IRTransformer final : public IRMutator { - public: - IRTransformer(const runtime::PackedFunc& f_preorder, - const runtime::PackedFunc& f_postorder, - const std::unordered_set& only_enable) - : f_preorder_(f_preorder), - f_postorder_(f_postorder), - only_enable_(only_enable) { - } - Stmt Mutate(Stmt stmt) final { - return MutateInternal(stmt); - } - Expr Mutate(Expr expr) final { - return MutateInternal(expr); - } - - private: - template - T MutateInternal(T node) { - if (only_enable_.size() && - !only_enable_.count(node->type_index())) { - return IRMutator::Mutate(node); - } - if (f_preorder_ != nullptr) { - T pre = f_preorder_(node); - if (pre.defined()) return pre; - } - node = IRMutator::Mutate(node); - if (f_postorder_ != nullptr) { - T post = f_postorder_(node); - if (post.defined()) return post; - } - return node; - } - // The functions - const runtime::PackedFunc& f_preorder_; - const runtime::PackedFunc& f_postorder_; - // type indices enabled. - const std::unordered_set& only_enable_; -}; - -Stmt IRTransform(const Stmt& ir_node, - const runtime::PackedFunc& f_preorder, - const runtime::PackedFunc& f_postorder, - const Array& only_enable) { - std::unordered_set only_type_index; - for (Expr s : only_enable) { - only_type_index.insert(Object::TypeKey2Index(s.as()->value.c_str())); - } - return IRTransformer(f_preorder, f_postorder, only_type_index) - .Mutate(ir_node); -} - IRMutator::FMutateExpr& IRMutator::vtable_expr() { // NOLINT(*) static FMutateExpr inst; return inst; } diff --git a/src/pass/ir_visitor.cc b/src/pass/ir_visitor.cc index 467cd5de2ef7..af6ea5252166 100644 --- a/src/pass/ir_visitor.cc +++ b/src/pass/ir_visitor.cc @@ -26,26 +26,6 @@ namespace tvm { namespace ir { -// visitor to implement apply -class IRApplyVisit : public IRVisitor { - public: - explicit IRApplyVisit(std::function f) : f_(f) {} - - void Visit(const ObjectRef& node) final { - if (visited_.count(node.get()) != 0) return; - visited_.insert(node.get()); - IRVisitor::Visit(node); - f_(node); - } - - private: - std::function f_; - std::unordered_set visited_; -}; - -void PostOrderVisit(const ObjectRef& node, std::function fvisit) { - IRApplyVisit(fvisit).Visit(node); -} IRVisitor::FVisit& IRVisitor::vtable() { // NOLINT(*) static FVisit inst; return inst; diff --git a/tests/cpp/ir_functor_test.cc b/tests/cpp/ir_functor_test.cc index 5636958a5b26..5f0860153587 100644 --- a/tests/cpp/ir_functor_test.cc +++ b/tests/cpp/ir_functor_test.cc @@ -31,7 +31,6 @@ TEST(IRF, Basic) { auto z = x + 1; NodeFunctor f; - LOG(INFO) << "x"; f.set_dispatch([](const ObjectRef& n, int b) { return b; }); @@ -101,6 +100,98 @@ TEST(IRF, ExprVisit) { CHECK_EQ(v.count, 1); } + +TEST(IRF, StmtVisitor) { + using namespace tvm; + using namespace tvm::ir; + Var x("x"); + class MyVisitor + : public StmtExprVisitor { + public: + int count = 0; + // implementation + void VisitExpr_(const Variable* op) final { + ++count; + } + }; + MyVisitor v; + auto fmaketest = [&]() { + auto z = x + 1; + Stmt body = Evaluate::make(z); + Var buffer("b", DataType::Handle()); + return Allocate::make(buffer, DataType::Float(32), {z, z}, const_true(), body); + }; + v(fmaketest()); + CHECK_EQ(v.count, 3); +} + +TEST(IRF, StmtMutator) { + using namespace tvm; + using namespace tvm::ir; + Var x("x"); + + class MyVisitor + : public ir::StmtMutator, + public ir::ExprMutator { + public: + using StmtMutator::operator(); + using ExprMutator::operator(); + + protected: + // implementation + Expr VisitExpr_(const Add* op) final { + return op->a; + } + Expr VisitExpr(const Expr& expr) final { + return ExprMutator::VisitExpr(expr); + } + }; + auto fmaketest = [&]() { + auto z = x + 1; + Stmt body = Evaluate::make(z); + Var buffer("b", DataType::Handle()); + return Allocate::make(buffer, DataType::Float(32), {1, z}, const_true(), body); + }; + + MyVisitor v; + { + auto body = fmaketest(); + Stmt body2 = Evaluate::make(1); + Stmt bref = body.as()->body; + auto* extentptr = body.as()->extents.get(); + Array arr{std::move(body), body2, body2}; + auto* arrptr = arr.get(); + arr.MutateByApply([&](Stmt s) { return v(std::move(s)); }); + CHECK(arr.get() == arrptr); + // inplace update body + CHECK(arr[0].as()->extents[1].same_as(x)); + CHECK(arr[0].as()->extents.get() == extentptr); + // copy because there is additional refs + CHECK(!arr[0].as()->body.same_as(bref)); + CHECK(arr[0].as()->body.as()->value.same_as(x)); + CHECK(bref.as()->value.as()); + } + { + Array arr{fmaketest()}; + // mutate array get reference by another one, triiger copy. + Array arr2 = arr; + auto* arrptr = arr.get(); + arr.MutateByApply([&](Stmt s) { return v(std::move(s)); }); + CHECK(arr.get() != arrptr); + CHECK(arr[0].as()->extents[1].same_as(x)); + CHECK(!arr2[0].as()->extents[1].same_as(x)); + // mutate but no content change. + arr2 = arr; + arr.MutateByApply([&](Stmt s) { return v(std::move(s)); }); + CHECK(arr2.get() == arr.get()); + } + { + auto body = Evaluate::make(Call::make(DataType::Int(32), "xyz", {x + 1}, Call::Extern)); + auto res = v(std::move(body)); + CHECK(res.as()->value.as()->args[0].same_as(x)); + } +} + int main(int argc, char ** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe";