diff --git a/include/tvm/auto_scheduler/loop_state.h b/include/tvm/auto_scheduler/loop_state.h index 1c8ea770e2f84..9850620d484d4 100644 --- a/include/tvm/auto_scheduler/loop_state.h +++ b/include/tvm/auto_scheduler/loop_state.h @@ -359,6 +359,29 @@ class State : public ObjectRef { TVM_DLL Array split(int stage_id, const Iterator& it, const Array>& lengths, bool inner_to_outer = true); + /*! + * \brief Schedule primitive extends to split step. + * \param stage_id The index of the stage to be split. + * \param it The iterator to be split. + * \param src_step_id The index of the split step to be followed in the history. + * \param n_split The number of split level. + * \return The splitted new Iterators. + */ + TVM_DLL Array follow_split(int stage_id, const Iterator& it, int src_step_id, + int n_split); + /*! + * \brief Schedule primitive extends to split step. + * \param stage_id The index of the stage to be split. + * \param it The iterator to be split. + * \param src_step_ids The indices of the split steps to be followed in the history. + * \param level Use the length in this split level. + * \param factor_or_nparts True to use `factor` for split from inner to outer, + False to use `nparts` for split from outer to inner. + * \return The splitted new Iterators. + */ + TVM_DLL Array follow_fused_split(int stage_id, const Iterator& it, + const Array& src_step_ids, int level, + bool factor_or_nparts); /********** Step APIs working on multiple stages **********/ diff --git a/include/tvm/auto_scheduler/transform_step.h b/include/tvm/auto_scheduler/transform_step.h index 83d6e298a7d7d..f91505ce05e4a 100644 --- a/include/tvm/auto_scheduler/transform_step.h +++ b/include/tvm/auto_scheduler/transform_step.h @@ -202,9 +202,10 @@ void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag); * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. * \param schedule A mutable pointer to a `te::Schedule`. This is required by some steps which need * `te::Schedule` API. (e.g. CacheRead/CacheWrite step) + * \param transform_steps An array record all transform steps. */ void StepApplyToSchedule(const Step& step, Array* stages, StageToAxesMap* stage_to_axes, - te::Schedule* schedule); + te::Schedule* schedule, const Array& transform_steps); /*! * \brief Print the step as equivalent python schedule API. @@ -213,10 +214,12 @@ void StepApplyToSchedule(const Step& step, Array* stages, StageToAxes * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. * \param schedule A mutable pointer to a te::Schedule. This is required by some steps. (e.g. * CacheRead/CacheWrite step) + * \param transform_steps An array record all transform steps. * \return Python schedule code. */ String StepPrintAsPythonAPI(const Step& step, Array* stages, - StageToAxesMap* stage_to_axes, te::Schedule* schedule); + StageToAxesMap* stage_to_axes, te::Schedule* schedule, + const Array& transform_steps); /********** Steps working on single stage **********/ @@ -487,6 +490,167 @@ class SplitStep : public Step { TVM_DEFINE_OBJECT_REF_METHODS(SplitStep, Step, SplitStepNode); }; +/*! \brief Similar to SplitStepNode, but uses split factors from another step + * (i.e. Follow another split step) */ +class FollowSplitStepNode : public StepNode { + public: + /*! \brief The id of the iter to be split. */ + int iter_id; + /*! \brief The index of the split step to follow in the history. */ + int src_step_id; + /*! \brief The number of split level. */ + int n_split; + + void WriteToRecord(dmlc::JSONWriter* writer) const final; + + /*! + * \brief Extract split lengths. + * \param transform_steps An array record all transform steps. + * \param lengths The multiple split factors. Can be None to be filled by search policy. + */ + void ExtractSplitLengths(const Array& transform_steps, + Array>* lengths) const; + + /*! + * \brief Apply the current step to State. + * \param state A mutable pointer to state, which will be updated. + */ + Array ApplyToState(State* state) const; + + /*! + * \brief Apply the current step to tvm.schedule. + * \param stages The `te::Stage`s used in TVM scheduler applying. + * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. + * \param transform_steps An array record all transform steps. + * \return The iterator results after split. + */ + Array ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes, + const Array& transform_steps) const; + + /*! + * \brief Print the current step as equivalent python schedule API. + * \param stages The `te::Stage`s used in TVM scheduler applying. + * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. + * \param transform_steps An array record all transform steps. + * \return Python schedule code. + */ + String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes, + const Array& transform_steps) const; + + static constexpr const char* record_prefix_str = "FSP"; + + static constexpr const char* _type_key = "auto_scheduler.FollowSplitStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(FollowSplitStepNode, Object); +}; + +/*! + * \brief Managed reference to FollowSplitStepNode. + * \sa FollowSplitStepNode + */ +class FollowSplitStep : public Step { + public: + /*! + * \brief The constructor. + * \param stage_id The index of the stage to be split. + * \param iter_id The index of the iterator to be split. + * \param src_step_id The index of the split step to follow in the history. + * \param n_split The number of split level. + */ + FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split); + + /*! + * \brief The constructor used to read a step record from JSONReader and create the + * corresponding step. + * \param reader The input JSONReader. + */ + explicit FollowSplitStep(dmlc::JSONReader* reader); + + TVM_DEFINE_OBJECT_REF_METHODS(FollowSplitStep, Step, FollowSplitStepNode); +}; + +/*! \brief Similar to FollowSplitStep, but uses split factors from multiple steps. + * \note This can be used for the split in cooperative fetching. + */ +class FollowFusedSplitStepNode : public StepNode { + public: + /*! \brief The id of the iter to split. */ + int iter_id; + /*! \brief The indices of the split steps to follow in the history. */ + Array src_step_ids; + /*! \brief Use the length in this split level. */ + int level; + /*! \brief If this is true, use factor. Otherwise, use nparts. */ + bool factor_or_nparts; + + void WriteToRecord(dmlc::JSONWriter* writer) const final; + + /*! + * \brief Extract split length. + * \param transform_steps An array record all transform steps. + * \return Split factor. + */ + Optional ExtractSplitLength(const Array& transform_steps) const; + + /*! + * \brief Apply the current step to State. + * \param state A mutable pointer to state, which will be updated. + * \return The iterator results after split. + */ + Array ApplyToState(State* state) const; + + /*! + * \brief Apply the current step to tvm.schedule. + * \param stages The `te::Stage`s used in TVM scheduler applying. + * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. + * \param transform_steps An array record all transform steps. + * \return The iterator results after split. + */ + Array ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes, + const Array& transform_steps) const; + + /*! + * \brief Print the current step as equivalent python schedule API. + * \param stages The `te::Stage`s used in TVM scheduler applying. + * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. + * \param transform_steps An array record all transform steps. + * \return Python schedule code. + */ + String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes, + const Array& transform_steps) const; + + static constexpr const char* record_prefix_str = "FFSP"; + + static constexpr const char* _type_key = "auto_scheduler.FollowFusedSplitStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(FollowFusedSplitStepNode, Object); +}; + +/*! + * \brief Managed reference to FollowFusedSplitStepNode. + * \sa FollowFusedSplitStepNode + */ +class FollowFusedSplitStep : public Step { + public: + /*! + * \brief The constructor. + * \param stage_id The index of the stage to be split. + * \param iter_id The index of the iterator to be split. + * \param src_step_ids An array of index for split step to follow in the history. + * \param level Use the length in this split level. + * \param factor_or_nparts If this is true, use factor. Otherwise, use nparts. + */ + FollowFusedSplitStep(int stage_id, int iter_id, const Array& src_step_ids, int level, + bool factor_or_nparts); + + /*! + * \brief The constructor used to read a step record from JSONReader and create the + * corresponding step. + * \param reader The input JSONReader. + */ + explicit FollowFusedSplitStep(dmlc::JSONReader* reader); + + TVM_DEFINE_OBJECT_REF_METHODS(FollowFusedSplitStep, Step, FollowFusedSplitStepNode); +}; + /********** Steps working on multiple stages **********/ /*! \brief Compute at step that corresponds to te::Stage::compute_at */ diff --git a/python/tvm/auto_scheduler/loop_state.py b/python/tvm/auto_scheduler/loop_state.py index 8c3a936ccf0cc..9ec26c34a6e69 100644 --- a/python/tvm/auto_scheduler/loop_state.py +++ b/python/tvm/auto_scheduler/loop_state.py @@ -117,6 +117,15 @@ def stages(self): """ return self.state_object.stages + @property + def transform_steps(self): + """ + Returns + ------- + transform_steps : List[transform_steps] + """ + return self.state_object.transform_steps + @property def stage_ops(self): """ @@ -301,6 +310,93 @@ def split(self, stage, iterator, lengths, inner_to_outer=True): iterator, lengths, inner_to_outer) return res + def follow_split(self, stage, iterator, src_step_id, n_split): + """ Schedule primitive extends to split step. + + This step splits the iterator by the same factors as the given SplitStep. + + Notes + ------ + This step is useful in a scenario that we have subgraph Dense -> Relu, + and we want to compute the Dense stage at ReLU. In this case, we need them to have + the same tiling structure of common outer loops. + The follow_split step could be used here to split the Dense stage and makes sure its + splitting factors are the same as the given split step for the ReLU stage. + + Parameters + ---------- + stage : Union[int, Operation, Tensor] + The Stage to be split, which can be specified by the integer index, Operation, + or output tensor of the stage. + iterator : Iterator + The iterator to split. + src_step_id : int + The index of the split step to follow in the history. + n_split : int + The number of split level. + + Returns + ------- + res_its : List[Iterator] + The splitted new Iterators. + """ + + self.state_object, res = _ffi_api.StateFollowSplit(self.state_object, + self._resolve_stage_id(stage), + iterator, + src_step_id, n_split) + return res + + def follow_fused_split(self, stage, iterator, src_step_ids, level, + factor_or_nparts): + """ Schedule primitive extends to split step. + + This step is used to split an iterator by the same factors + as the given list of SplitSteps and FuseSteps. + + Notes + ------ + This step is useful in a scenario that we have a subgraph + in GPU schedule: Input -> Dense + for i.0@j.0 = ... : Bind to blockIdx.x + for i.1@j.1 = ... : Bind to threadIdx.x + for i.2@j.2 = ... + Input_shared = Input ... + for k = ... + Dense = ... + We intend to apply cooperative fetching with the input stage, while the threadIdx.x + axis is bound to an iterator generated by split & fuse step. + The follow_fused_step is used split the iterator to 2 parts, while the split factor + matches the final extent of the threadIdx.x bound iterator. + + Parameters + ---------- + stage : Union[int, Operation, Tensor] + The Stage to be split, which can be specified by the integer index, Operation, + or output tensor of the stage. + iterator : Iterator + The iterator to split. + src_step_ids : List[int] + The indices of the split steps to follow in the history. + level : int + Use the length in this split level. + factor_or_nparts : bool + True to use `factor` for split from inner to outer, + False to use `nparts` for split from outer to inner. + + Returns + ------- + res_its : List[Iterator] + The splitted new Iterators. + """ + + self.state_object, res = _ffi_api.StateFollowFusedSplit(self.state_object, + self._resolve_stage_id(stage), + iterator, + src_step_ids, level, + factor_or_nparts) + return res + def compute_at(self, stage, target_stage, target_iter): """ Schedule primitive corresponds to `te.Stage.compute_at`, see also the `te.Stage` for more details. diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index 2f6e948335c0a..f2815fbdec6bd 100644 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -678,7 +678,7 @@ std::pair> ComputeDAG::ApplySteps( // Apply the history steps to TVM schedule // Call each step's ApplyToSchedule method for (const auto& step : transform_steps) { - StepApplyToSchedule(step, stages, stage_to_axes, &schedule); + StepApplyToSchedule(step, stages, stage_to_axes, &schedule, transform_steps); } return std::make_pair(schedule, operator->()->tensors); @@ -722,7 +722,7 @@ String ComputeDAG::PrintStepsAsPython(const Array& transform_steps) const } // Call each step's PrintAsPythonAPI method for (const auto& step : transform_steps) { - ss << StepPrintAsPythonAPI(step, &stages, &stage_to_axes, &schedule); + ss << StepPrintAsPythonAPI(step, &stages, &stage_to_axes, &schedule, transform_steps); } return ss.str(); diff --git a/src/auto_scheduler/loop_state.cc b/src/auto_scheduler/loop_state.cc index 67c6b38845c32..636066ac957f0 100644 --- a/src/auto_scheduler/loop_state.cc +++ b/src/auto_scheduler/loop_state.cc @@ -268,6 +268,25 @@ Array State::split(int stage_id, const Iterator& it, return step->ApplyToState(this); } +Array State::follow_split(int stage_id, const Iterator& it, int src_step_id, + int n_split) { + const Stage& stage = operator->()->stages[stage_id]; + FollowSplitStep step = + FollowSplitStep(stage_id, GetIndex(stage->iters, it), src_step_id, n_split); + CopyOnWrite()->transform_steps.push_back(step); + return step->ApplyToState(this); +} + +Array State::follow_fused_split(int stage_id, const Iterator& it, + const Array& src_step_ids, int level, + bool factor_or_nparts) { + const Stage& stage = operator->()->stages[stage_id]; + FollowFusedSplitStep step = FollowFusedSplitStep(stage_id, GetIndex(stage->iters, it), + src_step_ids, level, factor_or_nparts); + CopyOnWrite()->transform_steps.push_back(step); + return step->ApplyToState(this); +} + void State::compute_at(int stage_id, int target_stage_id, const Iterator& target_iter) { const Stage& target_stage = operator->()->stages[target_stage_id]; ComputeAtStep step = @@ -454,6 +473,21 @@ TVM_REGISTER_GLOBAL("auto_scheduler.StateSplit") return Array{state, res}; }); +TVM_REGISTER_GLOBAL("auto_scheduler.StateFollowSplit") + .set_body_typed([](State state, int stage_id, const Iterator& it, int src_step_id, + int n_split) { + const auto& res = state.follow_split(stage_id, it, src_step_id, n_split); + return Array{state, Array(res)}; + }); + +TVM_REGISTER_GLOBAL("auto_scheduler.StateFollowFusedSplit") + .set_body_typed([](State state, int stage_id, const Iterator& it, + const Array& src_step_ids, int level, bool factor_or_nparts) { + const auto& res = + state.follow_fused_split(stage_id, it, src_step_ids, level, factor_or_nparts); + return Array{state, Array(res)}; + }); + TVM_REGISTER_GLOBAL("auto_scheduler.StateComputeAt") .set_body_typed([](State state, int stage_id, int target_stage_id, const Iterator& target_iter) { diff --git a/src/auto_scheduler/transform_step.cc b/src/auto_scheduler/transform_step.cc index 5c5cc4b2e760f..d43d0af144998 100644 --- a/src/auto_scheduler/transform_step.cc +++ b/src/auto_scheduler/transform_step.cc @@ -85,6 +85,10 @@ Step StepReadFromRecord(dmlc::JSONReader* reader) { return ReorderStep(reader); } else if (name == SplitStepNode::record_prefix_str) { return SplitStep(reader); + } else if (name == FollowSplitStepNode::record_prefix_str) { + return FollowSplitStep(reader); + } else if (name == FollowFusedSplitStepNode::record_prefix_str) { + return FollowFusedSplitStep(reader); } else if (name == ComputeAtStepNode::record_prefix_str) { return ComputeAtStep(reader); } else if (name == ComputeInlineStepNode::record_prefix_str) { @@ -111,6 +115,10 @@ void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag) { ps->ApplyToState(state); } else if (auto ps = step.as()) { ps->ApplyToState(state); + } else if (auto ps = step.as()) { + ps->ApplyToState(state); + } else if (auto ps = step.as()) { + ps->ApplyToState(state); } else if (auto ps = step.as()) { ps->ApplyToState(state); } else if (auto ps = step.as()) { @@ -127,7 +135,7 @@ void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag) { } void StepApplyToSchedule(const Step& step, Array* stages, StageToAxesMap* stage_to_axes, - te::Schedule* schedule) { + te::Schedule* schedule, const Array& transform_steps) { if (auto ps = step.as()) { ps->ApplyToSchedule(stages, stage_to_axes); } else if (auto ps = step.as()) { @@ -136,6 +144,10 @@ void StepApplyToSchedule(const Step& step, Array* stages, StageToAxes ps->ApplyToSchedule(stages, stage_to_axes); } else if (auto ps = step.as()) { ps->ApplyToSchedule(stages, stage_to_axes); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes, transform_steps); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes, transform_steps); } else if (auto ps = step.as()) { ps->ApplyToSchedule(stages, stage_to_axes); } else if (auto ps = step.as()) { @@ -152,7 +164,8 @@ void StepApplyToSchedule(const Step& step, Array* stages, StageToAxes } String StepPrintAsPythonAPI(const Step& step, Array* stages, - StageToAxesMap* stage_to_axes, te::Schedule* schedule) { + StageToAxesMap* stage_to_axes, te::Schedule* schedule, + const Array& transform_steps) { if (auto ps = step.as()) { return ps->PrintAsPythonAPI(stages, stage_to_axes); } else if (auto ps = step.as()) { @@ -161,6 +174,10 @@ String StepPrintAsPythonAPI(const Step& step, Array* stages, return ps->PrintAsPythonAPI(stages, stage_to_axes); } else if (auto ps = step.as()) { return ps->PrintAsPythonAPI(stages, stage_to_axes); + } else if (auto ps = step.as()) { + return ps->PrintAsPythonAPI(stages, stage_to_axes, transform_steps); + } else if (auto ps = step.as()) { + return ps->PrintAsPythonAPI(stages, stage_to_axes, transform_steps); } else if (auto ps = step.as()) { return ps->PrintAsPythonAPI(stages, stage_to_axes); } else if (auto ps = step.as()) { @@ -776,6 +793,193 @@ String SplitStepNode::PrintAsPythonAPI(Array* stages, return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, lengths, inner_to_outer); } +/********** Follow Split **********/ +FollowSplitStep::FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split) { + auto node = make_object(); + node->stage_id = stage_id; + node->iter_id = iter_id; + node->src_step_id = src_step_id; + node->n_split = n_split; + data_ = std::move(node); +} + +void FollowSplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { + writer->WriteArraySeperator(); + writer->WriteString(record_prefix_str); + writer->WriteArrayItem(stage_id); + writer->WriteArrayItem(iter_id); + writer->WriteArrayItem(src_step_id); + writer->WriteArrayItem(n_split); +} + +void FollowSplitStepNode::ExtractSplitLengths(const Array& transform_steps, + Array>* lengths) const { + // Make sure src_step_id is within the range of transform_steps. + CHECK_LT(src_step_id, transform_steps.size()); + auto ps = transform_steps[src_step_id].as(); + CHECK(ps != nullptr); + + // Make sure the size of ps->lengths is not smaller than n_split-1. + // Note that the number of actual splitting factors of src_step is ps->lengths.size()+1. + CHECK_LE(n_split, ps->lengths.size() + 1); + CHECK(ps != nullptr); + + lengths->reserve(n_split); + int j = 0; + // Get the first (n_split-1) split factors of followed src_step. + for (; j < n_split - 1; ++j) { + lengths->push_back(ps->lengths[j]); + } + + // Get the last split factor of src_step for splitting level if n_split is smaller than + // ps->lengths.size()+1. + PrimExpr last_factor = 1; + for (; j < static_cast(ps->lengths.size()); ++j) { + if (ps->lengths[j]) { + last_factor *= ps->lengths[j].value(); + } else { + last_factor = PrimExpr(); + break; + } + } + if (last_factor.defined()) { + lengths->push_back(Downcast(last_factor)); + } else { + lengths->push_back(NullOpt); + } +} + +FollowSplitStep::FollowSplitStep(dmlc::JSONReader* reader) { + auto node = make_object(); + bool s; + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->stage_id); + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->iter_id); + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->src_step_id); + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->n_split); + + data_ = std::move(node); +} + +Array FollowSplitStepNode::ApplyToState(State* state) const { + Array> lengths; + ExtractSplitLengths((*state)->transform_steps, &lengths); + return ApplySplitToState(state, stage_id, iter_id, lengths, true); +} + +Array FollowSplitStepNode::ApplyToSchedule(Array* stages, + StageToAxesMap* stage_to_axes, + const Array& transform_steps) const { + Array> lengths; + ExtractSplitLengths(transform_steps, &lengths); + return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, lengths, true); +} + +String FollowSplitStepNode::PrintAsPythonAPI(Array* stages, + StageToAxesMap* stage_to_axes, + const Array& transform_steps) const { + Array> lengths; + ExtractSplitLengths(transform_steps, &lengths); + return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, lengths, true); +} + +/********** Follow Fused Split **********/ +FollowFusedSplitStep::FollowFusedSplitStep(int stage_id, int iter_id, + const Array& src_step_ids, int level, + bool factor_or_nparts) { + auto node = make_object(); + node->stage_id = stage_id; + node->iter_id = iter_id; + node->src_step_ids = src_step_ids; + node->level = level; + node->factor_or_nparts = factor_or_nparts; + data_ = std::move(node); +} + +FollowFusedSplitStep::FollowFusedSplitStep(dmlc::JSONReader* reader) { + auto node = make_object(); + bool s; + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->stage_id); + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->iter_id); + s = reader->NextArrayItem(); + CHECK(s); + std::vector int_list; + reader->Read(&int_list); + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->level); + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->factor_or_nparts); + + ::tvm::Array<::tvm::Integer> src_step_ids; + for (const auto& i : int_list) { + src_step_ids.push_back(i); + } + node->src_step_ids = src_step_ids; + data_ = std::move(node); +} + +void FollowFusedSplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { + writer->WriteArraySeperator(); + writer->WriteString(record_prefix_str); + writer->WriteArrayItem(stage_id); + writer->WriteArrayItem(iter_id); + writer->WriteArrayItem(IntArrayToVector(src_step_ids)); + writer->WriteArrayItem(level); + writer->WriteArrayItem(static_cast(factor_or_nparts)); +} + +Optional FollowFusedSplitStepNode::ExtractSplitLength( + const Array& transform_steps) const { + PrimExpr ret(1); + + for (int src_step_id : src_step_ids) { + // Make sure the src_step_id is within the range of transform_steps. + CHECK_LT(src_step_id, transform_steps.size()); + auto ps = transform_steps[src_step_id].as(); + CHECK(ps != nullptr); + // Multiple the splitting factor on corresponding splitting level of src_steps. + if (ps->lengths[level] && ret.defined()) { + ret *= ps->lengths[level].value(); + } else { + return NullOpt; + } + } + return Downcast(ret); +} + +Array FollowFusedSplitStepNode::ApplyToState(State* state) const { + const Optional& length = ExtractSplitLength((*state)->transform_steps); + return ApplySplitToState(state, stage_id, iter_id, {length}, factor_or_nparts); +} + +Array FollowFusedSplitStepNode::ApplyToSchedule(Array* stages, + StageToAxesMap* stage_to_axes, + const Array& transform_steps) const { + const Optional& length = ExtractSplitLength(transform_steps); + return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, {length}, factor_or_nparts); +} + +String FollowFusedSplitStepNode::PrintAsPythonAPI(Array* stages, + StageToAxesMap* stage_to_axes, + const Array& transform_steps) const { + const Optional& length = ExtractSplitLength(transform_steps); + return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, {length}, + factor_or_nparts); +} + /********** Steps working on multiple stages **********/ /********** Compute At **********/ diff --git a/tests/python/unittest/test_auto_scheduler_loop_state.py b/tests/python/unittest/test_auto_scheduler_loop_state.py index 8282d4a40e5ef..e35dfe3b0d539 100644 --- a/tests/python/unittest/test_auto_scheduler_loop_state.py +++ b/tests/python/unittest/test_auto_scheduler_loop_state.py @@ -85,7 +85,6 @@ def test_split_fuse_reorder_annotation(): assert res == s1[C].iters[5] assert res.annotation == auto_scheduler.loop_state.State.ANNOTATION_TRANS_TABLE["vectorize"] - def test_compute_at_root_inline(): dag = auto_scheduler.ComputeDAG(conv2d_nchw_bn_relu(N=1, H=224, W=224, CI=3, CO=64, kernel_size=7, strides=2, padding=3)) @@ -142,7 +141,6 @@ def test_compute_at_root_inline(): assert s0[conv].iters[5].range.extent == 7 assert s0[conv].iters[6].range.extent == 7 - def test_cache_read_write(): N, H, W, CO, CI, KH, KW, strides, padding = 4, 7, 7, 512, 512, 3, 3, ( 1, 1), (1, 1) @@ -417,7 +415,47 @@ def test_cache_read_write(): for it0, it1 in zip(s0[kernel_split].iters, s0[kernel_split_global].iters): assert it0.range == it1.range +def test_follow_split_follow_fused_split(): + A, B, C = matmul_auto_scheduler_test(512, 512, 512) + dag = auto_scheduler.ComputeDAG([A, B, C]) + s0 = dag.get_init_state() + + C_global = s0.cache_write(C, "global") + its0 = s0.split(C, s0[C].iters[0], [4, 2, 8, 4], True) + split_step0 = len(s0.transform_steps) - 1 + for level in range(1, 6): + tmp = s0.copy() + tmp.follow_split(C_global, tmp[C_global].iters[0], split_step0, level) + for i in range(0, level): + assert tmp[C].iters[i].range.extent == \ + tmp[C_global].iters[i].range.extent + + its1 = s0.split(C, s0[C].iters[5], [2, 2, 4, 8]) + split_step1 = len(s0.transform_steps) - 1 + its = [] + for i0, i1 in zip(its0, its1): + its.append(i0) + its.append(i1) + s0.reorder(C, its) + for i in range(0, 5): + s0.fuse(C, [s0[C].iters[i], s0[C].iters[i + 1]]) + + for level in range(0, 4): + tmp = s0.copy() + tmp.follow_fused_split(C_global, tmp[C_global].iters[0], + [split_step0, split_step1], level, False) + assert tmp[C].iters[level + 1].range.extent == \ + tmp[C_global].iters[0].range.extent + + for level in range(0, 4): + tmp = s0.copy() + tmp.follow_fused_split(C_global, tmp[C_global].iters[0], + [split_step0, split_step1], level, True) + assert tmp[C].iters[level + 1].range.extent == \ + tmp[C_global].iters[1].range.extent + if __name__ == "__main__": test_split_fuse_reorder_annotation() test_compute_at_root_inline() test_cache_read_write() + test_follow_split_follow_fused_split() diff --git a/tests/python/unittest/test_auto_scheduler_measure.py b/tests/python/unittest/test_auto_scheduler_measure.py index 5f2f87ad9baa2..39d01e0c29695 100644 --- a/tests/python/unittest/test_auto_scheduler_measure.py +++ b/tests/python/unittest/test_auto_scheduler_measure.py @@ -22,8 +22,7 @@ from tvm import te, auto_scheduler import tempfile -from test_auto_scheduler_common import get_tiled_matmul - +from test_auto_scheduler_common import matmul_auto_scheduler_test, get_tiled_matmul def test_record(): if not tvm.runtime.enabled("llvm"): @@ -37,8 +36,12 @@ def test_record(): k = te.reduce_axis((0, 512), name='k') E = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * D[k][j], axis=[k]), name='E') F = topi.nn.relu(E) + k = te.reduce_axis((0, 512), name='k') + G = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * F[k][j], axis=[k]), name='G') + H = topi.nn.relu(G) + I = topi.nn.relu(H) - dag = auto_scheduler.ComputeDAG([A, B, F]) + dag = auto_scheduler.ComputeDAG([A, B, I]) s = dag.get_init_state() # Split @@ -71,6 +74,22 @@ def test_record(): s.compute_at(D_global, E, s[E].iters[2]) # Cache Write s.cache_write(D, "shared") + #follow_split + its2 = s.split(G, s[G].iters[0], [4, 2, 8, 4], True) + split_step0 = len(s.transform_steps) - 1 + s.follow_split(G, s[G].iters[5], split_step0, 4) + #follow_fused_split + its2 = s.split(H, s[H].iters[0], [4, 2, 8, 4], True) + split_step1 = len(s.transform_steps) - 1 + its3 = s.split(H, s[H].iters[5], [2, 4, 2, 4], True) + split_step2 = len(s.transform_steps) - 1 + its = [] + for i0, i1 in zip(its2, its3): + its.append(i0) + its.append(i1) + for i in range(0, 5): + s.fuse(H, [s[H].iters[i], s[H].iters[i + 1]]) + s.follow_fused_split(I, s[I].iters[0], [split_step1, split_step2], 0, False) target = tvm.target.create("llvm") task = auto_scheduler.SearchTask(dag, "test", target)