From 6a62beb2b3f6f7fb7ea5dbb5ef3902af0b362f98 Mon Sep 17 00:00:00 2001 From: Ziheng Jiang Date: Wed, 8 Feb 2017 22:33:52 -0800 Subject: [PATCH] [FUSION] add 'void AutoFuseEwise(Schedule sch)' (#36) * [FUSION] add Fusion(Schedule) * [FUSION] rename to AutoFuseEwise, detect whether the stage has been scheduled * [FUSION] change to visitor pattern * [FUSION] rename filename * [FUSION] fine-tune the interface * [FUSION] typo * move elem_wise to schedule * rename test function --- include/tvm/ir_pass.h | 1 - include/tvm/schedule.h | 11 +++ include/tvm/schedule_pass.h | 7 ++ python/tvm/schedule.py | 2 +- src/api/api_schedule.cc | 5 ++ src/schedule/auto_inline_elem_wise.cc | 76 +++++++++++++++++++ .../unittest/test_schedule_schedule_ops.py | 16 ++++ 7 files changed, 116 insertions(+), 2 deletions(-) create mode 100644 src/schedule/auto_inline_elem_wise.cc diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index b11486d9023a..9e3e1b0a1d53 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -167,7 +167,6 @@ Array SplitHostDevice(LoweredFunc func); */ LoweredFunc StorageSync(LoweredFunc stmt, std::string storage_scope); - } // namespace ir } // namespace tvm diff --git a/include/tvm/schedule.h b/include/tvm/schedule.h index f115dbc6f18f..a7cd58c96524 100644 --- a/include/tvm/schedule.h +++ b/include/tvm/schedule.h @@ -123,6 +123,12 @@ class Stage : public NodeRef { IterVar* p_x_outer, IterVar* p_y_outer, IterVar* p_x_inner, IterVar* p_y_inner, Expr x_factor, Expr y_factor); + /*! + * \brief whether the stage has been scheduled. + * \return whether the stage has been scheduled. + */ + inline bool is_scheduled() const; + // declare container type using ContainerType = StageNode; }; @@ -353,6 +359,11 @@ inline StageNode* Stage::operator->() { return static_cast(node_.get()); } +inline bool Stage::is_scheduled() const { + const StageNode* n = operator->(); + return !(n->relations.empty() && n->attach_type == kNone); +} + inline const ScheduleNode* Schedule::operator->() const { return static_cast(node_.get()); } diff --git a/include/tvm/schedule_pass.h b/include/tvm/schedule_pass.h index 57e442c5c15e..c4e82cde139b 100644 --- a/include/tvm/schedule_pass.h +++ b/include/tvm/schedule_pass.h @@ -33,6 +33,13 @@ Map InferBound(Schedule sch); */ Stmt ScheduleOps(Schedule s, Map dom_map); +/*! + * \brief To automatically inline the element-wise operations. + * + * \param sch The schedule to be inlined. + */ +void AutoInlineElemWise(Schedule sch); + } // namespace schedule } // namespace tvm #endif // TVM_SCHEDULE_PASS_H_ diff --git a/python/tvm/schedule.py b/python/tvm/schedule.py index 3fd7f9730d46..fee0fb3b1274 100644 --- a/python/tvm/schedule.py +++ b/python/tvm/schedule.py @@ -135,7 +135,7 @@ def compute_root(self): parent : Stage The parent stage """ - _api_internal._StageComputeInline(self) + _api_internal._StageComputeRoot(self) def reorder(self, *args): """reorder the arguments in the specified order. diff --git a/src/api/api_schedule.cc b/src/api/api_schedule.cc index a4462117d494..882ff94bde21 100644 --- a/src/api/api_schedule.cc +++ b/src/api/api_schedule.cc @@ -13,6 +13,11 @@ namespace tvm { namespace schedule { +TVM_REGISTER_API(_schedule_AutoInlineElemWise) +.set_body([](TVMArgs args, TVMRetValue* ret) { + AutoInlineElemWise(args[0]); + }); + #define REGISTER_SCHEDULE_PASS1(PassName) \ TVM_REGISTER_API(_schedule_## PassName) \ .set_body([](TVMArgs args, TVMRetValue *ret) { \ diff --git a/src/schedule/auto_inline_elem_wise.cc b/src/schedule/auto_inline_elem_wise.cc new file mode 100644 index 000000000000..66816c955acb --- /dev/null +++ b/src/schedule/auto_inline_elem_wise.cc @@ -0,0 +1,76 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file auto_inline_elem_wise.cc + */ +#include +#include + +namespace tvm { +namespace ir { + +class ElemWiseDetector : public IRVisitor { + public: + explicit ElemWiseDetector(Array axis) : axis_(axis) {} + + void Visit(const NodeRef& e) final { + if (!is_elem_wise_) return; + IRVisitor::Visit(e); + } + + void Visit_(const Call* op) final { + Array axis = op->args; + if (axis_.size() != axis.size()) { + is_elem_wise_ = false; + return; + } + + for (size_t i = 0; i < axis_.size(); ++i) { + // const Variable *v1 = axis_[i]->var.as(); + // const Variable *v2 = axis[i].as(); + if (!axis[i].same_as(axis_[i]->var)) { + // if (!(v1 && v2) || (v1 != v2)) { + is_elem_wise_ = false; + return; + } + } + IRVisitor::Visit_(op); + } + + bool is_elem_wise_{true}; + + private: + Array axis_; +}; + + +bool IsElemWise(const Operation& op) { + if (const ComputeOpNode* compute = op.as()) { + ElemWiseDetector v = ElemWiseDetector(compute->axis); + v.Visit(compute->body); + return v.is_elem_wise_; + } + return false; +} + +} // namespace ir + +namespace schedule { + +void AutoInlineElemWise(Schedule sch) { + for (Stage s : sch->stages) { + if (!s.is_scheduled() && ir::IsElemWise(s->op)) { + bool is_root = false; + for (auto r : sch->roots) { + if (r == s->op) { + is_root = true; + break; + } + } + if (!is_root) + s.compute_inline(); + } + } +} + +} // namespace schedule +} // namespace tvm diff --git a/tests/python/unittest/test_schedule_schedule_ops.py b/tests/python/unittest/test_schedule_schedule_ops.py index feed951e295f..9689a1c34fc4 100644 --- a/tests/python/unittest/test_schedule_schedule_ops.py +++ b/tests/python/unittest/test_schedule_schedule_ops.py @@ -42,8 +42,24 @@ def test_schedule2(): stmt = tvm.schedule.ScheduleOps(s, bounds) print(stmt) +def test_auto_inline(): + m = tvm.Var('m') + n = tvm.Var('n') + A = tvm.placeholder((m, n), name='A') + B = tvm.placeholder((m, n), name='B') + C = tvm.placeholder((m, n), name='C') + T1 = tvm.compute((m, n), lambda i, j: A(i, j) * B(i, j), name='T1') + T2 = tvm.compute((m, n), lambda i, j: T1(i, j) + C(i, j), name='T2') + + s = tvm.Schedule(T2.op) + tvm.schedule.AutoInlineElemWise(s) + bounds = tvm.schedule.InferBound(s) + stmt = tvm.schedule.ScheduleOps(s, bounds) + print(stmt) + if __name__ == "__main__": test_schedule0() test_schedule1() test_schedule2() + test_auto_inline()