From 90e63919ffec9a026a31c06d1426243bb8427e5d Mon Sep 17 00:00:00 2001 From: "jingbang.yjb" Date: Wed, 22 Jul 2020 16:12:23 +0800 Subject: [PATCH 01/14] Add follow split and follow fused split Signed-off-by: jingbang.yjb Conflicts: src/auto_scheduler/compute_dag.cc src/auto_scheduler/transform_step.cc src/auto_scheduler/transform_step.h tests/python/unittest/test_auto_scheduler_loop_state.py --- python/tvm/auto_scheduler/loop_state.py | 60 ++++++ src/auto_scheduler/compute_dag.cc | 4 +- src/auto_scheduler/loop_state.cc | 42 ++++ src/auto_scheduler/loop_state.h | 4 +- src/auto_scheduler/transform_step.cc | 202 +++++++++++++++++- src/auto_scheduler/transform_step.h | 90 +++++++- .../test_auto_scheduler_loop_state.py | 39 ++++ .../unittest/test_auto_scheduler_measure.py | 11 +- 8 files changed, 441 insertions(+), 11 deletions(-) diff --git a/python/tvm/auto_scheduler/loop_state.py b/python/tvm/auto_scheduler/loop_state.py index 0311ee39c37d..163a2e342197 100644 --- a/python/tvm/auto_scheduler/loop_state.py +++ b/python/tvm/auto_scheduler/loop_state.py @@ -116,6 +116,15 @@ def stages(self): stages : List[Stage] """ return self.state_object.stages + + @property + def transform_steps(self): + """ + Returns + ------- + transform_steps : List[transform_steps] + """ + return self.state_object.transform_steps @property def stage_ops(self): @@ -293,6 +302,57 @@ def split(self, stage, iterator, lengths, inner_to_outer=True): self._resolve_stage_id(stage), iterator, lengths, inner_to_outer) return res + + def follow_split(self, stage, iterator, src_step_id, n_split): + """ + Parameters + ---------- + iterator : Iterator + The iterator to split + src_step_id : int + The index of the split step to follow in the history + n_split : int + The number of split level + + Returns + ------- + res_its : List[Iterator] + The splitted new Iterators + """ + + self.state_object, res = _ffi_api.StateFollowSplit(self.state_object, + self._resolve_stage_id(stage), + iterator, + src_step_id, n_split) + return res + + def follow_fused_split(self, stage, iterator, src_step_ids, level, + factor_or_nparts): + """ + Parameters + ---------- + iterator : Iterator + The iterator to split + src_step_ids : List[int] + The indices of the split steps to follow in the history + level : int + Use the length in this split level + factor_or_nparts : bool + True to use `factor` for split from inner to outer, + False to use `nparts` for split from outer to inner + + Returns + ------- + res_its : List[Iterator] + The splitted new Iterators + """ + + self.state_object, res = _ffi_api.StateFollowFusedSplit(self.state_object, + self._resolve_stage_id(stage), + iterator, + src_step_ids, level, + factor_or_nparts) + return res def compute_at(self, stage, target_stage, target_iter): """ Schedule primitive corresponds to te.compute_at. diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index 0d964cb63513..7b987be49e4e 100644 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -254,7 +254,7 @@ std::pair> ComputeDAG::ApplySteps( // Apply the history steps to TVM schedule // Call each step's ApplyToSchedule method for (const auto& step : transform_steps) { - StepApplyToSchedule(step, stages, stage_to_axes, &schedule); + StepApplyToSchedule(step, stages, stage_to_axes, &schedule, transform_steps); } return std::make_pair(schedule, operator->()->tensors); @@ -298,7 +298,7 @@ String ComputeDAG::PrintStepsAsPython(const Array& transform_steps) const } // Call each step's PrintAsPythonAPI method for (const auto& step : transform_steps) { - ss << StepPrintAsPythonAPI(step, &stages, &stage_to_axes, &schedule); + ss << StepPrintAsPythonAPI(step, &stages, &stage_to_axes, &schedule, transform_steps); } return ss.str(); diff --git a/src/auto_scheduler/loop_state.cc b/src/auto_scheduler/loop_state.cc index 18cc6c2537f3..8db726d17f95 100644 --- a/src/auto_scheduler/loop_state.cc +++ b/src/auto_scheduler/loop_state.cc @@ -239,6 +239,28 @@ Iterator State::vectorize(int stage_id, const Iterator& it) { return step->ApplyToState(this); } +Array State::follow_split(int stage_id, const Iterator& it, + int src_step_id, int n_split) { + const Stage& stage = operator->()->stages[stage_id]; + + FollowSplitStep step = FollowSplitStep( + stage_id, GetIndex(stage->iters, it), src_step_id, n_split); + CopyOnWrite()->transform_steps.push_back(step); + return step->ApplyToState(this); +} + +Array State::follow_fused_split( + int stage_id, const Iterator& it, const Array& src_step_ids, + int level, bool factor_or_nparts) { + const Stage& stage = operator->()->stages[stage_id]; + + FollowFusedSplitStep step = + FollowFusedSplitStep(stage_id, GetIndex(stage->iters, it), + src_step_ids, level, factor_or_nparts); + CopyOnWrite()->transform_steps.push_back(step); + return step->ApplyToState(this); +} + Iterator State::fuse(int stage_id, const Array& iters) { const Stage& stage = operator->()->stages[stage_id]; Array indices; @@ -436,6 +458,26 @@ TVM_REGISTER_GLOBAL("auto_scheduler.StateVectorize") return Array{state, res}; }); +TVM_REGISTER_GLOBAL("auto_scheduler.StateFollowSplit") +.set_body_typed([](State state, int stage_id, const Iterator& it, + int src_step_id, int n_split) { + const auto& res = state.follow_split(stage_id, it, src_step_id, n_split); + return Array{state, Array(res)}; +}); + +TVM_REGISTER_GLOBAL("auto_scheduler.StateFollowFusedSplit") +.set_body_typed([](State state, int stage_id, const Iterator& it, + const Array& src_step_ids, int level, + bool factor_or_nparts) { + Array array_src_step_ids; + for (const auto& i : src_step_ids) { + array_src_step_ids.push_back(i->value); + } + const auto& res = state.follow_fused_split( + stage_id, it, array_src_step_ids, level, factor_or_nparts); + return Array{state, Array(res)}; +}); + TVM_REGISTER_GLOBAL("auto_scheduler.StateFuse") .set_body_typed([](State state, int stage_id, const Array& iters) { const auto& res = state.fuse(stage_id, iters); diff --git a/src/auto_scheduler/loop_state.h b/src/auto_scheduler/loop_state.h index fb07e5a0a32e..d8ea9d467753 100644 --- a/src/auto_scheduler/loop_state.h +++ b/src/auto_scheduler/loop_state.h @@ -353,8 +353,10 @@ class State : public ObjectRef { */ Array split(int stage_id, const Iterator& it, const Array>& lengths, bool inner_to_outer = true); + + Array follow_split(int stage_id, const Iterator& it, int src_step_id, int n_split); - /********** Step APIs working on multiple stages **********/ + Array follow_fused_split(int stage_id, const Iterator& it, const Array& src_step_ids, int level, bool factor_or_nparts); /*! * \brief Schedule primitive corresponds to te.compute_at. diff --git a/src/auto_scheduler/transform_step.cc b/src/auto_scheduler/transform_step.cc index e6450125772b..d8fbbb221dab 100644 --- a/src/auto_scheduler/transform_step.cc +++ b/src/auto_scheduler/transform_step.cc @@ -96,7 +96,12 @@ Step StepReadFromRecord(dmlc::JSONReader* reader) { return CacheReadStep(reader); } else if (name == CacheWriteStepNode::record_prefix_str) { return CacheWriteStep(reader); - } else { + } else if (name == FollowSplitStepNode::record_prefix_str) { + return FollowSplitStep(reader); + } else if (name == FollowFusedSplitStepNode::record_prefix_str) { + return FollowFusedSplitStep(reader); + } + else { LOG(FATAL) << "Invalid step format: " << name; } return Step(); @@ -121,13 +126,17 @@ void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag) { ps->ApplyToState(state, dag); } else if (auto ps = step.as()) { ps->ApplyToState(state, dag); + } else if(auto ps = step.as()){ + ps->ApplyToState(state); + } else if(auto ps = step.as()){ + ps->ApplyToState(state); } else { LOG(FATAL) << "Invalid step: " << step; } } void StepApplyToSchedule(const Step& step, Array* stages, StageToAxesMap* stage_to_axes, - te::Schedule* schedule) { + te::Schedule* schedule, const Array& transform_steps) { if (auto ps = step.as()) { ps->ApplyToSchedule(stages, stage_to_axes); } else if (auto ps = step.as()) { @@ -146,13 +155,17 @@ void StepApplyToSchedule(const Step& step, Array* stages, StageToAxes ps->ApplyToSchedule(stages, stage_to_axes, schedule); } else if (auto ps = step.as()) { ps->ApplyToSchedule(stages, stage_to_axes, schedule); + } else if(auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes, transform_steps); + } else if(auto ps = step.as()){ + ps->ApplyToSchedule(stages, stage_to_axes, transform_steps); } else { LOG(FATAL) << "Invalid Step: " << step; } } String StepPrintAsPythonAPI(const Step& step, Array* stages, - StageToAxesMap* stage_to_axes, te::Schedule* schedule) { + StageToAxesMap* stage_to_axes, te::Schedule* schedule, const Array& transform_steps) { if (auto ps = step.as()) { return ps->PrintAsPythonAPI(stages, stage_to_axes); } else if (auto ps = step.as()) { @@ -171,6 +184,10 @@ String StepPrintAsPythonAPI(const Step& step, Array* stages, return ps->PrintAsPythonAPI(stages, stage_to_axes, schedule); } else if (auto ps = step.as()) { return ps->PrintAsPythonAPI(stages, stage_to_axes, schedule); + } else if(auto ps = step.as()){ + return ps->PrintAsPythonAPI(stages, stage_to_axes, transform_steps); + } else if(auto ps = step.as()){ + return ps->PrintAsPythonAPI(stages, stage_to_axes, transform_steps); } else { LOG(FATAL) << "Invalid Step: " << step; } @@ -776,6 +793,185 @@ String SplitStepNode::PrintAsPythonAPI(Array* stages, return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, lengths, inner_to_outer); } +/********** Follow Split **********/ +FollowSplitStep::FollowSplitStep(int stage_id, int iter_id, + int src_step_id, int n_split) { + auto node = make_object(); + node->stage_id = stage_id; + node->iter_id = iter_id; + node->src_step_id = src_step_id; + node->n_split = n_split; + data_ = std::move(node); +} + +void FollowSplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { + writer->WriteArraySeperator(); + writer->WriteString(record_prefix_str); + writer->WriteArrayItem(stage_id); + writer->WriteArrayItem(iter_id); + writer->WriteArrayItem(src_step_id); + writer->WriteArrayItem(n_split); +} + +void FollowSplitStepNode::ExtractSplitLengths( + const Array& transform_steps, + Array>* lengths) const { + CHECK_LT(src_step_id, transform_steps.size()); + auto ps = transform_steps[src_step_id].as(); + CHECK(ps != nullptr); + + // get lengths from src step + lengths->reserve(n_split); + int j = 0; + for (; j < n_split - 1; ++j) { + lengths->push_back(ps->lengths[j]); + } + PrimExpr last_factor = 1; + for (; j < static_cast(ps->lengths.size()); ++j) { + if (ps->lengths[j]) { + last_factor *= ps->lengths[j].value(); + } else { + last_factor = PrimExpr(); + break; + } + } + if (last_factor.defined()) { + lengths->push_back(Downcast(last_factor)); + } else { + lengths->push_back(NullOpt); + } +} + +FollowSplitStep::FollowSplitStep(dmlc::JSONReader* reader) { + auto node = make_object(); + bool s; + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->stage_id); + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->iter_id); + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->src_step_id); + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->n_split); + + data_ = std::move(node); +} + +Array FollowSplitStepNode::ApplyToState(State* state) const { + Array> lengths; + ExtractSplitLengths((*state)->transform_steps, &lengths); + return ApplySplitToState(state, stage_id, iter_id, lengths, true); +} + +Array FollowSplitStepNode::ApplyToSchedule( + Array *stages, StageToAxesMap *stage_to_axes, + const Array& transform_steps) const { + Array> lengths; + ExtractSplitLengths(transform_steps, &lengths); + return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, + lengths, true); +} + +String FollowSplitStepNode::PrintAsPythonAPI( + Array *stages, StageToAxesMap *stage_to_axes, const Array& transform_steps) const { + Array> lengths; + ExtractSplitLengths(transform_steps, &lengths); + return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, + lengths, true); +} + +/********** Follow Fused Split **********/ +FollowFusedSplitStep::FollowFusedSplitStep(int stage_id, int iter_id, + const Array& src_step_ids, int level, bool factor_or_nparts) { + auto node = make_object(); + node->stage_id = stage_id; + node->iter_id = iter_id; + node->src_step_ids = src_step_ids;; + node->level = level; + node->factor_or_nparts = factor_or_nparts; + data_ = std::move(node); +} + +FollowFusedSplitStep::FollowFusedSplitStep(dmlc::JSONReader* reader) { + auto node = make_object(); + bool s; + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->stage_id); + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->iter_id); + s = reader->NextArrayItem(); + CHECK(s); + std::vector int_list; + reader->Read(&int_list); + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->level); + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->factor_or_nparts); + + ::tvm::Array<::tvm::Integer> src_step_ids; + for (const auto& i : int_list) { + src_step_ids.push_back(i); + } + node->src_step_ids = src_step_ids; + data_ = std::move(node); +} + +void FollowFusedSplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { + writer->WriteArraySeperator(); + writer->WriteString(record_prefix_str); + writer->WriteArrayItem(stage_id); + writer->WriteArrayItem(iter_id); + writer->WriteArrayItem(IntArrayToVector(src_step_ids)); + writer->WriteArrayItem(level); + writer->WriteArrayItem(static_cast(factor_or_nparts)); +} + +Optional FollowFusedSplitStepNode::ExtractSplitLength( + const Array& transform_steps) const { + PrimExpr ret(1); + + for (int src_step_id : src_step_ids) { + CHECK_LT(src_step_id, transform_steps.size()); + auto ps = transform_steps[src_step_id].as(); + CHECK(ps != nullptr); + if (ps->lengths[level] && ret.defined()) { + ret *= ps->lengths[level].value(); + } else { + return NullOpt; + } + } + return Downcast(ret); +} + +Array FollowFusedSplitStepNode::ApplyToState(State* state) const { + const Optional& length = ExtractSplitLength((*state)->transform_steps); + return ApplySplitToState(state, stage_id, iter_id, {length}, factor_or_nparts); +} + +Array FollowFusedSplitStepNode::ApplyToSchedule( + Array *stages, StageToAxesMap *stage_to_axes, + const Array& transform_steps) const { + const Optional& length = ExtractSplitLength(transform_steps); + return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, + {length}, factor_or_nparts); +} + +String FollowFusedSplitStepNode::PrintAsPythonAPI( + Array *stages, StageToAxesMap *stage_to_axes, const Array& transform_steps) const { + const Optional& length = ExtractSplitLength(transform_steps); + return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, + {length}, factor_or_nparts); +} + + /********** Primitives working on multiple stages **********/ /********** Compute At **********/ diff --git a/src/auto_scheduler/transform_step.h b/src/auto_scheduler/transform_step.h index e1746189c29e..0a1044f38988 100644 --- a/src/auto_scheduler/transform_step.h +++ b/src/auto_scheduler/transform_step.h @@ -207,7 +207,7 @@ void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag); * CacheRead/CacheWrite step) */ void StepApplyToSchedule(const Step& step, Array* stages, StageToAxesMap* stage_to_axes, - te::Schedule* schedule); + te::Schedule* schedule, const Array& transform_steps); /*! * \brief Print the step as equivalent python schedule API. @@ -219,8 +219,8 @@ void StepApplyToSchedule(const Step& step, Array* stages, StageToAxes * \return Python schedule code. */ String StepPrintAsPythonAPI(const Step& step, Array* stages, - StageToAxesMap* stage_to_axes, te::Schedule* schedule); - + StageToAxesMap* stage_to_axes, te::Schedule* schedule, + const Array& transform_steps); /********** Primitives working on single stage **********/ /*! @@ -411,6 +411,90 @@ class ReorderStep : public Step { TVM_DEFINE_OBJECT_REF_METHODS(ReorderStep, Step, ReorderStepNode); }; +/*! \brief Similar to SplitStepNode, but use split factor from another stepf + * (i.e. Follow another split step) */ +class FollowSplitStepNode: public StepNode { + public: + int iter_id; // The id of the iter to split + int src_step_id; // The index of the split step to follow in the history + int n_split; // The number of split level + + void WriteToRecord(dmlc::JSONWriter* writer) const final; + + void ExtractSplitLengths( const Array& transform_steps, + Array>* lengths) const; + + Array ApplyToState(State* state) const; + + Array ApplyToSchedule(Array *stages, + StageToAxesMap *stage_to_axes, + const Array& transform_steps) const; + + String PrintAsPythonAPI(Array *stages, + StageToAxesMap *stage_to_axes, + const Array& transform_steps) const; + static constexpr const char* record_prefix_str = "FSP"; + static constexpr const char* _type_key = "auto_scheduler.FollowSplitStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(FollowSplitStepNode, Object); +}; + +/*! + * \brief Managed reference to FollowSplitStepNode. + * \sa FollowSplitStepNode + */ +class FollowSplitStep : public Step { + public: + FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split); + + explicit FollowSplitStep(dmlc::JSONReader* reader); + + TVM_DEFINE_OBJECT_REF_METHODS(FollowSplitStep, Step, FollowSplitStepNode); +}; + +/*! \brief Similar to FollowSplitStep, but use split factors from multiple steps. + * \Note This can be used for the split in cooperative fetching + */ +class FollowFusedSplitStepNode: public StepNode { + public: + int iter_id; // The id of the iter to split + Array src_step_ids; // The indices of the split steps to follow in the history + int level; // Use the length in this split level + bool factor_or_nparts; // If this is true, use factor. Otherwise, use nparts + + void WriteToRecord(dmlc::JSONWriter* writer) const final; + + Optional ExtractSplitLength( const Array& transform_steps) const; + + Array ApplyToState(State* state) const; + + Array ApplyToSchedule(Array *stages, + StageToAxesMap *stage_to_axes, + const Array& transform_steps) const; + + String PrintAsPythonAPI(Array *stages, + StageToAxesMap *stage_to_axes, + const Array& transform_steps) const; + + static constexpr const char* record_prefix_str = "FFSP"; + static constexpr const char* _type_key = "auto_scheduler.FollowFusedSplitStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(FollowFusedSplitStepNode, Object); +}; + +/*! + * \brief Managed reference to FollowFusedSplitStepNode. + * \sa FollowFusedSplitStepNode + */ +class FollowFusedSplitStep : public Step { + public: + FollowFusedSplitStep(int stage_id, int iter_id, + const Array& src_step_ids, + int level, bool factor_or_nparts); + + explicit FollowFusedSplitStep(dmlc::JSONReader* reader); + + TVM_DEFINE_OBJECT_REF_METHODS(FollowFusedSplitStep, Step, FollowFusedSplitStepNode); +}; + /*! * \brief Split step that corresponds to te::Stage::split with additional * support of multiple-level of factors diff --git a/tests/python/unittest/test_auto_scheduler_loop_state.py b/tests/python/unittest/test_auto_scheduler_loop_state.py index 8c9d635b526c..0243b40904df 100644 --- a/tests/python/unittest/test_auto_scheduler_loop_state.py +++ b/tests/python/unittest/test_auto_scheduler_loop_state.py @@ -142,6 +142,44 @@ def test_compute_at_root_inline(): assert s0[conv].iters[5].range.extent == 7 assert s0[conv].iters[6].range.extent == 7 +def test_follow_split_follow_fused_split(): + A, B, C = matmul_auto_scheduler_test(512, 512, 512) + dag = auto_scheduler.ComputeDAG([A, B, C]) + s0 = dag.get_init_state() + + C_global = s0.cache_write(C, "global") + its0 = s0.split(C, s0[C].iters[0], [4, 2, 8, 4], True) + split_step0 = len(s0.transform_steps) - 1 + for level in range(1, 6): + tmp = s0.copy() + tmp.follow_split(C_global, tmp[C_global].iters[0], split_step0, level) + for i in range(0, level): + assert tmp[C].iters[i].range.extent == \ + tmp[C_global].iters[i].range.extent + + its1 = s0.split(C, s0[C].iters[5], [2, 2, 4, 8]) + split_step1 = len(s0.transform_steps) - 1 + its = [] + for i0, i1 in zip(its0, its1): + its.append(i0) + its.append(i1) + s0.reorder(C, its) + for i in range(0, 5): + s0.fuse(C, [s0[C].iters[i], s0[C].iters[i + 1]]) + + for level in range(0, 4): + tmp = s0.copy() + tmp.follow_fused_split(C_global, tmp[C_global].iters[0], + [split_step0, split_step1], level, False) + assert tmp[C].iters[level + 1].range.extent == \ + tmp[C_global].iters[0].range.extent + + for level in range(0, 4): + tmp = s0.copy() + tmp.follow_fused_split(C_global, tmp[C_global].iters[0], + [split_step0, split_step1], level, True) + assert tmp[C].iters[level + 1].range.extent == \ + tmp[C_global].iters[1].range.extent def test_cache_read_write(): N, H, W, CO, CI, KH, KW, strides, padding = 4, 7, 7, 512, 512, 3, 3, ( @@ -422,3 +460,4 @@ def test_cache_read_write(): test_split_fuse_reorder_annotation() test_compute_at_root_inline() test_cache_read_write() + test_follow_split_follow_fused_split() diff --git a/tests/python/unittest/test_auto_scheduler_measure.py b/tests/python/unittest/test_auto_scheduler_measure.py index 5f2f87ad9baa..039dff921144 100644 --- a/tests/python/unittest/test_auto_scheduler_measure.py +++ b/tests/python/unittest/test_auto_scheduler_measure.py @@ -37,10 +37,17 @@ def test_record(): k = te.reduce_axis((0, 512), name='k') E = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * D[k][j], axis=[k]), name='E') F = topi.nn.relu(E) - dag = auto_scheduler.ComputeDAG([A, B, F]) s = dag.get_init_state() - + + # Follow split + C_global = s0.cache_write(C, "global") + split_step0 = len(s0.transform_steps) - 1 + split_step1 = len(s0.transform_steps) - 1 + s.follow_split(C_global, tmp[C_global].iters[0], split_step0, 0) + # Follow fused split + s.follow_fused_split(C_global, tmp[C_global].iters[0], + [split_step0, split_step1], 0, False) # Split its0 = s.split(C, s[C].iters[0], [4, 8, 8]) its1 = s.split(C, s[C].iters[4], [8, 4, 4]) From e144082c1db5b9c62aebbe10909989b8dd069a02 Mon Sep 17 00:00:00 2001 From: "jingbang.yjb" Date: Wed, 22 Jul 2020 20:01:29 +0800 Subject: [PATCH 02/14] add loop_state.py Signed-off-by: jingbang.yjb --- python/tvm/auto_scheduler/loop_state.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/python/tvm/auto_scheduler/loop_state.py b/python/tvm/auto_scheduler/loop_state.py index 163a2e342197..0bdd35aed40f 100644 --- a/python/tvm/auto_scheduler/loop_state.py +++ b/python/tvm/auto_scheduler/loop_state.py @@ -116,7 +116,7 @@ def stages(self): stages : List[Stage] """ return self.state_object.stages - + @property def transform_steps(self): """ @@ -302,7 +302,7 @@ def split(self, stage, iterator, lengths, inner_to_outer=True): self._resolve_stage_id(stage), iterator, lengths, inner_to_outer) return res - + def follow_split(self, stage, iterator, src_step_id, n_split): """ Parameters @@ -320,12 +320,12 @@ def follow_split(self, stage, iterator, src_step_id, n_split): The splitted new Iterators """ - self.state_object, res = _ffi_api.StateFollowSplit(self.state_object, - self._resolve_stage_id(stage), + self.state_object, res = _ffi_api.StateFollowSplit(self.state_object, + self._resolve_stage_id(stage), iterator, src_step_id, n_split) return res - + def follow_fused_split(self, stage, iterator, src_step_ids, level, factor_or_nparts): """ @@ -347,9 +347,9 @@ def follow_fused_split(self, stage, iterator, src_step_ids, level, The splitted new Iterators """ - self.state_object, res = _ffi_api.StateFollowFusedSplit(self.state_object, + self.state_object, res = _ffi_api.StateFollowFusedSplit(self.state_object, self._resolve_stage_id(stage), - iterator, + iterator, src_step_ids, level, factor_or_nparts) return res From c4a344cbf524ffc383488988d01ddf8087963745 Mon Sep 17 00:00:00 2001 From: "jingbang.yjb" Date: Thu, 23 Jul 2020 11:57:08 +0800 Subject: [PATCH 03/14] Add some doc strings for Follow_Split and Follow_fused_split Signed-off-by: jingbang.yjb --- python/tvm/auto_scheduler/loop_state.py | 31 ++++++++++++--------- src/auto_scheduler/loop_state.h | 22 +++++++++++++-- src/auto_scheduler/transform_step.cc | 37 ++++++++++++------------- 3 files changed, 56 insertions(+), 34 deletions(-) diff --git a/python/tvm/auto_scheduler/loop_state.py b/python/tvm/auto_scheduler/loop_state.py index 0bdd35aed40f..5904ed9cd588 100644 --- a/python/tvm/auto_scheduler/loop_state.py +++ b/python/tvm/auto_scheduler/loop_state.py @@ -305,19 +305,23 @@ def split(self, stage, iterator, lengths, inner_to_outer=True): def follow_split(self, stage, iterator, src_step_id, n_split): """ + Schedule primitive corresponds to te.follow_split. Parameters ---------- + stage : Union[int, Operation, Tensor] + The Stage to be split, which can be specified by the integer index, Operation, + or output tensor of the stage. iterator : Iterator - The iterator to split - src_step_id : int - The index of the split step to follow in the history - n_split : int - The number of split level + The iterator to split. + src_step_id : Int + The index of the split step to follow in the history. + n_split : Int + The number of split level. Returns ------- res_its : List[Iterator] - The splitted new Iterators + The splitted new Iterators. """ self.state_object, res = _ffi_api.StateFollowSplit(self.state_object, @@ -329,22 +333,23 @@ def follow_split(self, stage, iterator, src_step_id, n_split): def follow_fused_split(self, stage, iterator, src_step_ids, level, factor_or_nparts): """ + Schedule primitive corresponds to te.follow_fused_split. Parameters ---------- iterator : Iterator - The iterator to split + The iterator to split. src_step_ids : List[int] - The indices of the split steps to follow in the history - level : int - Use the length in this split level - factor_or_nparts : bool + The indices of the split steps to follow in the history. + level : Int + Use the length in this split level. + factor_or_nparts : Bool True to use `factor` for split from inner to outer, - False to use `nparts` for split from outer to inner + False to use `nparts` for split from outer to inner. Returns ------- res_its : List[Iterator] - The splitted new Iterators + The splitted new Iterators. """ self.state_object, res = _ffi_api.StateFollowFusedSplit(self.state_object, diff --git a/src/auto_scheduler/loop_state.h b/src/auto_scheduler/loop_state.h index d8ea9d467753..cc2fd2d698cb 100644 --- a/src/auto_scheduler/loop_state.h +++ b/src/auto_scheduler/loop_state.h @@ -353,9 +353,27 @@ class State : public ObjectRef { */ Array split(int stage_id, const Iterator& it, const Array>& lengths, bool inner_to_outer = true); - - Array follow_split(int stage_id, const Iterator& it, int src_step_id, int n_split); + /********** Step APIs working on multiple stages **********/ + /*! + * \brief Schedule primitive corresponds to te.follow_split. + * \param stage_id The index of the stage to be split. + * \param it The iterator to be split. + * \param src_step_id The index of the split step to follow in the history. + * \param n_split The number of split level. + * \return The splitted new Iterators. + */ + Array follow_split(int stage_id, const Iterator& it, int src_step_id, int n_split); + /*! + * \brief Schedule primitive corresponds to te.follow_split. + * \param stage_id The index of the stage to be split. + * \param it The iterator to be split. + * \param src_step_ids The indices of the split steps to follow in the history. + * \param level Use the length in this split level. + * \param factor_or_nparts True to use `factor` for split from inner to outer, + False to use `nparts` for split from outer to inner. + * \return The splitted new Iterators. + */ Array follow_fused_split(int stage_id, const Iterator& it, const Array& src_step_ids, int level, bool factor_or_nparts); /*! diff --git a/src/auto_scheduler/transform_step.cc b/src/auto_scheduler/transform_step.cc index d8fbbb221dab..715bbe97eb09 100644 --- a/src/auto_scheduler/transform_step.cc +++ b/src/auto_scheduler/transform_step.cc @@ -86,6 +86,10 @@ Step StepReadFromRecord(dmlc::JSONReader* reader) { return ReorderStep(reader); } else if (name == SplitStepNode::record_prefix_str) { return SplitStep(reader); + } else if (name == FollowSplitStepNode::record_prefix_str) { + return FollowSplitStep(reader); + } else if (name == FollowFusedSplitStepNode::record_prefix_str) { + return FollowFusedSplitStep(reader); } else if (name == ComputeAtStepNode::record_prefix_str) { return ComputeAtStep(reader); } else if (name == ComputeInlineStepNode::record_prefix_str) { @@ -96,12 +100,7 @@ Step StepReadFromRecord(dmlc::JSONReader* reader) { return CacheReadStep(reader); } else if (name == CacheWriteStepNode::record_prefix_str) { return CacheWriteStep(reader); - } else if (name == FollowSplitStepNode::record_prefix_str) { - return FollowSplitStep(reader); - } else if (name == FollowFusedSplitStepNode::record_prefix_str) { - return FollowFusedSplitStep(reader); - } - else { + } else { LOG(FATAL) << "Invalid step format: " << name; } return Step(); @@ -116,6 +115,10 @@ void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag) { ps->ApplyToState(state); } else if (auto ps = step.as()) { ps->ApplyToState(state); + } else if(auto ps = step.as()){ + ps->ApplyToState(state); + } else if(auto ps = step.as()){ + ps->ApplyToState(state); } else if (auto ps = step.as()) { ps->ApplyToState(state); } else if (auto ps = step.as()) { @@ -126,10 +129,6 @@ void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag) { ps->ApplyToState(state, dag); } else if (auto ps = step.as()) { ps->ApplyToState(state, dag); - } else if(auto ps = step.as()){ - ps->ApplyToState(state); - } else if(auto ps = step.as()){ - ps->ApplyToState(state); } else { LOG(FATAL) << "Invalid step: " << step; } @@ -145,6 +144,10 @@ void StepApplyToSchedule(const Step& step, Array* stages, StageToAxes ps->ApplyToSchedule(stages, stage_to_axes); } else if (auto ps = step.as()) { ps->ApplyToSchedule(stages, stage_to_axes); + } else if(auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes, transform_steps); + } else if(auto ps = step.as()){ + ps->ApplyToSchedule(stages, stage_to_axes, transform_steps); } else if (auto ps = step.as()) { ps->ApplyToSchedule(stages, stage_to_axes); } else if (auto ps = step.as()) { @@ -155,10 +158,6 @@ void StepApplyToSchedule(const Step& step, Array* stages, StageToAxes ps->ApplyToSchedule(stages, stage_to_axes, schedule); } else if (auto ps = step.as()) { ps->ApplyToSchedule(stages, stage_to_axes, schedule); - } else if(auto ps = step.as()) { - ps->ApplyToSchedule(stages, stage_to_axes, transform_steps); - } else if(auto ps = step.as()){ - ps->ApplyToSchedule(stages, stage_to_axes, transform_steps); } else { LOG(FATAL) << "Invalid Step: " << step; } @@ -174,6 +173,10 @@ String StepPrintAsPythonAPI(const Step& step, Array* stages, return ps->PrintAsPythonAPI(stages, stage_to_axes); } else if (auto ps = step.as()) { return ps->PrintAsPythonAPI(stages, stage_to_axes); + } else if(auto ps = step.as()){ + return ps->PrintAsPythonAPI(stages, stage_to_axes, transform_steps); + } else if(auto ps = step.as()){ + return ps->PrintAsPythonAPI(stages, stage_to_axes, transform_steps); } else if (auto ps = step.as()) { return ps->PrintAsPythonAPI(stages, stage_to_axes); } else if (auto ps = step.as()) { @@ -184,11 +187,7 @@ String StepPrintAsPythonAPI(const Step& step, Array* stages, return ps->PrintAsPythonAPI(stages, stage_to_axes, schedule); } else if (auto ps = step.as()) { return ps->PrintAsPythonAPI(stages, stage_to_axes, schedule); - } else if(auto ps = step.as()){ - return ps->PrintAsPythonAPI(stages, stage_to_axes, transform_steps); - } else if(auto ps = step.as()){ - return ps->PrintAsPythonAPI(stages, stage_to_axes, transform_steps); - } else { + } else { LOG(FATAL) << "Invalid Step: " << step; } return ""; From d3969b8d5d8152c2644771e537d7796858c9df42 Mon Sep 17 00:00:00 2001 From: "jingbang.yjb" Date: Thu, 23 Jul 2020 13:54:01 +0800 Subject: [PATCH 04/14] Check code using c-lint Signed-off-by: jingbang.yjb --- src/auto_scheduler/loop_state.cc | 50 ++++++++++----------- src/auto_scheduler/loop_state.h | 4 +- src/auto_scheduler/transform_step.cc | 66 ++++++++++++++-------------- src/auto_scheduler/transform_step.h | 47 +++++++++----------- 4 files changed, 80 insertions(+), 87 deletions(-) diff --git a/src/auto_scheduler/loop_state.cc b/src/auto_scheduler/loop_state.cc index 8db726d17f95..7aaddf129f6d 100644 --- a/src/auto_scheduler/loop_state.cc +++ b/src/auto_scheduler/loop_state.cc @@ -239,24 +239,23 @@ Iterator State::vectorize(int stage_id, const Iterator& it) { return step->ApplyToState(this); } -Array State::follow_split(int stage_id, const Iterator& it, - int src_step_id, int n_split) { +Array State::follow_split(int stage_id, const Iterator& it, int src_step_id, + int n_split) { const Stage& stage = operator->()->stages[stage_id]; - FollowSplitStep step = FollowSplitStep( - stage_id, GetIndex(stage->iters, it), src_step_id, n_split); + FollowSplitStep step = + FollowSplitStep(stage_id, GetIndex(stage->iters, it), src_step_id, n_split); CopyOnWrite()->transform_steps.push_back(step); return step->ApplyToState(this); } -Array State::follow_fused_split( - int stage_id, const Iterator& it, const Array& src_step_ids, - int level, bool factor_or_nparts) { +Array State::follow_fused_split(int stage_id, const Iterator& it, + const Array& src_step_ids, int level, + bool factor_or_nparts) { const Stage& stage = operator->()->stages[stage_id]; - FollowFusedSplitStep step = - FollowFusedSplitStep(stage_id, GetIndex(stage->iters, it), - src_step_ids, level, factor_or_nparts); + FollowFusedSplitStep step = FollowFusedSplitStep(stage_id, GetIndex(stage->iters, it), + src_step_ids, level, factor_or_nparts); CopyOnWrite()->transform_steps.push_back(step); return step->ApplyToState(this); } @@ -459,24 +458,23 @@ TVM_REGISTER_GLOBAL("auto_scheduler.StateVectorize") }); TVM_REGISTER_GLOBAL("auto_scheduler.StateFollowSplit") -.set_body_typed([](State state, int stage_id, const Iterator& it, - int src_step_id, int n_split) { - const auto& res = state.follow_split(stage_id, it, src_step_id, n_split); - return Array{state, Array(res)}; -}); + .set_body_typed([](State state, int stage_id, const Iterator& it, int src_step_id, + int n_split) { + const auto& res = state.follow_split(stage_id, it, src_step_id, n_split); + return Array{state, Array(res)}; + }); TVM_REGISTER_GLOBAL("auto_scheduler.StateFollowFusedSplit") -.set_body_typed([](State state, int stage_id, const Iterator& it, - const Array& src_step_ids, int level, - bool factor_or_nparts) { - Array array_src_step_ids; - for (const auto& i : src_step_ids) { - array_src_step_ids.push_back(i->value); - } - const auto& res = state.follow_fused_split( - stage_id, it, array_src_step_ids, level, factor_or_nparts); - return Array{state, Array(res)}; -}); + .set_body_typed([](State state, int stage_id, const Iterator& it, + const Array& src_step_ids, int level, bool factor_or_nparts) { + Array array_src_step_ids; + for (const auto& i : src_step_ids) { + array_src_step_ids.push_back(i->value); + } + const auto& res = + state.follow_fused_split(stage_id, it, array_src_step_ids, level, factor_or_nparts); + return Array{state, Array(res)}; + }); TVM_REGISTER_GLOBAL("auto_scheduler.StateFuse") .set_body_typed([](State state, int stage_id, const Array& iters) { diff --git a/src/auto_scheduler/loop_state.h b/src/auto_scheduler/loop_state.h index cc2fd2d698cb..0d8823bb3ef6 100644 --- a/src/auto_scheduler/loop_state.h +++ b/src/auto_scheduler/loop_state.h @@ -374,7 +374,9 @@ class State : public ObjectRef { False to use `nparts` for split from outer to inner. * \return The splitted new Iterators. */ - Array follow_fused_split(int stage_id, const Iterator& it, const Array& src_step_ids, int level, bool factor_or_nparts); + Array follow_fused_split(int stage_id, const Iterator& it, + const Array& src_step_ids, int level, + bool factor_or_nparts); /*! * \brief Schedule primitive corresponds to te.compute_at. diff --git a/src/auto_scheduler/transform_step.cc b/src/auto_scheduler/transform_step.cc index 715bbe97eb09..5b26c5e98e3a 100644 --- a/src/auto_scheduler/transform_step.cc +++ b/src/auto_scheduler/transform_step.cc @@ -115,9 +115,9 @@ void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag) { ps->ApplyToState(state); } else if (auto ps = step.as()) { ps->ApplyToState(state); - } else if(auto ps = step.as()){ + } else if (auto ps = step.as()) { ps->ApplyToState(state); - } else if(auto ps = step.as()){ + } else if (auto ps = step.as()) { ps->ApplyToState(state); } else if (auto ps = step.as()) { ps->ApplyToState(state); @@ -144,9 +144,9 @@ void StepApplyToSchedule(const Step& step, Array* stages, StageToAxes ps->ApplyToSchedule(stages, stage_to_axes); } else if (auto ps = step.as()) { ps->ApplyToSchedule(stages, stage_to_axes); - } else if(auto ps = step.as()) { + } else if (auto ps = step.as()) { ps->ApplyToSchedule(stages, stage_to_axes, transform_steps); - } else if(auto ps = step.as()){ + } else if (auto ps = step.as()) { ps->ApplyToSchedule(stages, stage_to_axes, transform_steps); } else if (auto ps = step.as()) { ps->ApplyToSchedule(stages, stage_to_axes); @@ -164,7 +164,8 @@ void StepApplyToSchedule(const Step& step, Array* stages, StageToAxes } String StepPrintAsPythonAPI(const Step& step, Array* stages, - StageToAxesMap* stage_to_axes, te::Schedule* schedule, const Array& transform_steps) { + StageToAxesMap* stage_to_axes, te::Schedule* schedule, + const Array& transform_steps) { if (auto ps = step.as()) { return ps->PrintAsPythonAPI(stages, stage_to_axes); } else if (auto ps = step.as()) { @@ -173,9 +174,9 @@ String StepPrintAsPythonAPI(const Step& step, Array* stages, return ps->PrintAsPythonAPI(stages, stage_to_axes); } else if (auto ps = step.as()) { return ps->PrintAsPythonAPI(stages, stage_to_axes); - } else if(auto ps = step.as()){ + } else if (auto ps = step.as()) { return ps->PrintAsPythonAPI(stages, stage_to_axes, transform_steps); - } else if(auto ps = step.as()){ + } else if (auto ps = step.as()) { return ps->PrintAsPythonAPI(stages, stage_to_axes, transform_steps); } else if (auto ps = step.as()) { return ps->PrintAsPythonAPI(stages, stage_to_axes); @@ -187,7 +188,7 @@ String StepPrintAsPythonAPI(const Step& step, Array* stages, return ps->PrintAsPythonAPI(stages, stage_to_axes, schedule); } else if (auto ps = step.as()) { return ps->PrintAsPythonAPI(stages, stage_to_axes, schedule); - } else { + } else { LOG(FATAL) << "Invalid Step: " << step; } return ""; @@ -793,8 +794,7 @@ String SplitStepNode::PrintAsPythonAPI(Array* stages, } /********** Follow Split **********/ -FollowSplitStep::FollowSplitStep(int stage_id, int iter_id, - int src_step_id, int n_split) { +FollowSplitStep::FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split) { auto node = make_object(); node->stage_id = stage_id; node->iter_id = iter_id; @@ -812,9 +812,8 @@ void FollowSplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { writer->WriteArrayItem(n_split); } -void FollowSplitStepNode::ExtractSplitLengths( - const Array& transform_steps, - Array>* lengths) const { +void FollowSplitStepNode::ExtractSplitLengths(const Array& transform_steps, + Array>* lengths) const { CHECK_LT(src_step_id, transform_steps.size()); auto ps = transform_steps[src_step_id].as(); CHECK(ps != nullptr); @@ -866,30 +865,30 @@ Array FollowSplitStepNode::ApplyToState(State* state) const { return ApplySplitToState(state, stage_id, iter_id, lengths, true); } -Array FollowSplitStepNode::ApplyToSchedule( - Array *stages, StageToAxesMap *stage_to_axes, - const Array& transform_steps) const { +Array FollowSplitStepNode::ApplyToSchedule(Array* stages, + StageToAxesMap* stage_to_axes, + const Array& transform_steps) const { Array> lengths; ExtractSplitLengths(transform_steps, &lengths); - return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, - lengths, true); + return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, lengths, true); } -String FollowSplitStepNode::PrintAsPythonAPI( - Array *stages, StageToAxesMap *stage_to_axes, const Array& transform_steps) const { +String FollowSplitStepNode::PrintAsPythonAPI(Array* stages, + StageToAxesMap* stage_to_axes, + const Array& transform_steps) const { Array> lengths; ExtractSplitLengths(transform_steps, &lengths); - return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, - lengths, true); + return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, lengths, true); } /********** Follow Fused Split **********/ FollowFusedSplitStep::FollowFusedSplitStep(int stage_id, int iter_id, - const Array& src_step_ids, int level, bool factor_or_nparts) { + const Array& src_step_ids, int level, + bool factor_or_nparts) { auto node = make_object(); node->stage_id = stage_id; node->iter_id = iter_id; - node->src_step_ids = src_step_ids;; + node->src_step_ids = src_step_ids; node->level = level; node->factor_or_nparts = factor_or_nparts; data_ = std::move(node); @@ -955,22 +954,21 @@ Array FollowFusedSplitStepNode::ApplyToState(State* state) const { return ApplySplitToState(state, stage_id, iter_id, {length}, factor_or_nparts); } -Array FollowFusedSplitStepNode::ApplyToSchedule( - Array *stages, StageToAxesMap *stage_to_axes, - const Array& transform_steps) const { +Array FollowFusedSplitStepNode::ApplyToSchedule(Array* stages, + StageToAxesMap* stage_to_axes, + const Array& transform_steps) const { const Optional& length = ExtractSplitLength(transform_steps); - return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, - {length}, factor_or_nparts); + return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, {length}, factor_or_nparts); } -String FollowFusedSplitStepNode::PrintAsPythonAPI( - Array *stages, StageToAxesMap *stage_to_axes, const Array& transform_steps) const { +String FollowFusedSplitStepNode::PrintAsPythonAPI(Array* stages, + StageToAxesMap* stage_to_axes, + const Array& transform_steps) const { const Optional& length = ExtractSplitLength(transform_steps); - return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, - {length}, factor_or_nparts); + return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, {length}, + factor_or_nparts); } - /********** Primitives working on multiple stages **********/ /********** Compute At **********/ diff --git a/src/auto_scheduler/transform_step.h b/src/auto_scheduler/transform_step.h index 0a1044f38988..3d1456eeb088 100644 --- a/src/auto_scheduler/transform_step.h +++ b/src/auto_scheduler/transform_step.h @@ -413,25 +413,23 @@ class ReorderStep : public Step { /*! \brief Similar to SplitStepNode, but use split factor from another stepf * (i.e. Follow another split step) */ -class FollowSplitStepNode: public StepNode { +class FollowSplitStepNode : public StepNode { public: int iter_id; // The id of the iter to split int src_step_id; // The index of the split step to follow in the history int n_split; // The number of split level - + void WriteToRecord(dmlc::JSONWriter* writer) const final; - void ExtractSplitLengths( const Array& transform_steps, - Array>* lengths) const; + void ExtractSplitLengths(const Array& transform_steps, + Array>* lengths) const; Array ApplyToState(State* state) const; - Array ApplyToSchedule(Array *stages, - StageToAxesMap *stage_to_axes, - const Array& transform_steps) const; + Array ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes, + const Array& transform_steps) const; - String PrintAsPythonAPI(Array *stages, - StageToAxesMap *stage_to_axes, + String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes, const Array& transform_steps) const; static constexpr const char* record_prefix_str = "FSP"; static constexpr const char* _type_key = "auto_scheduler.FollowSplitStep"; @@ -445,7 +443,7 @@ class FollowSplitStepNode: public StepNode { class FollowSplitStep : public Step { public: FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split); - + explicit FollowSplitStep(dmlc::JSONReader* reader); TVM_DEFINE_OBJECT_REF_METHODS(FollowSplitStep, Step, FollowSplitStepNode); @@ -454,26 +452,24 @@ class FollowSplitStep : public Step { /*! \brief Similar to FollowSplitStep, but use split factors from multiple steps. * \Note This can be used for the split in cooperative fetching */ -class FollowFusedSplitStepNode: public StepNode { +class FollowFusedSplitStepNode : public StepNode { public: - int iter_id; // The id of the iter to split + int iter_id; // The id of the iter to split Array src_step_ids; // The indices of the split steps to follow in the history - int level; // Use the length in this split level - bool factor_or_nparts; // If this is true, use factor. Otherwise, use nparts - + int level; // Use the length in this split level + bool factor_or_nparts; // If this is true, use factor. Otherwise, use nparts + void WriteToRecord(dmlc::JSONWriter* writer) const final; - Optional ExtractSplitLength( const Array& transform_steps) const; + Optional ExtractSplitLength(const Array& transform_steps) const; Array ApplyToState(State* state) const; - Array ApplyToSchedule(Array *stages, - StageToAxesMap *stage_to_axes, - const Array& transform_steps) const; + Array ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes, + const Array& transform_steps) const; - String PrintAsPythonAPI(Array *stages, - StageToAxesMap *stage_to_axes, - const Array& transform_steps) const; + String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes, + const Array& transform_steps) const; static constexpr const char* record_prefix_str = "FFSP"; static constexpr const char* _type_key = "auto_scheduler.FollowFusedSplitStep"; @@ -486,10 +482,9 @@ class FollowFusedSplitStepNode: public StepNode { */ class FollowFusedSplitStep : public Step { public: - FollowFusedSplitStep(int stage_id, int iter_id, - const Array& src_step_ids, - int level, bool factor_or_nparts); - + FollowFusedSplitStep(int stage_id, int iter_id, const Array& src_step_ids, int level, + bool factor_or_nparts); + explicit FollowFusedSplitStep(dmlc::JSONReader* reader); TVM_DEFINE_OBJECT_REF_METHODS(FollowFusedSplitStep, Step, FollowFusedSplitStepNode); From f20952538ae3b3eb209772d092bf800abeea255b Mon Sep 17 00:00:00 2001 From: "jingbang.yjb" Date: Thu, 23 Jul 2020 17:34:32 +0800 Subject: [PATCH 05/14] Add more doc strings and change the order for follow split. Signed-off-by: jingbang.yjb --- python/tvm/auto_scheduler/loop_state.py | 43 ++++-- src/auto_scheduler/loop_state.cc | 80 +++++------ src/auto_scheduler/loop_state.h | 1 - src/auto_scheduler/transform_step.h | 174 +++++++++++++----------- 4 files changed, 170 insertions(+), 128 deletions(-) diff --git a/python/tvm/auto_scheduler/loop_state.py b/python/tvm/auto_scheduler/loop_state.py index 5904ed9cd588..9072c46a7dab 100644 --- a/python/tvm/auto_scheduler/loop_state.py +++ b/python/tvm/auto_scheduler/loop_state.py @@ -304,8 +304,17 @@ def split(self, stage, iterator, lengths, inner_to_outer=True): return res def follow_split(self, stage, iterator, src_step_id, n_split): - """ - Schedule primitive corresponds to te.follow_split. + """ Schedule primitive extends to split step. + + This step is used to follow a former SplitStep, keeps their iterator structures to be same. + + Example cases: + With subgraph: Dense -> Relu + Some tiling structures are used in Relu stage and we intend to compute the Dense + stage at Relu. + The follow_split is used here to keep their outer most few iterators the same for + applying compute at. + Parameters ---------- stage : Union[int, Operation, Tensor] @@ -313,9 +322,9 @@ def follow_split(self, stage, iterator, src_step_id, n_split): or output tensor of the stage. iterator : Iterator The iterator to split. - src_step_id : Int + src_step_id : int The index of the split step to follow in the history. - n_split : Int + n_split : int The number of split level. Returns @@ -332,17 +341,35 @@ def follow_split(self, stage, iterator, src_step_id, n_split): def follow_fused_split(self, stage, iterator, src_step_ids, level, factor_or_nparts): - """ - Schedule primitive corresponds to te.follow_fused_split. + """ Schedule primitive extends to split step. + + This step is used to follow several former SplitSteps and FuseSteps. + + Example cases: + With subgraph in GPU schedule: Input -> Dense + for i.0@j.0 = ... : Bind to blockIdx.x + for i.1@j.1 = ... : Bind to threadIdx.x + for i.2@j.2 = ... + Input_shared = Input ... + for k = ... + Dense = ... + We intend to apply cooperative fetching with the Input stage, while the threadIdx.x + axis is binded to a iterator generated by split & fuse step. + The follow_fused_step is used here to figure out the final extent of the threadIdx.x + binded iterator. + Parameters ---------- + stage : Union[int, Operation, Tensor] + The Stage to be split, which can be specified by the integer index, Operation, + or output tensor of the stage. iterator : Iterator The iterator to split. src_step_ids : List[int] The indices of the split steps to follow in the history. - level : Int + level : int Use the length in this split level. - factor_or_nparts : Bool + factor_or_nparts : bool True to use `factor` for split from inner to outer, False to use `nparts` for split from outer to inner. diff --git a/src/auto_scheduler/loop_state.cc b/src/auto_scheduler/loop_state.cc index 7aaddf129f6d..804256917dec 100644 --- a/src/auto_scheduler/loop_state.cc +++ b/src/auto_scheduler/loop_state.cc @@ -239,27 +239,6 @@ Iterator State::vectorize(int stage_id, const Iterator& it) { return step->ApplyToState(this); } -Array State::follow_split(int stage_id, const Iterator& it, int src_step_id, - int n_split) { - const Stage& stage = operator->()->stages[stage_id]; - - FollowSplitStep step = - FollowSplitStep(stage_id, GetIndex(stage->iters, it), src_step_id, n_split); - CopyOnWrite()->transform_steps.push_back(step); - return step->ApplyToState(this); -} - -Array State::follow_fused_split(int stage_id, const Iterator& it, - const Array& src_step_ids, int level, - bool factor_or_nparts) { - const Stage& stage = operator->()->stages[stage_id]; - - FollowFusedSplitStep step = FollowFusedSplitStep(stage_id, GetIndex(stage->iters, it), - src_step_ids, level, factor_or_nparts); - CopyOnWrite()->transform_steps.push_back(step); - return step->ApplyToState(this); -} - Iterator State::fuse(int stage_id, const Array& iters) { const Stage& stage = operator->()->stages[stage_id]; Array indices; @@ -290,6 +269,27 @@ Array State::split(int stage_id, const Iterator& it, return step->ApplyToState(this); } +Array State::follow_split(int stage_id, const Iterator& it, int src_step_id, + int n_split) { + const Stage& stage = operator->()->stages[stage_id]; + + FollowSplitStep step = + FollowSplitStep(stage_id, GetIndex(stage->iters, it), src_step_id, n_split); + CopyOnWrite()->transform_steps.push_back(step); + return step->ApplyToState(this); +} + +Array State::follow_fused_split(int stage_id, const Iterator& it, + const Array& src_step_ids, int level, + bool factor_or_nparts) { + const Stage& stage = operator->()->stages[stage_id]; + + FollowFusedSplitStep step = FollowFusedSplitStep(stage_id, GetIndex(stage->iters, it), + src_step_ids, level, factor_or_nparts); + CopyOnWrite()->transform_steps.push_back(step); + return step->ApplyToState(this); +} + void State::compute_at(int stage_id, int target_stage_id, const Iterator& target_iter) { const Stage& target_stage = operator->()->stages[target_stage_id]; ComputeAtStep step = @@ -457,25 +457,6 @@ TVM_REGISTER_GLOBAL("auto_scheduler.StateVectorize") return Array{state, res}; }); -TVM_REGISTER_GLOBAL("auto_scheduler.StateFollowSplit") - .set_body_typed([](State state, int stage_id, const Iterator& it, int src_step_id, - int n_split) { - const auto& res = state.follow_split(stage_id, it, src_step_id, n_split); - return Array{state, Array(res)}; - }); - -TVM_REGISTER_GLOBAL("auto_scheduler.StateFollowFusedSplit") - .set_body_typed([](State state, int stage_id, const Iterator& it, - const Array& src_step_ids, int level, bool factor_or_nparts) { - Array array_src_step_ids; - for (const auto& i : src_step_ids) { - array_src_step_ids.push_back(i->value); - } - const auto& res = - state.follow_fused_split(stage_id, it, array_src_step_ids, level, factor_or_nparts); - return Array{state, Array(res)}; - }); - TVM_REGISTER_GLOBAL("auto_scheduler.StateFuse") .set_body_typed([](State state, int stage_id, const Array& iters) { const auto& res = state.fuse(stage_id, iters); @@ -495,6 +476,25 @@ TVM_REGISTER_GLOBAL("auto_scheduler.StateSplit") return Array{state, res}; }); +TVM_REGISTER_GLOBAL("auto_scheduler.StateFollowSplit") + .set_body_typed([](State state, int stage_id, const Iterator& it, int src_step_id, + int n_split) { + const auto& res = state.follow_split(stage_id, it, src_step_id, n_split); + return Array{state, Array(res)}; + }); + +TVM_REGISTER_GLOBAL("auto_scheduler.StateFollowFusedSplit") + .set_body_typed([](State state, int stage_id, const Iterator& it, + const Array& src_step_ids, int level, bool factor_or_nparts) { + Array array_src_step_ids; + for (const auto& i : src_step_ids) { + array_src_step_ids.push_back(i->value); + } + const auto& res = + state.follow_fused_split(stage_id, it, array_src_step_ids, level, factor_or_nparts); + return Array{state, Array(res)}; + }); + TVM_REGISTER_GLOBAL("auto_scheduler.StateComputeAt") .set_body_typed([](State state, int stage_id, int target_stage_id, const Iterator& target_iter) { diff --git a/src/auto_scheduler/loop_state.h b/src/auto_scheduler/loop_state.h index 0d8823bb3ef6..da63b8a11d37 100644 --- a/src/auto_scheduler/loop_state.h +++ b/src/auto_scheduler/loop_state.h @@ -377,7 +377,6 @@ class State : public ObjectRef { Array follow_fused_split(int stage_id, const Iterator& it, const Array& src_step_ids, int level, bool factor_or_nparts); - /*! * \brief Schedule primitive corresponds to te.compute_at. * \param stage_id The index of the stage to be compute at. diff --git a/src/auto_scheduler/transform_step.h b/src/auto_scheduler/transform_step.h index 3d1456eeb088..0ddeb9b20c3f 100644 --- a/src/auto_scheduler/transform_step.h +++ b/src/auto_scheduler/transform_step.h @@ -411,85 +411,6 @@ class ReorderStep : public Step { TVM_DEFINE_OBJECT_REF_METHODS(ReorderStep, Step, ReorderStepNode); }; -/*! \brief Similar to SplitStepNode, but use split factor from another stepf - * (i.e. Follow another split step) */ -class FollowSplitStepNode : public StepNode { - public: - int iter_id; // The id of the iter to split - int src_step_id; // The index of the split step to follow in the history - int n_split; // The number of split level - - void WriteToRecord(dmlc::JSONWriter* writer) const final; - - void ExtractSplitLengths(const Array& transform_steps, - Array>* lengths) const; - - Array ApplyToState(State* state) const; - - Array ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes, - const Array& transform_steps) const; - - String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes, - const Array& transform_steps) const; - static constexpr const char* record_prefix_str = "FSP"; - static constexpr const char* _type_key = "auto_scheduler.FollowSplitStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(FollowSplitStepNode, Object); -}; - -/*! - * \brief Managed reference to FollowSplitStepNode. - * \sa FollowSplitStepNode - */ -class FollowSplitStep : public Step { - public: - FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split); - - explicit FollowSplitStep(dmlc::JSONReader* reader); - - TVM_DEFINE_OBJECT_REF_METHODS(FollowSplitStep, Step, FollowSplitStepNode); -}; - -/*! \brief Similar to FollowSplitStep, but use split factors from multiple steps. - * \Note This can be used for the split in cooperative fetching - */ -class FollowFusedSplitStepNode : public StepNode { - public: - int iter_id; // The id of the iter to split - Array src_step_ids; // The indices of the split steps to follow in the history - int level; // Use the length in this split level - bool factor_or_nparts; // If this is true, use factor. Otherwise, use nparts - - void WriteToRecord(dmlc::JSONWriter* writer) const final; - - Optional ExtractSplitLength(const Array& transform_steps) const; - - Array ApplyToState(State* state) const; - - Array ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes, - const Array& transform_steps) const; - - String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes, - const Array& transform_steps) const; - - static constexpr const char* record_prefix_str = "FFSP"; - static constexpr const char* _type_key = "auto_scheduler.FollowFusedSplitStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(FollowFusedSplitStepNode, Object); -}; - -/*! - * \brief Managed reference to FollowFusedSplitStepNode. - * \sa FollowFusedSplitStepNode - */ -class FollowFusedSplitStep : public Step { - public: - FollowFusedSplitStep(int stage_id, int iter_id, const Array& src_step_ids, int level, - bool factor_or_nparts); - - explicit FollowFusedSplitStep(dmlc::JSONReader* reader); - - TVM_DEFINE_OBJECT_REF_METHODS(FollowFusedSplitStep, Step, FollowFusedSplitStepNode); -}; - /*! * \brief Split step that corresponds to te::Stage::split with additional * support of multiple-level of factors @@ -569,6 +490,101 @@ class SplitStep : public Step { TVM_DEFINE_OBJECT_REF_METHODS(SplitStep, Step, SplitStepNode); }; +/*! \brief Similar to SplitStepNode, but use split factor from another stepf + * (i.e. Follow another split step) */ +class FollowSplitStepNode : public StepNode { + public: + /*! \brief The id of the iter to split. */ + int iter_id; + /*! \brief The index of the split step to follow in the history. */ + int src_step_id; + /*! \brief The number of split level. */ + int n_split; + + void WriteToRecord(dmlc::JSONWriter* writer) const final; + + /*! + * \brief Extract split lengths. + * \param transform_steps An array record all transform steps. + * \param lengths The multiple split factors. Can be None to be filled by search policy. + */ + void ExtractSplitLengths(const Array& transform_steps, + Array>* lengths) const; + + Array ApplyToState(State* state) const; + + Array ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes, + const Array& transform_steps) const; + + String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes, + const Array& transform_steps) const; + + static constexpr const char* record_prefix_str = "FSP"; + + static constexpr const char* _type_key = "auto_scheduler.FollowSplitStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(FollowSplitStepNode, Object); +}; + +/*! + * \brief Managed reference to FollowSplitStepNode. + * \sa FollowSplitStepNode + */ +class FollowSplitStep : public Step { + public: + FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split); + + explicit FollowSplitStep(dmlc::JSONReader* reader); + + TVM_DEFINE_OBJECT_REF_METHODS(FollowSplitStep, Step, FollowSplitStepNode); +}; + +/*! \brief Similar to FollowSplitStep, but use split factors from multiple steps. + * \Note This can be used for the split in cooperative fetching + */ +class FollowFusedSplitStepNode : public StepNode { + public: + int iter_id; // The id of the iter to split + Array src_step_ids; // The indices of the split steps to follow in the history + int level; // Use the length in this split level + bool factor_or_nparts; // If this is true, use factor. Otherwise, use nparts + + void WriteToRecord(dmlc::JSONWriter* writer) const final; + + /*! + * \brief Extract split length. + * \param transform_steps An array record all transform steps. + * \return Split factor. + */ + Optional ExtractSplitLength(const Array& transform_steps) const; + + Array ApplyToState(State* state) const; + + Array ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes, + const Array& transform_steps) const; + + String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes, + const Array& transform_steps) const; + + static constexpr const char* record_prefix_str = "FFSP"; + + static constexpr const char* _type_key = "auto_scheduler.FollowFusedSplitStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(FollowFusedSplitStepNode, Object); +}; + +/*! + * \brief Managed reference to FollowFusedSplitStepNode. + * \sa FollowFusedSplitStepNode + */ +class FollowFusedSplitStep : public Step { + public: + FollowFusedSplitStep(int stage_id, int iter_id, const Array& src_step_ids, int level, + bool factor_or_nparts); + + explicit FollowFusedSplitStep(dmlc::JSONReader* reader); + + TVM_DEFINE_OBJECT_REF_METHODS(FollowFusedSplitStep, Step, FollowFusedSplitStepNode); +}; + /********** Primitives working on multiple stages **********/ /*! \brief Compute at step that corresponds to te::Stage::compute_at */ From 50f7c4a50d91c95814790c8ffc9f4a8299e26e6c Mon Sep 17 00:00:00 2001 From: "jingbang.yjb" Date: Thu, 23 Jul 2020 19:11:18 +0800 Subject: [PATCH 06/14] Add record test for follow_split and follow_fused_split Signed-off-by: jingbang.yjb --- .../unittest/test_auto_scheduler_measure.py | 29 ++++++++++++------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/tests/python/unittest/test_auto_scheduler_measure.py b/tests/python/unittest/test_auto_scheduler_measure.py index 039dff921144..4071bd7f42ac 100644 --- a/tests/python/unittest/test_auto_scheduler_measure.py +++ b/tests/python/unittest/test_auto_scheduler_measure.py @@ -22,8 +22,7 @@ from tvm import te, auto_scheduler import tempfile -from test_auto_scheduler_common import get_tiled_matmul - +from test_auto_scheduler_common import matmul_auto_scheduler_test, get_tiled_matmul def test_record(): if not tvm.runtime.enabled("llvm"): @@ -37,20 +36,28 @@ def test_record(): k = te.reduce_axis((0, 512), name='k') E = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * D[k][j], axis=[k]), name='E') F = topi.nn.relu(E) + I, H, G = matmul_auto_scheduler_test(512, 512, 512) dag = auto_scheduler.ComputeDAG([A, B, F]) s = dag.get_init_state() - - # Follow split - C_global = s0.cache_write(C, "global") - split_step0 = len(s0.transform_steps) - 1 - split_step1 = len(s0.transform_steps) - 1 - s.follow_split(C_global, tmp[C_global].iters[0], split_step0, 0) - # Follow fused split - s.follow_fused_split(C_global, tmp[C_global].iters[0], - [split_step0, split_step1], 0, False) + dag_follow = auto_scheduler.ComputeDAG([G, H, I]) + s_follow = dag_follow.get_init_state() + # Split its0 = s.split(C, s[C].iters[0], [4, 8, 8]) its1 = s.split(C, s[C].iters[4], [8, 4, 4]) + # Follow split + G_global = s_follow.cache_write(G, "global") + its2 = s_follow.split(G, s_follow[G].iters[0], [4, 2, 8, 4], True) + split_step0 = len(s_follow.transform_steps) - 1 + tmp = s_follow.copy() + tmp.follow_split(G_global, tmp[G_global].iters[0], split_step0, 1) + its3 = s_follow.split(G, s_follow[G].iters[5], [2, 2, 4, 8]) + split_step1 = len(s_follow.transform_steps) - 1 + + # Follow fused split + tmp = s_follow.copy() + tmp.follow_fused_split(G_global, tmp[G_global].iters[0], + [split_step0, split_step1], 0, False) # Reorder s.reorder(C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], its0[3], s[C].iters[8], its1[3]]) From 7bf8dd5190890212809b78f4bb5cf96f4f24e720 Mon Sep 17 00:00:00 2001 From: "jingbang.yjb" Date: Fri, 24 Jul 2020 07:58:26 +0800 Subject: [PATCH 07/14] Add record test for follow_split Signed-off-by: jingbang.yjb --- src/auto_scheduler/loop_state.h | 5 ++-- .../unittest/test_auto_scheduler_measure.py | 25 +++++++------------ 2 files changed, 12 insertions(+), 18 deletions(-) diff --git a/src/auto_scheduler/loop_state.h b/src/auto_scheduler/loop_state.h index da63b8a11d37..cf0f35dce899 100644 --- a/src/auto_scheduler/loop_state.h +++ b/src/auto_scheduler/loop_state.h @@ -353,8 +353,6 @@ class State : public ObjectRef { */ Array split(int stage_id, const Iterator& it, const Array>& lengths, bool inner_to_outer = true); - /********** Step APIs working on multiple stages **********/ - /*! * \brief Schedule primitive corresponds to te.follow_split. * \param stage_id The index of the stage to be split. @@ -377,6 +375,9 @@ class State : public ObjectRef { Array follow_fused_split(int stage_id, const Iterator& it, const Array& src_step_ids, int level, bool factor_or_nparts); + + /********** Step APIs working on multiple stages **********/ + /*! * \brief Schedule primitive corresponds to te.compute_at. * \param stage_id The index of the stage to be compute at. diff --git a/tests/python/unittest/test_auto_scheduler_measure.py b/tests/python/unittest/test_auto_scheduler_measure.py index 4071bd7f42ac..72b8804f0b4d 100644 --- a/tests/python/unittest/test_auto_scheduler_measure.py +++ b/tests/python/unittest/test_auto_scheduler_measure.py @@ -36,28 +36,16 @@ def test_record(): k = te.reduce_axis((0, 512), name='k') E = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * D[k][j], axis=[k]), name='E') F = topi.nn.relu(E) - I, H, G = matmul_auto_scheduler_test(512, 512, 512) + k = te.reduce_axis((0, 512), name='k') + G = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * F[k][j], axis=[k]), name='G') + H = topi.nn.relu(G) + dag = auto_scheduler.ComputeDAG([A, B, F]) s = dag.get_init_state() - dag_follow = auto_scheduler.ComputeDAG([G, H, I]) - s_follow = dag_follow.get_init_state() # Split its0 = s.split(C, s[C].iters[0], [4, 8, 8]) its1 = s.split(C, s[C].iters[4], [8, 4, 4]) - # Follow split - G_global = s_follow.cache_write(G, "global") - its2 = s_follow.split(G, s_follow[G].iters[0], [4, 2, 8, 4], True) - split_step0 = len(s_follow.transform_steps) - 1 - tmp = s_follow.copy() - tmp.follow_split(G_global, tmp[G_global].iters[0], split_step0, 1) - its3 = s_follow.split(G, s_follow[G].iters[5], [2, 2, 4, 8]) - split_step1 = len(s_follow.transform_steps) - 1 - - # Follow fused split - tmp = s_follow.copy() - tmp.follow_fused_split(G_global, tmp[G_global].iters[0], - [split_step0, split_step1], 0, False) # Reorder s.reorder(C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], its0[3], s[C].iters[8], its1[3]]) @@ -85,6 +73,11 @@ def test_record(): s.compute_at(D_global, E, s[E].iters[2]) # Cache Write s.cache_write(D, "shared") + #follow_split + s.split(F, s[F].iters[0], [2]) + split_step0 = len(s.transform_steps) - 1 + s.follow_split(F, s[F].iters[0], split_step0, 1) + #follow_fused_split target = tvm.target.create("llvm") task = auto_scheduler.SearchTask(dag, "test", target) From 98d943b8e3332a197f1b501725d2fd6e7a5016c8 Mon Sep 17 00:00:00 2001 From: "jingbang.yjb" Date: Fri, 24 Jul 2020 08:21:59 +0800 Subject: [PATCH 08/14] Add record test for follow_fused_split. Signed-off-by: jingbang.yjb --- tests/python/unittest/test_auto_scheduler_measure.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/python/unittest/test_auto_scheduler_measure.py b/tests/python/unittest/test_auto_scheduler_measure.py index 72b8804f0b4d..6ec88b714c67 100644 --- a/tests/python/unittest/test_auto_scheduler_measure.py +++ b/tests/python/unittest/test_auto_scheduler_measure.py @@ -40,7 +40,8 @@ def test_record(): G = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * F[k][j], axis=[k]), name='G') H = topi.nn.relu(G) - dag = auto_scheduler.ComputeDAG([A, B, F]) + #dag = auto_scheduler.ComputeDAG([A, B, F]) + dag = auto_scheduler.ComputeDAG([A, B, G]) s = dag.get_init_state() # Split @@ -74,10 +75,13 @@ def test_record(): # Cache Write s.cache_write(D, "shared") #follow_split - s.split(F, s[F].iters[0], [2]) + its2 = s.split(G, s[G].iters[0], [4, 2, 8, 4], True) split_step0 = len(s.transform_steps) - 1 - s.follow_split(F, s[F].iters[0], split_step0, 1) + s.follow_split(G, s[G].iters[0], split_step0, 1) #follow_fused_split + its1 = s.split(G, s[G].iters[5], [2, 2, 4, 8]) + split_step1 = len(s.transform_steps) - 1 + s.follow_fused_split(G, s[G].iters[0], [split_step0, split_step1], 0, False) target = tvm.target.create("llvm") task = auto_scheduler.SearchTask(dag, "test", target) From 296cb367f0b54e08cf4634cd4ceefc89a12f23b3 Mon Sep 17 00:00:00 2001 From: "jingbang.yjb" Date: Fri, 24 Jul 2020 08:39:08 +0800 Subject: [PATCH 09/14] Add test record for follow_fused_split 1. delete a comment 2. add "fuse" between follow_split and follow_fused_split Signed-off-by: jingbang.yjb --- tests/python/unittest/test_auto_scheduler_measure.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_auto_scheduler_measure.py b/tests/python/unittest/test_auto_scheduler_measure.py index 6ec88b714c67..e3b31ed9d936 100644 --- a/tests/python/unittest/test_auto_scheduler_measure.py +++ b/tests/python/unittest/test_auto_scheduler_measure.py @@ -40,7 +40,6 @@ def test_record(): G = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * F[k][j], axis=[k]), name='G') H = topi.nn.relu(G) - #dag = auto_scheduler.ComputeDAG([A, B, F]) dag = auto_scheduler.ComputeDAG([A, B, G]) s = dag.get_init_state() @@ -79,8 +78,14 @@ def test_record(): split_step0 = len(s.transform_steps) - 1 s.follow_split(G, s[G].iters[0], split_step0, 1) #follow_fused_split - its1 = s.split(G, s[G].iters[5], [2, 2, 4, 8]) + its3 = s.split(G, s[G].iters[5], [2, 2, 4, 8]) split_step1 = len(s.transform_steps) - 1 + its = [] + for i0, i1 in zip(its2, its3): + its.append(i0) + its.append(i1) + for i in range(0, 5): + s.fuse(G, [s[G].iters[i], s[G].iters[i + 1]]) s.follow_fused_split(G, s[G].iters[0], [split_step0, split_step1], 0, False) target = tvm.target.create("llvm") From a7b12946961f0bec88678346c1ce54781f7055af Mon Sep 17 00:00:00 2001 From: "jingbang.yjb" Date: Fri, 24 Jul 2020 09:06:15 +0800 Subject: [PATCH 10/14] Add doc strings for some functions and variables Signed-off-by: jingbang.yjb --- src/auto_scheduler/transform_step.h | 73 +++++++++++++++++++++++++++-- 1 file changed, 68 insertions(+), 5 deletions(-) diff --git a/src/auto_scheduler/transform_step.h b/src/auto_scheduler/transform_step.h index 0ddeb9b20c3f..21a3d71b7322 100644 --- a/src/auto_scheduler/transform_step.h +++ b/src/auto_scheduler/transform_step.h @@ -511,11 +511,28 @@ class FollowSplitStepNode : public StepNode { void ExtractSplitLengths(const Array& transform_steps, Array>* lengths) const; + /*! + * \brief Apply the current step to State. + * \param state A mutable pointer to State. + */ Array ApplyToState(State* state) const; + /*! + * \brief Apply the current step to tvm.schedule. + * \param stages A pointer to a `te::Stage` Array. + * \param stage_to_axes A pointer to a StageToAxesMap. + * \param transform_steps An array record all transform steps. + */ Array ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes, const Array& transform_steps) const; + /*! + * \brief Print the current step as equivalent python schedule API. + * \param stages A pointer to a `te::Stage` Array. + * \param stage_to_axes A pointer to a StageToAxesMap. + * \param transform_steps An array record all transform steps. + * \return Python schedule code. + */ String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes, const Array& transform_steps) const; @@ -531,8 +548,20 @@ class FollowSplitStepNode : public StepNode { */ class FollowSplitStep : public Step { public: + /*! + * \brief The constructor. + * \param stage_id The index of the stage to be split. + * \param iter_id The index of the iterator to be split. + * \param src_step_id The index of the split step to follow in the history. + * \param n_split The number of split level. + */ FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split); - + + /*! + * \brief The constructor used to read a step record from JSONReader and create the + * corresponding step. + * \param reader The input JSONReader. + */ explicit FollowSplitStep(dmlc::JSONReader* reader); TVM_DEFINE_OBJECT_REF_METHODS(FollowSplitStep, Step, FollowSplitStepNode); @@ -543,10 +572,14 @@ class FollowSplitStep : public Step { */ class FollowFusedSplitStepNode : public StepNode { public: - int iter_id; // The id of the iter to split - Array src_step_ids; // The indices of the split steps to follow in the history - int level; // Use the length in this split level - bool factor_or_nparts; // If this is true, use factor. Otherwise, use nparts + /*! \brief The id of the iter to split. */ + int iter_id; + /*! \brief The indices of the split steps to follow in the history. */ + Array src_step_ids; + /*! \brief Use the length in this split level. */ + int level; + /*! \brief If this is true, use factor. Otherwise, use nparts. */ + bool factor_or_nparts; void WriteToRecord(dmlc::JSONWriter* writer) const final; @@ -557,11 +590,28 @@ class FollowFusedSplitStepNode : public StepNode { */ Optional ExtractSplitLength(const Array& transform_steps) const; + /*! + * \brief Apply the current step to State. + * \param state A mutable pointer to State. + */ Array ApplyToState(State* state) const; + /*! + * \brief Apply the current step to tvm.schedule. + * \param stages A pointer to a `te::Stage` Array. + * \param stage_to_axes A pointer to a StageToAxesMap. + * \param transform_steps An array record all transform steps. + */ Array ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes, const Array& transform_steps) const; + /*! + * \brief Print the current step as equivalent python schedule API. + * \param stages A pointer to a `te::Stage` Array. + * \param stage_to_axes A pointer to a StageToAxesMap. + * \param transform_steps An array record all transform steps. + * \return Python schedule code. + */ String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes, const Array& transform_steps) const; @@ -577,9 +627,22 @@ class FollowFusedSplitStepNode : public StepNode { */ class FollowFusedSplitStep : public Step { public: + /*! + * \brief The constructor. + * \param stage_id The index of the stage to be split. + * \param iter_id The index of the iterator to be split. + * \param src_step_ids An array of index for split step to follow in the history. + * \param level Use the length in this split level. + * \param factor_or_nparts If this is true, use factor. Otherwise, use nparts. + */ FollowFusedSplitStep(int stage_id, int iter_id, const Array& src_step_ids, int level, bool factor_or_nparts); + /*! + * \brief The constructor used to read a step record from JSONReader and create the + * corresponding step. + * \param reader The input JSONReader. + */ explicit FollowFusedSplitStep(dmlc::JSONReader* reader); TVM_DEFINE_OBJECT_REF_METHODS(FollowFusedSplitStep, Step, FollowFusedSplitStepNode); From 5220a681402289e42af2aae338ea8092bbe62ceb Mon Sep 17 00:00:00 2001 From: "jingbang.yjb" Date: Fri, 24 Jul 2020 09:15:49 +0800 Subject: [PATCH 11/14] Fix the code format in src/auto_scheduler/transform_step.h Signed-off-by: jingbang.yjb --- src/auto_scheduler/transform_step.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/auto_scheduler/transform_step.h b/src/auto_scheduler/transform_step.h index 21a3d71b7322..23c6321334a5 100644 --- a/src/auto_scheduler/transform_step.h +++ b/src/auto_scheduler/transform_step.h @@ -556,7 +556,7 @@ class FollowSplitStep : public Step { * \param n_split The number of split level. */ FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split); - + /*! * \brief The constructor used to read a step record from JSONReader and create the * corresponding step. @@ -573,7 +573,7 @@ class FollowSplitStep : public Step { class FollowFusedSplitStepNode : public StepNode { public: /*! \brief The id of the iter to split. */ - int iter_id; + int iter_id; /*! \brief The indices of the split steps to follow in the history. */ Array src_step_ids; /*! \brief Use the length in this split level. */ From 1a87244b3e750414058477c8aa000ae6ac29715e Mon Sep 17 00:00:00 2001 From: "jingbang.yjb" Date: Fri, 24 Jul 2020 12:35:13 +0800 Subject: [PATCH 12/14] Fix follow_split and follow_fused_split record test. Signed-off-by: jingbang.yjb --- .../unittest/test_auto_scheduler_measure.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/tests/python/unittest/test_auto_scheduler_measure.py b/tests/python/unittest/test_auto_scheduler_measure.py index e3b31ed9d936..6d6fce79d58c 100644 --- a/tests/python/unittest/test_auto_scheduler_measure.py +++ b/tests/python/unittest/test_auto_scheduler_measure.py @@ -27,7 +27,7 @@ def test_record(): if not tvm.runtime.enabled("llvm"): return - + #pdb.set_trace() A = te.placeholder((512, 512), name='A') B = te.placeholder((512, 512), name='B') k = te.reduce_axis((0, 512), name='k') @@ -39,8 +39,9 @@ def test_record(): k = te.reduce_axis((0, 512), name='k') G = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * F[k][j], axis=[k]), name='G') H = topi.nn.relu(G) + I = topi.nn.relu(H) - dag = auto_scheduler.ComputeDAG([A, B, G]) + dag = auto_scheduler.ComputeDAG([A, B, I]) s = dag.get_init_state() # Split @@ -76,17 +77,19 @@ def test_record(): #follow_split its2 = s.split(G, s[G].iters[0], [4, 2, 8, 4], True) split_step0 = len(s.transform_steps) - 1 - s.follow_split(G, s[G].iters[0], split_step0, 1) + s.follow_split(G, s[G].iters[5], split_step0, 4) #follow_fused_split - its3 = s.split(G, s[G].iters[5], [2, 2, 4, 8]) + its2 = s.split(H, s[H].iters[0], [4, 2, 8, 4], True) split_step1 = len(s.transform_steps) - 1 + its3 = s.split(H, s[H].iters[5], [2, 4, 2, 4], True) + split_step2 = len(s.transform_steps) - 1 its = [] for i0, i1 in zip(its2, its3): its.append(i0) its.append(i1) for i in range(0, 5): - s.fuse(G, [s[G].iters[i], s[G].iters[i + 1]]) - s.follow_fused_split(G, s[G].iters[0], [split_step0, split_step1], 0, False) + s.fuse(H, [s[H].iters[i], s[H].iters[i + 1]]) + s.follow_fused_split(I, s[I].iters[0], [split_step1, split_step2], 0, False) target = tvm.target.create("llvm") task = auto_scheduler.SearchTask(dag, "test", target) From 849bda34ff16ace649674344cdeb10d82449ac6c Mon Sep 17 00:00:00 2001 From: "jingbang.yjb" Date: Mon, 27 Jul 2020 08:50:04 +0800 Subject: [PATCH 13/14] Add a blank line in transform.h Signed-off-by: jingbang.yjb --- src/auto_scheduler/transform_step.h | 1 + 1 file changed, 1 insertion(+) diff --git a/src/auto_scheduler/transform_step.h b/src/auto_scheduler/transform_step.h index 23c6321334a5..ebcbb7dceb63 100644 --- a/src/auto_scheduler/transform_step.h +++ b/src/auto_scheduler/transform_step.h @@ -221,6 +221,7 @@ void StepApplyToSchedule(const Step& step, Array* stages, StageToAxes String StepPrintAsPythonAPI(const Step& step, Array* stages, StageToAxesMap* stage_to_axes, te::Schedule* schedule, const Array& transform_steps); + /********** Primitives working on single stage **********/ /*! From 0f47cfcc630d3306f2d62e03f2efa38536c22c3d Mon Sep 17 00:00:00 2001 From: "jingbang.yjb" Date: Fri, 24 Jul 2020 09:15:49 +0800 Subject: [PATCH 14/14] Add record test for follow_split and follow_fused_split. Signed-off-by: jingbang.yjb --- src/auto_scheduler/transform_step.h | 5 +++-- .../unittest/test_auto_scheduler_measure.py | 15 +++++++++------ 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/auto_scheduler/transform_step.h b/src/auto_scheduler/transform_step.h index 21a3d71b7322..ebcbb7dceb63 100644 --- a/src/auto_scheduler/transform_step.h +++ b/src/auto_scheduler/transform_step.h @@ -221,6 +221,7 @@ void StepApplyToSchedule(const Step& step, Array* stages, StageToAxes String StepPrintAsPythonAPI(const Step& step, Array* stages, StageToAxesMap* stage_to_axes, te::Schedule* schedule, const Array& transform_steps); + /********** Primitives working on single stage **********/ /*! @@ -556,7 +557,7 @@ class FollowSplitStep : public Step { * \param n_split The number of split level. */ FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split); - + /*! * \brief The constructor used to read a step record from JSONReader and create the * corresponding step. @@ -573,7 +574,7 @@ class FollowSplitStep : public Step { class FollowFusedSplitStepNode : public StepNode { public: /*! \brief The id of the iter to split. */ - int iter_id; + int iter_id; /*! \brief The indices of the split steps to follow in the history. */ Array src_step_ids; /*! \brief Use the length in this split level. */ diff --git a/tests/python/unittest/test_auto_scheduler_measure.py b/tests/python/unittest/test_auto_scheduler_measure.py index e3b31ed9d936..6d6fce79d58c 100644 --- a/tests/python/unittest/test_auto_scheduler_measure.py +++ b/tests/python/unittest/test_auto_scheduler_measure.py @@ -27,7 +27,7 @@ def test_record(): if not tvm.runtime.enabled("llvm"): return - + #pdb.set_trace() A = te.placeholder((512, 512), name='A') B = te.placeholder((512, 512), name='B') k = te.reduce_axis((0, 512), name='k') @@ -39,8 +39,9 @@ def test_record(): k = te.reduce_axis((0, 512), name='k') G = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * F[k][j], axis=[k]), name='G') H = topi.nn.relu(G) + I = topi.nn.relu(H) - dag = auto_scheduler.ComputeDAG([A, B, G]) + dag = auto_scheduler.ComputeDAG([A, B, I]) s = dag.get_init_state() # Split @@ -76,17 +77,19 @@ def test_record(): #follow_split its2 = s.split(G, s[G].iters[0], [4, 2, 8, 4], True) split_step0 = len(s.transform_steps) - 1 - s.follow_split(G, s[G].iters[0], split_step0, 1) + s.follow_split(G, s[G].iters[5], split_step0, 4) #follow_fused_split - its3 = s.split(G, s[G].iters[5], [2, 2, 4, 8]) + its2 = s.split(H, s[H].iters[0], [4, 2, 8, 4], True) split_step1 = len(s.transform_steps) - 1 + its3 = s.split(H, s[H].iters[5], [2, 4, 2, 4], True) + split_step2 = len(s.transform_steps) - 1 its = [] for i0, i1 in zip(its2, its3): its.append(i0) its.append(i1) for i in range(0, 5): - s.fuse(G, [s[G].iters[i], s[G].iters[i + 1]]) - s.follow_fused_split(G, s[G].iters[0], [split_step0, split_step1], 0, False) + s.fuse(H, [s[H].iters[i], s[H].iters[i + 1]]) + s.follow_fused_split(I, s[I].iters[0], [split_step1, split_step2], 0, False) target = tvm.target.create("llvm") task = auto_scheduler.SearchTask(dag, "test", target)