diff --git a/include/tvm/ir_functor_ext.h b/include/tvm/ir_functor_ext.h index 9b2632f87b3cf..ba2d1f3a3e829 100644 --- a/include/tvm/ir_functor_ext.h +++ b/include/tvm/ir_functor_ext.h @@ -287,6 +287,283 @@ class StmtFunctor { #undef EXPR_FUNCTOR_DEFAULT #undef STMT_FUNCTOR_DEFAULT +/*! + * \brief ExprVisitor + */ +class TVM_DLL ExprVisitor : + public ExprFunctor { + public: + using ExprFunctor::operator(); + + protected: + /*! + * \brief Call into visit expr. + * \param expr The expr to visit. + */ + void Visit(const Expr& expr) { + this->VisitExpr(expr); + } + // list of functions to override. + void VisitExpr_(const Variable* op) override; + void VisitExpr_(const Load* op) override; + void VisitExpr_(const Let* op) override; + void VisitExpr_(const Call* op) override; + void VisitExpr_(const Add* op) override; + void VisitExpr_(const Sub* op) override; + void VisitExpr_(const Mul* op) override; + void VisitExpr_(const Div* op) override; + void VisitExpr_(const Mod* op) override; + void VisitExpr_(const FloorDiv* op) override; + void VisitExpr_(const FloorMod* op) override; + void VisitExpr_(const Min* op) override; + void VisitExpr_(const Max* op) override; + void VisitExpr_(const EQ* op) override; + void VisitExpr_(const NE* op) override; + void VisitExpr_(const LT* op) override; + void VisitExpr_(const LE* op) override; + void VisitExpr_(const GT* op) override; + void VisitExpr_(const GE* op) override; + void VisitExpr_(const And* op) override; + void VisitExpr_(const Or* op) override; + void VisitExpr_(const Reduce* op) override; + void VisitExpr_(const Cast* op) override; + void VisitExpr_(const Not* op) override; + void VisitExpr_(const Select* op) override; + void VisitExpr_(const Ramp* op) override; + void VisitExpr_(const Broadcast* op) override; + void VisitExpr_(const Shuffle* op) override; + void VisitExpr_(const IntImm* op) override; + void VisitExpr_(const UIntImm* op) override; + void VisitExpr_(const FloatImm* op) override; + void VisitExpr_(const StringImm* op) override; +}; + +/*! + * \brief ExprMutator that mutates expressions. + */ +class TVM_DLL ExprMutator : + protected ExprFunctor { + public: + using ExprFunctor::operator(); + + protected: + /*! + * \brief Internal mutator that everyone calls. + * \note To override mutate's behavior, override VisitExpr instead. + * \param expr The input expression. + * \return The mutated results. + */ + Expr Mutate(const Expr& expr) { + return this->VisitExpr(expr); + } + // list of functions to override. + Expr VisitExpr_(const Variable* op) override; + Expr VisitExpr_(const Load* op) override; + Expr VisitExpr_(const Let* op) override; + Expr VisitExpr_(const Call* op) override; + Expr VisitExpr_(const Add* op) override; + Expr VisitExpr_(const Sub* op) override; + Expr VisitExpr_(const Mul* op) override; + Expr VisitExpr_(const Div* op) override; + Expr VisitExpr_(const Mod* op) override; + Expr VisitExpr_(const FloorDiv* op) override; + Expr VisitExpr_(const FloorMod* op) override; + Expr VisitExpr_(const Min* op) override; + Expr VisitExpr_(const Max* op) override; + Expr VisitExpr_(const EQ* op) override; + Expr VisitExpr_(const NE* op) override; + Expr VisitExpr_(const LT* op) override; + Expr VisitExpr_(const LE* op) override; + Expr VisitExpr_(const GT* op) override; + Expr VisitExpr_(const GE* op) override; + Expr VisitExpr_(const And* op) override; + Expr VisitExpr_(const Or* op) override; + Expr VisitExpr_(const Reduce* op) override; + Expr VisitExpr_(const Cast* op) override; + Expr VisitExpr_(const Not* op) override; + Expr VisitExpr_(const Select* op) override; + Expr VisitExpr_(const Ramp* op) override; + Expr VisitExpr_(const Broadcast* op) override; + Expr VisitExpr_(const Shuffle* op) override; + Expr VisitExpr_(const IntImm* op) override; + Expr VisitExpr_(const UIntImm* op) override; + Expr VisitExpr_(const FloatImm* op) override; + Expr VisitExpr_(const StringImm* op) override; +}; + +/*! + * \brief StmtVisitor. + */ +class TVM_DLL StmtVisitor : + protected StmtFunctor { + public: + using StmtFunctor::operator(); + + protected: + /*! + * \brief Call into VisitStmt. + * \param stmt The stmt to visit. + */ + void Visit(const Stmt& stmt) { + this->VisitStmt(stmt); + } + /*! + * \brief Visitor to Exprs, can be overriden + * to do recursive changes to Exprs. + * \note A common pattern is to call ExprVisitor here, + * or have a class sub-class both StmtVisitor and ExprVisitor + * and redirect Visit to ExprMutator::VisitExpr(Expr) + */ + virtual void Visit(const Expr& e) {} + // statement visitor + void VisitStmt_(const AttrStmt* op) override; + void VisitStmt_(const IfThenElse* op) override; + void VisitStmt_(const LetStmt* op) override; + void VisitStmt_(const For* op) override; + void VisitStmt_(const Allocate* op) override; + void VisitStmt_(const Store* op) override; + void VisitStmt_(const Free* op) override; + void VisitStmt_(const AssertStmt* op) override; + void VisitStmt_(const ProducerConsumer* op) override; + void VisitStmt_(const Provide* op) override; + void VisitStmt_(const Realize* op) override; + void VisitStmt_(const Prefetch* op) override; + void VisitStmt_(const Block* op) override; + void VisitStmt_(const Evaluate* op) override; +}; + +/*! + * \brief StmtMutator that mutates the statements. + */ +class TVM_DLL StmtMutator : + protected StmtFunctor { + public: + /*! + * \brief Mutate stmt. + * \param stmt The input statement to be mutated. + * \return The result of the call + * \note It is important that stmt is passed by value. + * so copy on write can be triggered correctly. + * do mutator(std::move(stmt)) or when copy elison is triggered. + */ + Stmt operator()(Stmt stmt) { + allow_copy_on_write_ = true; + return Mutate(stmt); + } + + protected: + // We perform copy on write optimizations on the StmtMutator + // so that an unique copy of parent can be mutated inplace + // when some of its children changed. + // We only do such optimization for Stmt nests(instead of Exprs) for now + // as Stmt's parent state is more likely remain unchanged when one of + // its child block changes. + /*! + * \brief Internal state to indicate whether copy on write is enabled. + * COW is enabled iff all the parents of the node are unique. + */ + bool allow_copy_on_write_{false}; + /*! + * \brief Perform copy on write on node. + * + * If CopyOnWrite is allowed, directly return + * a strong reference to the node container. + * Otherwise, return a copy of the node. + * + * \return The result object pointer. + */ + template + ObjectPtr CopyOnWrite(const TNode* node) { + if (allow_copy_on_write_) { + // return the old node. + return runtime::GetObjectPtr(const_cast(node)); + } else { + // Make a new copy of the node. + // need to rely on the default copy constructor + return runtime::make_object(*node); + } + } + /*! + * \brief Internal mutator that everyone calls. + * \note To override mutate's behavior, override VisitExpr instead. + * \param stmt The input stmt. + * \return The mutated results. + */ + Stmt Mutate(const Stmt& stmt) { + if (allow_copy_on_write_ && !stmt.unique()) { + allow_copy_on_write_ = false; + Stmt ret = this->VisitStmt(stmt); + allow_copy_on_write_ = true; + return ret; + } else { + return this->VisitStmt(stmt); + } + } + /*! + * \brief Visitor to Exprs, can be overriden + * to do recursive changes to Exprs. + * \note A common pattern is to call ExprMutator here, + * or have a class sub-class both StmtMutator and ExprMutator + * and redirect Mutate to ExprMutator::Mutate(Expr) + */ + virtual Expr Mutate(const Expr& e); + // statement visitor + Stmt VisitStmt_(const AttrStmt* op) override; + Stmt VisitStmt_(const IfThenElse* op) override; + Stmt VisitStmt_(const LetStmt* op) override; + Stmt VisitStmt_(const For* op) override; + Stmt VisitStmt_(const Allocate* op) override; + Stmt VisitStmt_(const Store* op) override; + Stmt VisitStmt_(const Free* op) override; + Stmt VisitStmt_(const AssertStmt* op) override; + Stmt VisitStmt_(const ProducerConsumer* op) override; + Stmt VisitStmt_(const Provide* op) override; + Stmt VisitStmt_(const Realize* op) override; + Stmt VisitStmt_(const Prefetch* op) override; + Stmt VisitStmt_(const Block* op) override; + Stmt VisitStmt_(const Evaluate* op) override; + // internal helper. + class Internal; +}; + +/*! + * \brief Visitor that recursively visit stmts and exprs on them. + */ +class StmtExprVisitor : + public StmtVisitor, + public ExprVisitor { + public: + using StmtVisitor::operator(); + using ExprVisitor::operator(); + + protected: + using StmtVisitor::Visit; + using ExprVisitor::Visit; + + void Visit(const Expr& e) final { + return ExprVisitor::Visit(e); + } +}; + +/*! + * \brief Mutator that recursively mutates stmts and exprs on them. + */ +class StmtExprMutator : + public StmtMutator, + public ExprMutator { + public: + using StmtMutator::operator(); + using ExprMutator::operator(); + + protected: + using StmtMutator::Mutate; + using ExprMutator::Mutate; + + Expr Mutate(const Expr& e) final { + return ExprMutator::Mutate(e); + } +}; + } // namespace ir } // namespace tvm #endif // TVM_IR_FUNCTOR_EXT_H_ diff --git a/include/tvm/ir_mutator.h b/include/tvm/ir_mutator.h index 5460ae0f4ba9a..702cea3ce8fd0 100644 --- a/include/tvm/ir_mutator.h +++ b/include/tvm/ir_mutator.h @@ -123,6 +123,7 @@ class TVM_DLL IRMutator { virtual Expr Mutate_(const Shuffle* op, const Expr& e); }; + /*! * \brief recursively visit the ir in post DFS order node, and transform it * @@ -138,7 +139,7 @@ class TVM_DLL IRMutator { * If it is not empty, preorder/postorder will only be called * when the IRNode's type key is in the list. */ -Stmt IRTransform(const Stmt& node, +Stmt IRTransform(Stmt node, const runtime::PackedFunc& preorder, const runtime::PackedFunc& postorder, const Array& only_enable = {}); diff --git a/include/tvm/node/container.h b/include/tvm/node/container.h index d20fb288039cb..0d5ab376d50a3 100644 --- a/include/tvm/node/container.h +++ b/include/tvm/node/container.h @@ -284,6 +284,48 @@ class Array : public ObjectRef { inline bool empty() const { return size() == 0; } + /*! + * \brief Helper function to apply fmutate to mutate an array. + * \param fmutate The transformation function T -> T. + * \tparam F the type of the mutation function. + * \note This function performs copy on write optimization. + */ + template + inline void MutateByApply(F fmutate) { + ArrayNode* ptr = static_cast(data_.get()); + if (ptr == nullptr) return; + if (data_.unique()) { + // Copy on write optimization. + // Perform inplace update because this is an unique copy. + for (size_t i = 0; i < ptr->data.size(); ++i) { + // It is important to use move here + // to make prevent the element's ref count from increasing + // so fmutate itself can perform copy-on-write optimization + T old_elem = DowncastNoCheck(std::move(ptr->data[i])); + T new_elem = fmutate(std::move(old_elem)); + ptr->data[i] = std::move(new_elem); + } + } else { + // lazily trigger copy if there is element change. + ObjectPtr copy; + for (size_t i = 0; i < ptr->data.size(); ++i) { + T old_elem = DowncastNoCheck(ptr->data[i]); + T new_elem = fmutate(old_elem); + if (!new_elem.same_as(ptr->data[i])) { + // copy the old array + if (copy == nullptr) { + copy = runtime::make_object(*ptr); + } + copy->data[i] = std::move(new_elem); + } + } + // replace the data with the new copy. + if (copy != nullptr) { + data_ = std::move(copy); + } + } + } + /*! \brief specify container node */ using ContainerType = ArrayNode; diff --git a/src/pass/ir_functor.cc b/src/pass/ir_functor.cc new file mode 100644 index 0000000000000..0386f4ae6f44c --- /dev/null +++ b/src/pass/ir_functor.cc @@ -0,0 +1,729 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file ir_functor.cc + */ +#include +#include +#include + +namespace tvm { +namespace ir { + +// visitor to implement apply +class IRApplyVisit : + public StmtExprVisitor { + public: + explicit IRApplyVisit(std::function f) : f_(f) {} + + void VisitExpr(const Expr& node) { + if (visited_.count(node.get()) != 0) return; + visited_.insert(node.get()); + ExprVisitor::VisitExpr(node); + f_(node); + } + + void VisitStmt(const Stmt& node) { + if (visited_.count(node.get()) != 0) return; + visited_.insert(node.get()); + StmtVisitor::VisitStmt(node); + f_(node); + } + + private: + std::function f_; + std::unordered_set visited_; +}; + +void PostOrderVisit(const ObjectRef& node, + std::function fvisit) { + if (node.as()) { + IRApplyVisit visitor(fvisit); + visitor(Downcast(node)); + } else { + IRApplyVisit visitor(fvisit); + visitor(Downcast(node)); + } +} + +class IRTransformer final : + public StmtExprMutator { + public: + IRTransformer(const runtime::PackedFunc& f_preorder, + const runtime::PackedFunc& f_postorder, + const std::unordered_set& only_enable) + : f_preorder_(f_preorder), + f_postorder_(f_postorder), + only_enable_(only_enable) { + } + Stmt VisitStmt(const Stmt& stmt) final { + return MutateInternal(stmt, [this](const Stmt& s){ + return StmtMutator::VisitStmt(s); + }); + } + Expr VisitExpr(const Expr& expr) final { + return MutateInternal(expr, [this](const Expr& e) { + return ExprMutator::VisitExpr(e); + }); + } + + private: + template + T MutateInternal(const T& node, F fmutate) { + if (only_enable_.size() && + !only_enable_.count(node->type_index())) { + return fmutate(node); + } + if (f_preorder_ != nullptr) { + T pre = f_preorder_(node); + if (pre.defined()) return pre; + } + T new_node = fmutate(node); + if (f_postorder_ != nullptr) { + T post = f_postorder_(new_node); + if (post.defined()) return post; + } + return new_node; + } + // The functions + const runtime::PackedFunc& f_preorder_; + const runtime::PackedFunc& f_postorder_; + // type indices enabled. + const std::unordered_set& only_enable_; +}; + +Stmt IRTransform(Stmt ir_node, + const runtime::PackedFunc& f_preorder, + const runtime::PackedFunc& f_postorder, + const Array& only_enable) { + std::unordered_set only_type_index; + for (Expr s : only_enable) { + only_type_index.insert(Object::TypeKey2Index(s.as()->value.c_str())); + } + IRTransformer transform(f_preorder, f_postorder, only_type_index); + return transform(std::move(ir_node)); +} + +// Implementation of Visitors +template +inline void VisitArray(const Array& arr, F fvisit) { + for (size_t i = 0; i < arr.size(); i++) { + fvisit(arr[i]); + } +} + +void StmtVisitor::VisitStmt_(const LetStmt* op) { + this->Visit(op->value); + this->Visit(op->body); +} + +void StmtVisitor::VisitStmt_(const AttrStmt* op) { + this->Visit(op->value); + this->Visit(op->body); +} + +void StmtVisitor::VisitStmt_(const For* op) { + this->Visit(op->min); + this->Visit(op->extent); + this->Visit(op->body); +} + +void StmtVisitor::VisitStmt_(const Allocate* op) { + VisitArray(op->extents, [this](const Expr& e) { this->Visit(e); }); + + for (size_t i = 0; i < op->extents.size(); i++) { + this->Visit(op->extents[i]); + } + this->Visit(op->body); + this->Visit(op->condition); + if (op->new_expr.defined()) { + this->Visit(op->new_expr); + } +} + +void StmtVisitor::VisitStmt_(const Store* op) { + this->Visit(op->value); + this->Visit(op->index); + this->Visit(op->predicate); +} + +void StmtVisitor::VisitStmt_(const IfThenElse* op) { + this->Visit(op->condition); + this->Visit(op->then_case); + if (op->else_case.defined()) { + this->Visit(op->else_case); + } +} + +void StmtVisitor::VisitStmt_(const Free* op) {} + +void StmtVisitor::VisitStmt_(const AssertStmt* op) { + this->Visit(op->condition); + this->Visit(op->message); + this->Visit(op->body); +} + +void StmtVisitor::VisitStmt_(const ProducerConsumer* op) { + this->Visit(op->body); +} + +void StmtVisitor::VisitStmt_(const Provide* op) { + VisitArray(op->args, [this](const Expr& e) { this->Visit(e); }); + this->Visit(op->value); +} + +void StmtVisitor::VisitStmt_(const Realize* op) { + VisitArray(op->bounds, [this](const Range& r) { + this->Visit(r->min); + this->Visit(r->extent); + }); + this->Visit(op->body); + this->Visit(op->condition); +} + +void StmtVisitor::VisitStmt_(const Prefetch* op) { + VisitArray(op->bounds, [this](const Range& r) { + this->Visit(r->min); + this->Visit(r->extent); + }); +} + +void StmtVisitor::VisitStmt_(const Block* op) { + this->Visit(op->first); + this->Visit(op->rest); +} + +void StmtVisitor::VisitStmt_(const Evaluate* op) { + this->Visit(op->value); +} + +void ExprVisitor::VisitExpr_(const Variable* op) {} + +void ExprVisitor::VisitExpr_(const Load* op) { + this->Visit(op->index); + this->Visit(op->predicate); +} + +void ExprVisitor::VisitExpr_(const Let* op) { + this->Visit(op->value); + this->Visit(op->body); +} + +void ExprVisitor::VisitExpr_(const Call* op) { + VisitArray(op->args, [this](const Expr& e) { this->Visit(e); }); +} + +#define DEFINE_BINOP_VISIT_(OP) \ + void ExprVisitor::VisitExpr_(const OP* op) { \ + this->Visit(op->a); \ + this->Visit(op->b); \ + } + +DEFINE_BINOP_VISIT_(Add); +DEFINE_BINOP_VISIT_(Sub); +DEFINE_BINOP_VISIT_(Mul); +DEFINE_BINOP_VISIT_(Div); +DEFINE_BINOP_VISIT_(Mod); +DEFINE_BINOP_VISIT_(FloorDiv); +DEFINE_BINOP_VISIT_(FloorMod); +DEFINE_BINOP_VISIT_(Min); +DEFINE_BINOP_VISIT_(Max); +DEFINE_BINOP_VISIT_(EQ); +DEFINE_BINOP_VISIT_(NE); +DEFINE_BINOP_VISIT_(LT); +DEFINE_BINOP_VISIT_(LE); +DEFINE_BINOP_VISIT_(GT); +DEFINE_BINOP_VISIT_(GE); +DEFINE_BINOP_VISIT_(And); +DEFINE_BINOP_VISIT_(Or); + +void ExprVisitor::VisitExpr_(const IntImm* op) {} +void ExprVisitor::VisitExpr_(const UIntImm* op) {} +void ExprVisitor::VisitExpr_(const FloatImm* op) {} +void ExprVisitor::VisitExpr_(const StringImm* op) {} + +void ExprVisitor::VisitExpr_(const Reduce* op) { + VisitArray(op->axis, [this](const IterVar& r) { + this->Visit(r->dom->min); + this->Visit(r->dom->extent); + }); + VisitArray(op->source, [this](const Expr& e) { this->Visit(e); }); + this->Visit(op->condition); +} + +void ExprVisitor::VisitExpr_(const Cast* op) { + this->Visit(op->value); +} + +void ExprVisitor::VisitExpr_(const Not* op) { + this->Visit(op->a); +} + +void ExprVisitor::VisitExpr_(const Select* op) { + this->Visit(op->condition); + this->Visit(op->true_value); + this->Visit(op->false_value); +} + +void ExprVisitor::VisitExpr_(const Ramp* op) { + this->Visit(op->base); + this->Visit(op->stride); +} + +void ExprVisitor::VisitExpr_(const Shuffle* op) { + VisitArray(op->indices, [this](const Expr& e) { this->Visit(e); }); + VisitArray(op->vectors, [this](const Expr& e) { this->Visit(e); }); +} + +void ExprVisitor::VisitExpr_(const Broadcast* op) { + this->Visit(op->value); +} + +// Implementation of mutators +template +inline Array MutateArray(const Array& arr, + F fmutate, + bool allow_copy_on_write = false) { + if (allow_copy_on_write) { + // if we allow copy on write, we can directly + // call the inplace mutate function. + const_cast&>(arr).MutateByApply(fmutate); + return arr; + } else { + Array copy = arr; + copy.MutateByApply(fmutate); + return copy; + } +} + +class StmtMutator::Internal { + public: + static Array Mutate(StmtMutator* self, const Array& arr) { + auto fmutate = [self](const Expr& e) { return self->Mutate(e); }; + return MutateArray(arr, fmutate, self->allow_copy_on_write_); + } + + static Array Mutate(StmtMutator* self, const Array& arr) { + auto fmutate = [self](const Stmt& s) { return self->Mutate(s); }; + return MutateArray(arr, fmutate, self->allow_copy_on_write_); + } + + static Array Mutate(StmtMutator* self, const Array& arr) { + auto fmutate = [self](const Range& r) { + Expr min = self->Mutate(r->min); + Expr extent = self->Mutate(r->extent); + if (min.same_as(r->min) && extent.same_as(r->extent)) { + return r; + } else { + return Range::make_by_min_extent(min, extent); + } + }; + return MutateArray(arr, fmutate, self->allow_copy_on_write_); + } +}; + +Stmt StmtMutator::VisitStmt_(const AttrStmt* op) { + Expr value = this->Mutate(op->value); + Stmt body = this->Mutate(op->body); + if (value.same_as(op->value) && + body.same_as(op->body)) { + return GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->value = std::move(value); + n->body = std::move(body); + return Stmt(n); + } +} + +Expr StmtMutator::Mutate(const Expr& e) { + return e; +} + +Stmt StmtMutator::VisitStmt_(const LetStmt* op) { + Expr value = this->Mutate(op->value); + Stmt body = this->Mutate(op->body); + if (value.same_as(op->value) && + body.same_as(op->body)) { + return GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->value = std::move(value); + n->body = std::move(body); + return Stmt(n); + } +} + +Stmt StmtMutator::VisitStmt_(const For* op) { + Expr min = this->Mutate(op->min); + Expr extent = this->Mutate(op->extent); + Stmt body = this->Mutate(op->body); + if (min.same_as(op->min) && + extent.same_as(op->extent) && + body.same_as(op->body)) { + return GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->min = std::move(min); + n->extent = std::move(extent); + n->body = std::move(body); + return Stmt(n); + } +} + +Stmt StmtMutator::VisitStmt_(const Allocate* op) { + Array extents = Internal::Mutate(this, op->extents); + Stmt body = this->Mutate(op->body); + Expr condition = this->Mutate(op->condition); + Expr new_expr; + if (op->new_expr.defined()) { + new_expr = this->Mutate(op->new_expr); + } + if (extents.same_as(op->extents) && + body.same_as(op->body) && + condition.same_as(op->condition) && + new_expr.same_as(op->new_expr)) { + return GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->extents = std::move(extents); + n->body = std::move(body); + n->condition = std::move(condition); + n->new_expr = std::move(new_expr); + return Stmt(n); + } +} + +Stmt StmtMutator::VisitStmt_(const IfThenElse* op) { + Expr condition = this->Mutate(op->condition); + Stmt then_case = this->Mutate(op->then_case); + Stmt else_case; + if (op->else_case.defined()) { + else_case = this->Mutate(op->else_case); + } + if (condition.same_as(op->condition) && + then_case.same_as(op->then_case) && + else_case.same_as(op->else_case)) { + return GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->condition = std::move(condition); + n->then_case = std::move(then_case); + n->else_case = std::move(then_case); + return Stmt(n); + } +} + +Stmt StmtMutator::VisitStmt_(const Store* op) { + Expr value = this->Mutate(op->value); + Expr index = this->Mutate(op->index); + Expr predicate = this->Mutate(op->predicate); + if (value.same_as(op->value) && + index.same_as(op->index) && + predicate.same_as(op->predicate)) { + return GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->value = std::move(value); + n->index = std::move(index); + n->predicate = std::move(predicate); + return Stmt(n); + } +} + +Stmt StmtMutator::VisitStmt_(const Provide* op) { + Array args = Internal::Mutate(this, op->args); + Expr value = this->Mutate(op->value); + if (args.same_as(op->args) && + value.same_as(op->value)) { + return GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->args = std::move(args); + n->value = std::move(value); + return Stmt(n); + } +} + +Stmt StmtMutator::VisitStmt_(const Realize* op) { + Region bounds = Internal::Mutate(this, op->bounds); + Stmt body = this->Mutate(op->body); + Expr condition = this->Mutate(op->condition); + if (bounds.same_as(op->bounds) && + body.same_as(op->body) && + condition.same_as(op->condition)) { + return GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->bounds = std::move(bounds); + n->body = std::move(body); + n->condition = std::move(condition); + return Stmt(n); + } +} + +Stmt StmtMutator::VisitStmt_(const Prefetch* op) { + Region bounds = Internal::Mutate(this, op->bounds); + if (bounds.same_as(op->bounds)) { + return GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->bounds = std::move(bounds); + return Stmt(n); + } +} + +Stmt StmtMutator::VisitStmt_(const Block* op) { + Stmt first = this->Mutate(op->first); + Stmt rest = this->Mutate(op->rest); + if (first.same_as(op->first) && + rest.same_as(op->rest)) { + return GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->first = std::move(first); + n->rest = std::move(rest); + return Stmt(n); + } +} + +Stmt StmtMutator::VisitStmt_(const AssertStmt* op) { + Expr condition = this->Mutate(op->condition); + Expr message = this->Mutate(op->message); + Stmt body = this->Mutate(op->body); + + if (condition.same_as(op->condition) && + message.same_as(op->message) && + body.same_as(op->body)) { + return GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->condition = std::move(condition); + n->message = std::move(message); + n->body = std::move(body); + return Stmt(n); + } +} + +Stmt StmtMutator::VisitStmt_(const ProducerConsumer* op) { + Stmt body = this->Mutate(op->body); + if (body.same_as(op->body)) { + return GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->body = std::move(body); + return Stmt(n); + } +} + +Stmt StmtMutator::VisitStmt_(const Evaluate* op) { + Expr value = this->Mutate(op->value); + if (value.same_as(op->value)) { + return GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->value = std::move(value); + return Stmt(n); + } +} + +Stmt StmtMutator::VisitStmt_(const Free* op) { + return GetRef(op); +} + + +Expr ExprMutator::VisitExpr_(const Variable* op) { + return GetRef(op); +} + +Expr ExprMutator::VisitExpr_(const Load* op) { + Expr index = this->Mutate(op->index); + Expr predicate = this->Mutate(op->predicate); + if (index.same_as(op->index) && predicate.same_as(op->predicate)) { + return GetRef(op); + } else { + return Load::make(op->dtype, op->buffer_var, index, predicate); + } +} + +Expr ExprMutator::VisitExpr_(const Let* op) { + Expr value = this->Mutate(op->value); + Expr body = this->Mutate(op->body); + if (value.same_as(op->value) && + body.same_as(op->body)) { + return GetRef(op); + } else { + return Let::make(op->var, value, body); + } +} + +Expr ExprMutator::VisitExpr_(const Call* op) { + auto fmutate = [this](const Expr& e) { return this->Mutate(e); }; + auto args = op->args; + args.MutateByApply(fmutate); + + if (args.same_as(op->args)) { + return GetRef(op); + } else { + return Call::make(op->dtype, + op->name, + args, + op->call_type, + op->func, + op->value_index); + } +} + +#define DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(OP) \ + Expr ExprMutator::VisitExpr_(const OP *op) { \ + return GetRef(op); \ + } + +DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(IntImm) +DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(UIntImm) +DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(FloatImm) +DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(StringImm) + +#define DEFINE_BIOP_EXPR_MUTATE_(OP) \ + Expr ExprMutator::VisitExpr_(const OP* op) { \ + Expr a = this->Mutate(op->a); \ + Expr b = this->Mutate(op->b); \ + if (a.same_as(op->a) && \ + b.same_as(op->b)) { \ + return GetRef(op); \ + } else { \ + return OP::make(a, b); \ + } \ + } + +DEFINE_BIOP_EXPR_MUTATE_(Add); +DEFINE_BIOP_EXPR_MUTATE_(Sub); +DEFINE_BIOP_EXPR_MUTATE_(Mul); +DEFINE_BIOP_EXPR_MUTATE_(Div); +DEFINE_BIOP_EXPR_MUTATE_(Mod); +DEFINE_BIOP_EXPR_MUTATE_(FloorDiv); +DEFINE_BIOP_EXPR_MUTATE_(FloorMod); +DEFINE_BIOP_EXPR_MUTATE_(Min); +DEFINE_BIOP_EXPR_MUTATE_(Max); +DEFINE_BIOP_EXPR_MUTATE_(EQ); +DEFINE_BIOP_EXPR_MUTATE_(NE); +DEFINE_BIOP_EXPR_MUTATE_(LT); +DEFINE_BIOP_EXPR_MUTATE_(LE); +DEFINE_BIOP_EXPR_MUTATE_(GT); +DEFINE_BIOP_EXPR_MUTATE_(GE); +DEFINE_BIOP_EXPR_MUTATE_(And); +DEFINE_BIOP_EXPR_MUTATE_(Or); + +Expr ExprMutator::VisitExpr_(const Reduce* op) { + Array axis = op->axis; + Array source = op->source; + + auto fitervar = [this](const IterVar& v) { + Range r = v->dom; + Expr min = this->Mutate(r->min); + Expr extent = this->Mutate(r->extent); + if (min.same_as(r->min) && + extent.same_as(r->extent)) { + return v; + } else { + return IterVarNode::make( + Range::make_by_min_extent(min, extent), + v->var, v->iter_type, v->thread_tag); + } + }; + axis.MutateByApply(fitervar); + + auto fexpr = [this](const Expr& e) { return this->Mutate(e); }; + source.MutateByApply(fexpr); + + Expr condition = this->Mutate(op->condition); + + if (axis.same_as(op->axis) && + source.same_as(op->source) && + condition.same_as(op->condition)) { + return GetRef(op); + } else { + return Reduce::make( + op->combiner, source, axis, condition, op->value_index); + } +} + +Expr ExprMutator::VisitExpr_(const Cast* op) { + Expr value = this->Mutate(op->value); + if (value.same_as(op->value)) { + return GetRef(op); + } else { + return Cast::make(op->dtype, value); + } +} + +Expr ExprMutator::VisitExpr_(const Not* op) { + Expr a = this->Mutate(op->a); + if (a.same_as(op->a)) { + return GetRef(op); + } else { + return Not::make(a); + } +} + +Expr ExprMutator::VisitExpr_(const Select* op) { + Expr condition = this->Mutate(op->condition); + Expr true_value = this->Mutate(op->true_value); + Expr false_value = this->Mutate(op->false_value); + if (condition.same_as(op->condition) && + true_value.same_as(op->true_value) && + false_value.same_as(op->false_value)) { + return GetRef(op); + } else { + return Select::make(condition, true_value, false_value); + } +} + +Expr ExprMutator::VisitExpr_(const Ramp* op) { + Expr base = this->Mutate(op->base); + Expr stride = this->Mutate(op->stride); + if (base.same_as(op->base) && + stride.same_as(op->stride)) { + return GetRef(op); + } else { + return Ramp::make(base, stride, op->lanes); + } +} + +Expr ExprMutator::VisitExpr_(const Broadcast* op) { + Expr value = this->Mutate(op->value); + if (value.same_as(op->value)) { + return GetRef(op); + } else { + return Broadcast::make(value, op->lanes); + } +} + +Expr ExprMutator::VisitExpr_(const Shuffle* op) { + auto fexpr = [this](const Expr& e) { return this->Mutate(e); }; + auto vectors = MutateArray(op->vectors, fexpr); + if (vectors.same_as(op->vectors)) { + return GetRef(op); + } else { + return Shuffle::make(vectors, op->indices); + } +} + +} // namespace ir +} // namespace tvm diff --git a/src/pass/ir_mutator.cc b/src/pass/ir_mutator.cc index b300989dd2fd4..5ba29fc0b6e9a 100644 --- a/src/pass/ir_mutator.cc +++ b/src/pass/ir_mutator.cc @@ -28,59 +28,6 @@ namespace tvm { namespace ir { -class IRTransformer final : public IRMutator { - public: - IRTransformer(const runtime::PackedFunc& f_preorder, - const runtime::PackedFunc& f_postorder, - const std::unordered_set& only_enable) - : f_preorder_(f_preorder), - f_postorder_(f_postorder), - only_enable_(only_enable) { - } - Stmt Mutate(Stmt stmt) final { - return MutateInternal(stmt); - } - Expr Mutate(Expr expr) final { - return MutateInternal(expr); - } - - private: - template - T MutateInternal(T node) { - if (only_enable_.size() && - !only_enable_.count(node->type_index())) { - return IRMutator::Mutate(node); - } - if (f_preorder_ != nullptr) { - T pre = f_preorder_(node); - if (pre.defined()) return pre; - } - node = IRMutator::Mutate(node); - if (f_postorder_ != nullptr) { - T post = f_postorder_(node); - if (post.defined()) return post; - } - return node; - } - // The functions - const runtime::PackedFunc& f_preorder_; - const runtime::PackedFunc& f_postorder_; - // type indices enabled. - const std::unordered_set& only_enable_; -}; - -Stmt IRTransform(const Stmt& ir_node, - const runtime::PackedFunc& f_preorder, - const runtime::PackedFunc& f_postorder, - const Array& only_enable) { - std::unordered_set only_type_index; - for (Expr s : only_enable) { - only_type_index.insert(Object::TypeKey2Index(s.as()->value.c_str())); - } - return IRTransformer(f_preorder, f_postorder, only_type_index) - .Mutate(ir_node); -} - IRMutator::FMutateExpr& IRMutator::vtable_expr() { // NOLINT(*) static FMutateExpr inst; return inst; } diff --git a/src/pass/ir_visitor.cc b/src/pass/ir_visitor.cc index 467cd5de2ef71..af6ea52521666 100644 --- a/src/pass/ir_visitor.cc +++ b/src/pass/ir_visitor.cc @@ -26,26 +26,6 @@ namespace tvm { namespace ir { -// visitor to implement apply -class IRApplyVisit : public IRVisitor { - public: - explicit IRApplyVisit(std::function f) : f_(f) {} - - void Visit(const ObjectRef& node) final { - if (visited_.count(node.get()) != 0) return; - visited_.insert(node.get()); - IRVisitor::Visit(node); - f_(node); - } - - private: - std::function f_; - std::unordered_set visited_; -}; - -void PostOrderVisit(const ObjectRef& node, std::function fvisit) { - IRApplyVisit(fvisit).Visit(node); -} IRVisitor::FVisit& IRVisitor::vtable() { // NOLINT(*) static FVisit inst; return inst; diff --git a/tests/cpp/ir_functor_test.cc b/tests/cpp/ir_functor_test.cc index 5636958a5b267..85be76b9cb479 100644 --- a/tests/cpp/ir_functor_test.cc +++ b/tests/cpp/ir_functor_test.cc @@ -31,7 +31,6 @@ TEST(IRF, Basic) { auto z = x + 1; NodeFunctor f; - LOG(INFO) << "x"; f.set_dispatch([](const ObjectRef& n, int b) { return b; }); @@ -101,6 +100,68 @@ TEST(IRF, ExprVisit) { CHECK_EQ(v.count, 1); } +TEST(IRF, StmtMutator) { + using namespace tvm; + using namespace tvm::ir; + Var x("x"); + + class MyVisitor + : public ir::StmtMutator, + public ir::ExprMutator { + public: + using StmtMutator::operator(); + using ExprMutator::operator(); + + protected: + // implementation + Expr VisitExpr_(const Add* op) final { + return op->a; + } + Expr Mutate(const Expr& expr) final { + return ExprMutator::Mutate(expr); + } + }; + auto fmaketest = [&]() { + auto z = x + 1; + Stmt body = Evaluate::make(z); + Var buffer("b", DataType::Handle()); + return Allocate::make(buffer, DataType::Float(32), {1, z}, const_true(), body); + }; + + MyVisitor v; + { + auto body = fmaketest(); + Stmt body2 = Evaluate::make(1); + Stmt bref = body.as()->body; + auto* extentptr = body.as()->extents.get(); + Array arr{std::move(body), body2, body2}; + auto* arrptr = arr.get(); + arr.MutateByApply([&](Stmt s) { return v(std::move(s)); }); + CHECK(arr.get() == arrptr); + // inplace update body + CHECK(arr[0].as()->extents[1].same_as(x)); + CHECK(arr[0].as()->extents.get() == extentptr); + // copy because there is additional refs + CHECK(!arr[0].as()->body.same_as(bref)); + CHECK(arr[0].as()->body.as()->value.same_as(x)); + CHECK(bref.as()->value.as()); + } + { + Array arr{fmaketest()}; + // mutate array get reference by another one, triiger copy. + Array arr2 = arr; + auto* arrptr = arr.get(); + arr.MutateByApply([&](Stmt s) { return v(std::move(s)); }); + CHECK(arr.get() != arrptr); + CHECK(arr[0].as()->extents[1].same_as(x)); + CHECK(!arr2[0].as()->extents[1].same_as(x)); + // mutate but no content change. + arr2 = arr; + arr.MutateByApply([&](Stmt s) { return v(std::move(s)); }); + CHECK(arr2.get() == arr.get()); + } +} + int main(int argc, char ** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe";