diff --git a/include/tvm/auto_scheduler/loop_state.h b/include/tvm/auto_scheduler/loop_state.h index 1c8ea770e2f84..059184b721585 100644 --- a/include/tvm/auto_scheduler/loop_state.h +++ b/include/tvm/auto_scheduler/loop_state.h @@ -340,6 +340,13 @@ class State : public ObjectRef { * result will become the new attach point. */ TVM_DLL Iterator fuse(int stage_id, const Array& iters); + /*! + * \brief Schedule primitive corresponds to `te.Stage.pragma`. + * \param stage_id The index of the stage to add pragma. + * \param it The iterator to add pragma. + * \param pragma_type The pragma string. + */ + TVM_DLL void pragma(int stage_id, const Iterator& it, const String& pragma_type); /*! * \brief Schedule primitive corresponds to `te::Stage::reorder`. * \param stage_id The index of the stage to be reordered. @@ -359,6 +366,14 @@ class State : public ObjectRef { TVM_DLL Array split(int stage_id, const Iterator& it, const Array>& lengths, bool inner_to_outer = true); + /*! + * \brief Schedule primitive corresponds to `te.Stage.storage_align`. + * \param stage_id The index of the stage to be aligned. + * \param it The iterator to be aligned. + * \param factor The factor in alignment specification. + * \param offset The offset in the alignment specification. + */ + TVM_DLL void storage_align(int stage_id, const Iterator& it, int factor, int offset); /********** Step APIs working on multiple stages **********/ @@ -399,8 +414,8 @@ class State : public ObjectRef { * \note Cache read step will add an extra stage to the original ComputeDAG (at the back of the * target stage), a up-to-date ComputeDAG is stored in State's `current_compute_dag`. */ - int cache_read(int stage_id, const String& scope_name, const Array& reader_stage_ids, - const ComputeDAG& dag); + TVM_DLL int cache_read(int stage_id, const String& scope_name, + const Array& reader_stage_ids, const ComputeDAG& dag); /*! * \brief Schedule primitive corresponds to `te::Schedule::cache_write`. * \param stage_id The index of the stage to be cache write. @@ -410,7 +425,17 @@ class State : public ObjectRef { * target stage), a up-to-date ComputeDAG is stored in State's `current_compute_dag`. * This step will cache write all output tensors of the target stage. */ - int cache_write(int stage_id, const String& scope_name, const ComputeDAG& dag); + TVM_DLL int cache_write(int stage_id, const String& scope_name, const ComputeDAG& dag); + /*! + * \brief Schedule primitive corresponds to `te::Schedule::rfactor`. + * \param stage_id The index of the iterator to be factored. + * \param iter_id The iterator to be factored. + * \param factor_iter_id The position where the new iterator is placed. + * \param dag The original ComputeDAG of this state. + * \note Rfactor step will add an extra stage to the original ComputeDAG, a up-to-date + * ComputeDAG is stored in State's `current_compute_dag`. + */ + TVM_DLL int rfactor(int stage_id, const Iterator& it, int factor_iter_id, const ComputeDAG& dag); TVM_DEFINE_OBJECT_REF_METHODS(State, ObjectRef, StateNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(StateNode); diff --git a/include/tvm/auto_scheduler/transform_step.h b/include/tvm/auto_scheduler/transform_step.h index 83d6e298a7d7d..57dc6ac6dcf77 100644 --- a/include/tvm/auto_scheduler/transform_step.h +++ b/include/tvm/auto_scheduler/transform_step.h @@ -347,6 +347,67 @@ class FuseStep : public Step { TVM_DEFINE_OBJECT_REF_METHODS(FuseStep, Step, FuseStepNode); }; +/*! \brief Pragma step that corresponds to te::Stage::pragma */ +class PragmaStepNode : public StepNode { + public: + /*! \brief The index of the iterator to add pragma. */ + int iter_id; + /*! \brief The pragma string. */ + String pragma_type; + + void WriteToRecord(dmlc::JSONWriter* writer) const final; + + /*! + * \brief Apply the current step to State. + * \param state A mutable pointer to State. + */ + void ApplyToState(State* state) const; + + /*! + * \brief Apply the current step to tvm.schedule. + * \param stages A pointer to a `te::Stage` Array. + * \param stage_to_axes A pointer to a StageToAxesMap. + */ + void ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const; + + /*! + * \brief Print the current step as equivalent python schedule API. + * \param stages A pointer to a `te::Stage` Array. + * \param stage_to_axes A pointer to a StageToAxesMap. + * \return Python schedule code. + */ + String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; + + static constexpr const char* record_prefix_str = "PR"; + + static constexpr const char* _type_key = "auto_scheduler.PragmaStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(PragmaStepNode, Object); +}; + +/*! + * \brief Managed reference to PragmaStepNode. + * \sa PragmaStepNode + */ +class PragmaStep : public Step { + public: + /*! + * \brief The constructor. + * \param stage_id The index of the stage to be fused. + * \param iter_id The index of the iterator to add pragma. + * \param pragma_type The pragma string. + */ + PragmaStep(int stage_id, int iter_id, String pragma_type); + + /*! + * \brief The constructor used to read a step record from JSONReader and create the + * corresponding step. + * \param reader The input JSONReader. + */ + explicit PragmaStep(dmlc::JSONReader* reader); + + TVM_DEFINE_OBJECT_REF_METHODS(PragmaStep, Step, PragmaStepNode); +}; + /*! \brief Reorder step that corresponds to te::Stage::reorder */ class ReorderStepNode : public StepNode { public: @@ -487,6 +548,70 @@ class SplitStep : public Step { TVM_DEFINE_OBJECT_REF_METHODS(SplitStep, Step, SplitStepNode); }; +/*! \brief Storage align step that corresponds to te::Stage::storage_align */ +class StorageAlignStepNode : public StepNode { + public: + /*! \brief The iterator to be aligned. */ + int iter_id; + /*! \brief The factor in alignment specification. */ + int factor; + /*! \brief The offset in the alignment specification. */ + int offset; + + void WriteToRecord(dmlc::JSONWriter* writer) const final; + + /*! + * \brief Apply the current step to State. + * \param state A mutable pointer to State. + */ + void ApplyToState(State* state) const; + + /*! + * \brief Apply the current step to tvm.schedule. + * \param stages A pointer to a `te::Stage` Array. + * \param stage_to_axes A pointer to a StageToAxesMap. + */ + void ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const; + + /*! + * \brief Print the current step as equivalent python schedule API. + * \param stages A pointer to a `te::Stage` Array. + * \param stage_to_axes A pointer to a StageToAxesMap. + * \return Python schedule code. + */ + String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; + + static constexpr const char* record_prefix_str = "SA"; + + static constexpr const char* _type_key = "auto_scheduler.StorageAlignStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(StorageAlignStepNode, Object); +}; + +/*! + * \brief Managed reference to StorageAlignStepNode. + * \sa StorageAlignStepNode + */ +class StorageAlignStep : public Step { + public: + /*! + * \brief The constructor. + * \param stage_id The index of the stage to be aligned. + * \param iter_id The index of the iterator to be aligned. + * \param factor The factor in alignment specification. + * \param offset The offset in the alignment specification. + */ + StorageAlignStep(int stage_id, int iter_id, int factor, int offset); + + /*! + * \brief The constructor used to read a step record from JSONReader and create the + * corresponding step. + * \param reader The input JSONReader. + */ + explicit StorageAlignStep(dmlc::JSONReader* reader); + + TVM_DEFINE_OBJECT_REF_METHODS(StorageAlignStep, Step, StorageAlignStepNode); +}; + /********** Steps working on multiple stages **********/ /*! \brief Compute at step that corresponds to te::Stage::compute_at */ @@ -668,7 +793,7 @@ class ComputeRootStep : public Step { TVM_DEFINE_OBJECT_REF_METHODS(ComputeRootStep, Step, ComputeRootStepNode); }; -/********** Primitives adding new stages **********/ +/********** Steps adding new stages **********/ /*! * \brief Cache read step that corresponds to te::Schedule::cache_read. @@ -812,6 +937,74 @@ class CacheWriteStep : public Step { TVM_DEFINE_OBJECT_REF_METHODS(CacheWriteStep, Step, CacheWriteStepNode); }; +/*! \brief Reduction factor step that corresponds to te::Schedule::rfactor */ +class RfactorStepNode : public StepNode { + public: + /*! \brief The index of the iterator to be factored. */ + int iter_id; + /*! \brief The position where the new iterator is placed. */ + int factor_iter_id; + + void WriteToRecord(dmlc::JSONWriter* writer) const final; + + /*! + * \brief Apply the current step to State. + * \param state A mutable pointer to State. + * \param dag The original ComputeDAG of this state. + * \return The index of the new added stage. + */ + int ApplyToState(State* state, const ComputeDAG& dag) const; + + /*! + * \brief Apply the current step to tvm.schedule. + * \param stages A mutable pointer to a `te::Stage` Array. + * \param stage_to_axes A mutable pointer to a StageToAxesMap. + * \param schedule A mutable pointer to a te::Schedule. + * \return The output Tensors of the new added stage. + */ + Array ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes, + te::Schedule* schedule) const; + + /*! + * \brief Print the current step as equivalent python schedule API. + * \param stages A mutable pointer to a `te::Stage` Array. + * \param stage_to_axes A mutable pointer to a StageToAxesMap. + * \param schedule A mutable pointer to a te::Schedule. + * \return Python schedule code. + */ + String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes, + te::Schedule* schedule) const; + + static constexpr const char* record_prefix_str = "RF"; + + static constexpr const char* _type_key = "auto_scheduler.RfactorStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(RfactorStepNode, Object); +}; + +/*! + * \brief Managed reference to RfactorStepNode. + * \sa RfactorStepNode + */ +class RfactorStep : public Step { + public: + /*! + * \brief The constructor. + * \param stage_id The index of the iterator to be factored. + * \param iter_id The index of the iterator to be factored. + * \param factor_iter_id The position where the new iterator is placed. + */ + RfactorStep(int stage_id, int iter_id, int factor_iter_id); + + /*! + * \brief The constructor used to read a step record from JSONReader and create the + * corresponding step. + * \param reader The input JSONReader. + */ + explicit RfactorStep(dmlc::JSONReader* reader); + + TVM_DEFINE_OBJECT_REF_METHODS(RfactorStep, Step, RfactorStepNode); +}; + } // namespace auto_scheduler } // namespace tvm diff --git a/python/tvm/auto_scheduler/loop_state.py b/python/tvm/auto_scheduler/loop_state.py index 8c3a936ccf0cc..c5f512b721f25 100644 --- a/python/tvm/auto_scheduler/loop_state.py +++ b/python/tvm/auto_scheduler/loop_state.py @@ -252,6 +252,22 @@ def fuse(self, stage, iters): self._resolve_stage_id(stage), iters) return res + def pragma(self, stage, iterator, pragma_type): + """ Schedule primitive corresponds to te.pragma. + + Parameters + ---------- + stage : Union[int, Operation, Tensor] + The Stage to add pragma, which can be specified by the integer index, Operation, + or output tensor of the stage. + iterator : Iterator + The iterator to add pragma. + pragma_type : str + The pragma string. + """ + self.state_object = _ffi_api.StatePragma(self.state_object, self._resolve_stage_id(stage), + iterator, pragma_type) + def reorder(self, stage, order): """ Schedule primitive corresponds to `te.Stage.reorder`, see also the `te.Stage` for more details. @@ -301,6 +317,27 @@ def split(self, stage, iterator, lengths, inner_to_outer=True): iterator, lengths, inner_to_outer) return res + def storage_align(self, stage, iterator, factor, offset): + """ Schedule primitive corresponds to te.storage_align. + + See `te.schedule.Stage.storage_align` for more information. + + Parameters + ---------- + stage : Union[int, Operation, Tensor] + The Stage to be storage aligned, which can be specified by the integer index, + Operation, or output tensor of the stage. + iterator : Iterator + The iterator to be aligned. + factor : int + The factor in alignment specification. + offset : int + The offset in the alignment specification. + """ + self.state_object = _ffi_api.StateStorageAlign(self.state_object, + self._resolve_stage_id(stage), iterator, + factor, offset) + def compute_at(self, stage, target_stage, target_iter): """ Schedule primitive corresponds to `te.Stage.compute_at`, see also the `te.Stage` for more details. @@ -429,6 +466,37 @@ def cache_write(self, stage, scope_name): self._update_stage_id_map() return self.stages[int(new_stage_id)].op + def rfactor(self, stage, iterator, factor_iter_id): + """ Schedule primitive corresponds to te.schedule.rfactor. + + See `te.schedule.Schedule.rfactor` for more information. + + Parameters + ---------- + stage : Union[int, Operation, Tensor] + The Stage to be factored, which can be specified by the integer index, Operation, + or output tensor of the stage. + iterator : Iterator + The reduction iterator to be factored. + factor_iter_id : int + The position where the new iterator is placed. + + Returns + ------- + new_stage_op : Operator + The Operator of the new added stage. + + Notes + ----- + Rfactor step will insert an extra stage to the original ComputeDAG (in the front of the + target stage). + """ + self.state_object, new_stage_id = _ffi_api.StateRfactor(self.state_object, + self._resolve_stage_id(stage), + iterator, factor_iter_id, + self.compute_dag) + return self._insert_new_stage(int(new_stage_id)) + def copy(self): """ Do deep copy of this State. """ state = State(self.state_object, self.compute_dag) diff --git a/src/auto_scheduler/loop_state.cc b/src/auto_scheduler/loop_state.cc index 67c6b38845c32..481ca0f762419 100644 --- a/src/auto_scheduler/loop_state.cc +++ b/src/auto_scheduler/loop_state.cc @@ -247,6 +247,13 @@ Iterator State::fuse(int stage_id, const Array& iters) { return step->ApplyToState(this); } +void State::pragma(int stage_id, const Iterator& it, const String& pragma_type) { + const Stage& stage = operator->()->stages[stage_id]; + PragmaStep step = PragmaStep(stage_id, GetIndex(stage->iters, it), pragma_type); + CopyOnWrite()->transform_steps.push_back(step); + return step->ApplyToState(this); +} + void State::reorder(int stage_id, const Array& order) { const Stage& stage = operator->()->stages[stage_id]; CHECK_EQ(order.size(), stage->iters.size()) << "The order of all iterators " @@ -268,6 +275,13 @@ Array State::split(int stage_id, const Iterator& it, return step->ApplyToState(this); } +void State::storage_align(int stage_id, const Iterator& it, int factor, int offset) { + const Stage& stage = operator->()->stages[stage_id]; + StorageAlignStep step = StorageAlignStep(stage_id, GetIndex(stage->iters, it), factor, offset); + CopyOnWrite()->transform_steps.push_back(step); + return step->ApplyToState(this); +} + void State::compute_at(int stage_id, int target_stage_id, const Iterator& target_iter) { const Stage& target_stage = operator->()->stages[target_stage_id]; ComputeAtStep step = @@ -301,6 +315,13 @@ int State::cache_write(int stage_id, const String& scope_name, const ComputeDAG& return step->ApplyToState(this, dag); } +int State::rfactor(int stage_id, const Iterator& it, int factor_iter_id, const ComputeDAG& dag) { + const Stage& stage = operator->()->stages[stage_id]; + RfactorStep step = RfactorStep(stage_id, GetIndex(stage->iters, it), factor_iter_id); + CopyOnWrite()->transform_steps.push_back(step); + return step->ApplyToState(this, dag); +} + void State::ApplySteps(const ComputeDAG& dag) { CHECK(operator->()->stages.size()) << "Invalid State with empty operation stages."; @@ -441,6 +462,12 @@ TVM_REGISTER_GLOBAL("auto_scheduler.StateFuse") return Array{state, res}; }); +TVM_REGISTER_GLOBAL("auto_scheduler.StatePragma") + .set_body_typed([](State state, int stage_id, const Iterator& it, const String& pragma_type) { + state.pragma(stage_id, it, pragma_type); + return state; + }); + TVM_REGISTER_GLOBAL("auto_scheduler.StateReorder") .set_body_typed([](State state, int stage_id, const Array& order) { state.reorder(stage_id, order); @@ -454,6 +481,12 @@ TVM_REGISTER_GLOBAL("auto_scheduler.StateSplit") return Array{state, res}; }); +TVM_REGISTER_GLOBAL("auto_scheduler.StateStorageAlign") + .set_body_typed([](State state, int stage_id, const Iterator& it, int factor, int offset) { + state.storage_align(stage_id, it, factor, offset); + return state; + }); + TVM_REGISTER_GLOBAL("auto_scheduler.StateComputeAt") .set_body_typed([](State state, int stage_id, int target_stage_id, const Iterator& target_iter) { @@ -487,6 +520,13 @@ TVM_REGISTER_GLOBAL("auto_scheduler.StateCacheWrite") return Array{state, Integer(res)}; }); +TVM_REGISTER_GLOBAL("auto_scheduler.StateRfactor") + .set_body_typed([](State state, int stage_id, const Iterator& it, int factor_iter_id, + const ComputeDAG& dag) { + int res = state.rfactor(stage_id, it, factor_iter_id, dag); + return Array{state, Integer(res)}; + }); + TVM_REGISTER_GLOBAL("auto_scheduler.StateEqual").set_body_typed([](State state1, State state2) { return std::equal_to()(state1, state2); }); diff --git a/src/auto_scheduler/transform_step.cc b/src/auto_scheduler/transform_step.cc index 5c5cc4b2e760f..2eae04e828633 100644 --- a/src/auto_scheduler/transform_step.cc +++ b/src/auto_scheduler/transform_step.cc @@ -81,10 +81,14 @@ Step StepReadFromRecord(dmlc::JSONReader* reader) { return AnnotationStep(reader); } else if (name == FuseStepNode::record_prefix_str) { return FuseStep(reader); + } else if (name == PragmaStepNode::record_prefix_str) { + return PragmaStep(reader); } else if (name == ReorderStepNode::record_prefix_str) { return ReorderStep(reader); } else if (name == SplitStepNode::record_prefix_str) { return SplitStep(reader); + } else if (name == StorageAlignStepNode::record_prefix_str) { + return StorageAlignStep(reader); } else if (name == ComputeAtStepNode::record_prefix_str) { return ComputeAtStep(reader); } else if (name == ComputeInlineStepNode::record_prefix_str) { @@ -95,6 +99,8 @@ Step StepReadFromRecord(dmlc::JSONReader* reader) { return CacheReadStep(reader); } else if (name == CacheWriteStepNode::record_prefix_str) { return CacheWriteStep(reader); + } else if (name == RfactorStepNode::record_prefix_str) { + return RfactorStep(reader); } else { LOG(FATAL) << "Invalid step format: " << name; } @@ -107,10 +113,14 @@ void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag) { ps->ApplyToState(state); } else if (auto ps = step.as()) { ps->ApplyToState(state); + } else if (auto ps = step.as()) { + ps->ApplyToState(state); } else if (auto ps = step.as()) { ps->ApplyToState(state); } else if (auto ps = step.as()) { ps->ApplyToState(state); + } else if (auto ps = step.as()) { + ps->ApplyToState(state); } else if (auto ps = step.as()) { ps->ApplyToState(state); } else if (auto ps = step.as()) { @@ -121,6 +131,8 @@ void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag) { ps->ApplyToState(state, dag); } else if (auto ps = step.as()) { ps->ApplyToState(state, dag); + } else if (auto ps = step.as()) { + ps->ApplyToState(state, dag); } else { LOG(FATAL) << "Invalid step: " << step; } @@ -132,10 +144,14 @@ void StepApplyToSchedule(const Step& step, Array* stages, StageToAxes ps->ApplyToSchedule(stages, stage_to_axes); } else if (auto ps = step.as()) { ps->ApplyToSchedule(stages, stage_to_axes); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes); } else if (auto ps = step.as()) { ps->ApplyToSchedule(stages, stage_to_axes); } else if (auto ps = step.as()) { ps->ApplyToSchedule(stages, stage_to_axes); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes); } else if (auto ps = step.as()) { ps->ApplyToSchedule(stages, stage_to_axes); } else if (auto ps = step.as()) { @@ -146,6 +162,8 @@ void StepApplyToSchedule(const Step& step, Array* stages, StageToAxes ps->ApplyToSchedule(stages, stage_to_axes, schedule); } else if (auto ps = step.as()) { ps->ApplyToSchedule(stages, stage_to_axes, schedule); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes, schedule); } else { LOG(FATAL) << "Invalid Step: " << step; } @@ -157,10 +175,14 @@ String StepPrintAsPythonAPI(const Step& step, Array* stages, return ps->PrintAsPythonAPI(stages, stage_to_axes); } else if (auto ps = step.as()) { return ps->PrintAsPythonAPI(stages, stage_to_axes); + } else if (auto ps = step.as()) { + return ps->PrintAsPythonAPI(stages, stage_to_axes); } else if (auto ps = step.as()) { return ps->PrintAsPythonAPI(stages, stage_to_axes); } else if (auto ps = step.as()) { return ps->PrintAsPythonAPI(stages, stage_to_axes); + } else if (auto ps = step.as()) { + return ps->PrintAsPythonAPI(stages, stage_to_axes); } else if (auto ps = step.as()) { return ps->PrintAsPythonAPI(stages, stage_to_axes); } else if (auto ps = step.as()) { @@ -171,6 +193,8 @@ String StepPrintAsPythonAPI(const Step& step, Array* stages, return ps->PrintAsPythonAPI(stages, stage_to_axes, schedule); } else if (auto ps = step.as()) { return ps->PrintAsPythonAPI(stages, stage_to_axes, schedule); + } else if (auto ps = step.as()) { + return ps->PrintAsPythonAPI(stages, stage_to_axes, schedule); } else { LOG(FATAL) << "Invalid Step: " << step; } @@ -471,6 +495,115 @@ String FuseStepNode::PrintAsPythonAPI(Array* stages, return ss.str(); } +/********** Pragma **********/ +PragmaStep::PragmaStep(int stage_id, int iter_id, String pragma_type) { + auto node = make_object(); + node->stage_id = stage_id; + node->iter_id = iter_id; + node->pragma_type = std::move(pragma_type); + data_ = std::move(node); +} + +PragmaStep::PragmaStep(dmlc::JSONReader* reader) { + auto node = make_object(); + bool s; + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->stage_id); + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->iter_id); + s = reader->NextArrayItem(); + CHECK(s); + std::string string_value; + reader->Read(&string_value); + node->pragma_type = std::move(string_value); + data_ = std::move(node); +} + +void PragmaStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { + writer->WriteArraySeperator(); + writer->WriteString(record_prefix_str); + writer->WriteArrayItem(stage_id); + writer->WriteArrayItem(iter_id); + writer->WriteArraySeperator(); + writer->WriteString(pragma_type); +} + +void PragmaStepNode::ApplyToState(State* state) const { + if (pragma_type == "debug_skip_region") { + StateNode* pstate = state->CopyOnWrite(); + pstate->attach_map.DeleteStage(stage_id); + } else if (StrStartsWith(pragma_type, "auto_unroll_max_step")) { + StateNode* pstate = state->CopyOnWrite(); + Stage stage = pstate->stages[stage_id]; + size_t pos = 0; + for (; pos < pragma_type.size(); ++pos) { + if ((*(pragma_type.c_str() + pos)) == '$') { + break; + } + } + CHECK_LT(pos, pragma_type.size()) << "max step value not found."; + stage.CopyOnWrite()->attrs.auto_unroll_max_step = atoi(pragma_type.c_str() + pos + 1); + pstate->stages.Set(stage_id, std::move(stage)); + } else if (pragma_type == "tensor_core") { + // Nothing needs to be done here + } else { + LOG(FATAL) << "Invalid pragma: " << pragma_type; + } +} + +void PragmaStepNode::ApplyToSchedule(Array* stages, + StageToAxesMap* stage_to_axes) const { + te::Stage stage = (*stages)[stage_id]; + const Array& axes = (*stage_to_axes)[stage]; + if (StrStartsWith(pragma_type, "auto_unroll_max_step")) { + size_t pos = 0; + for (; pos < pragma_type.size(); ++pos) { + if ((*(pragma_type.c_str() + pos)) == '$') { + break; + } + } + CHECK_LT(pos, pragma_type.size()) << "max step value not found."; + int value = atoi(pragma_type.c_str() + pos + 1); + stage.pragma(axes[iter_id], "auto_unroll_max_step", value); + stage.pragma(axes[iter_id], "unroll_explicit", true); + } else { + stage.pragma(axes[iter_id], pragma_type); + } + stages->Set(stage_id, std::move(stage)); +} + +String PragmaStepNode::PrintAsPythonAPI(Array* stages, + StageToAxesMap* stage_to_axes) const { + std::stringstream ss; + const auto& stage = (*stages)[stage_id]; + + if (StrStartsWith(pragma_type, "auto_unroll_max_step")) { + size_t pos = 0; + for (; pos < pragma_type.size(); ++pos) { + if ((*(pragma_type.c_str() + pos)) == '$') { + break; + } + } + CHECK_LT(pos, pragma_type.size()) << "max step value not found."; + int value = atoi(pragma_type.c_str() + pos + 1); + ss << "s[" << CleanName(stage->op->name) << "].pragma(" + << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) + << ", \"auto_unroll_max_step\", " << value << ")\n"; + ss << "s[" << CleanName(stage->op->name) << "].pragma(" + << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) + << ", \"unroll_explicit\", True)\n"; + } else { + ss << "s[" << CleanName(stage->op->name) << "].pragma(" + << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) << ", \"" << pragma_type + << "\")\n"; + } + + ApplyToSchedule(stages, stage_to_axes); + return ss.str(); +} + /********** Reorder **********/ ReorderStep::ReorderStep(int stage_id, const Array& after_ids) { auto node = make_object(); @@ -776,6 +909,70 @@ String SplitStepNode::PrintAsPythonAPI(Array* stages, return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, lengths, inner_to_outer); } +/********** Storage Align **********/ +StorageAlignStep::StorageAlignStep(int stage_id, int iter_id, int factor, int offset) { + auto node = make_object(); + node->stage_id = stage_id; + node->iter_id = iter_id; + node->factor = factor; + node->offset = offset; + data_ = std::move(node); +} + +StorageAlignStep::StorageAlignStep(dmlc::JSONReader* reader) { + auto node = make_object(); + bool s; + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->stage_id); + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->iter_id); + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->factor); + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->offset); + data_ = std::move(node); +} + +void StorageAlignStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { + writer->WriteArraySeperator(); + writer->WriteString(record_prefix_str); + writer->WriteArrayItem(stage_id); + writer->WriteArrayItem(iter_id); + writer->WriteArrayItem(factor); + writer->WriteArrayItem(offset); +} + +void StorageAlignStepNode::ApplyToState(State* state) const { + StateNode* pstate = state->CopyOnWrite(); + Stage stage = pstate->stages[stage_id]; + stage.CopyOnWrite()->attrs.storage_offset = offset; + pstate->stages.Set(stage_id, std::move(stage)); +} + +void StorageAlignStepNode::ApplyToSchedule(Array* stages, + StageToAxesMap* stage_to_axes) const { + te::Stage stage = (*stages)[stage_id]; + const Array& axes = (*stage_to_axes)[stage]; + stage.storage_align(axes[iter_id], factor, offset); + stages->Set(stage_id, std::move(stage)); +} + +String StorageAlignStepNode::PrintAsPythonAPI(Array* stages, + StageToAxesMap* stage_to_axes) const { + std::stringstream ss; + const auto& stage = (*stages)[stage_id]; + ss << "s[" << CleanName(stage->op->name) << "].storage_align(" + << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) << ", " << factor << ", " + << offset << ")\n"; + + ApplyToSchedule(stages, stage_to_axes); + return ss.str(); +} + /********** Steps working on multiple stages **********/ /********** Compute At **********/ @@ -958,7 +1155,7 @@ String ComputeRootStepNode::PrintAsPythonAPI(Array* stages, return ss.str(); } -/********** Primitives adding new stages **********/ +/********** Steps adding new stages **********/ /*! * \brief Common part for steps that add new stages(e.g. CacheReadStep, CacheWriteStep, @@ -967,11 +1164,27 @@ String ComputeRootStepNode::PrintAsPythonAPI(Array* stages, */ Array GetFormerStageModifiableSteps(Step current_step, const Array& transform_steps) { Array ret_steps; - for (const Step& step : transform_steps) { + for (size_t i = 0; i < transform_steps.size(); ++i) { + const Step& step = transform_steps[i]; if (step->IsInstance() || step->IsInstance()) { ret_steps.push_back(step); + } else if (step->IsInstance()) { + // add FuseStepNode required by rfactor + if (i >= 2 && transform_steps[i - 2]->IsInstance()) { + const Step& fuse_step = transform_steps[i - 2]; + if (fuse_step->stage_id == step->stage_id) { + ret_steps.push_back(fuse_step); + } + } + // add SplitStepNode required by rfactor + CHECK_GE(i, 1); + CHECK(transform_steps[i - 1]->IsInstance()); + const Step& split_step = transform_steps[i - 1]; + CHECK_EQ(split_step->stage_id, step->stage_id); + ret_steps.push_back(split_step); + // add RfactorStepNode + ret_steps.push_back(step); } - // TODO(jcf94): add rfactor support // A state may have multiple stage modifiable steps, stop by the current step to avoid // replaying excess steps if (step.same_as(current_step)) { @@ -1228,5 +1441,136 @@ String CacheWriteStepNode::PrintAsPythonAPI(Array* stages, StageToAxe return ss.str(); } +/********** Rfactor **********/ +RfactorStep::RfactorStep(int stage_id, int iter_id, int factor_iter_id) { + auto node = make_object(); + node->stage_id = stage_id; + node->iter_id = iter_id; + node->factor_iter_id = factor_iter_id; + data_ = std::move(node); +} + +RfactorStep::RfactorStep(dmlc::JSONReader* reader) { + auto node = make_object(); + bool s; + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->stage_id); + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->iter_id); + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->factor_iter_id); + data_ = std::move(node); +} + +void RfactorStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { + writer->WriteArraySeperator(); + writer->WriteString(record_prefix_str); + writer->WriteArrayItem(stage_id); + writer->WriteArrayItem(iter_id); + writer->WriteArrayItem(factor_iter_id); +} + +int RfactorStepNode::ApplyToState(State* state, const ComputeDAG& dag) const { + StateNode* pstate = state->CopyOnWrite(); + const auto& compute_at_type = pstate->stages[stage_id]->compute_at; + Array replay_steps; + for (size_t i = 0; i < pstate->transform_steps.size(); ++i) { + AddStageModificationSteps(i, pstate->transform_steps, &replay_steps); + if (pstate->transform_steps[i].same_as(GetRef(this))) { + break; + } + } + const ComputeDAG& current_compute_dag = dag.ReplayAndGetDAG(replay_steps); + + // target -> target_compute + target + // Should insert new stage, update target stage, update the later stage's op + pstate->stages.insert(pstate->stages.begin() + stage_id, + Stage(current_compute_dag->ops[stage_id])); + // maintain the compute_at type of target stage + Stage target_stage = Stage(current_compute_dag->ops[stage_id + 1]); + target_stage.CopyOnWrite()->compute_at = compute_at_type; + pstate->stages.Set(stage_id + 1, std::move(target_stage)); + + for (size_t i = stage_id + 2; i < pstate->stages.size(); ++i) { + Stage stage = pstate->stages[i]; + stage.CopyOnWrite()->op = current_compute_dag->ops[i]; + pstate->stages.Set(i, std::move(stage)); + } + pstate->attach_map = pstate->attach_map.ApplyStageIdOfffset(stage_id, 1); + + return stage_id; +} + +Array RfactorStepNode::ApplyToSchedule(Array* stages, + StageToAxesMap* stage_to_axes, + te::Schedule* schedule) const { + const auto& stage = (*stages)[stage_id]; + const Array& axes = (*stage_to_axes)[stage]; + + const te::Tensor& tensor = stage->origin_op.output(0); + const IterVar& axis = axes[iter_id]; + auto outs = schedule->rfactor(tensor, axis, factor_iter_id); + + UpdateStageToAxesMap(stage, stage_to_axes); + + const auto& new_stage = (*schedule)[outs[0]->op]; + UpdateStageToAxesMap(new_stage, stage_to_axes); + stages->insert(stages->begin() + stage_id, new_stage); + + return outs; +} + +String RfactorStepNode::PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes, + te::Schedule* schedule) const { + std::stringstream ss; + const auto& stage = (*stages)[stage_id]; + + const auto& tensor_name = CleanName(stage->origin_op.output(0)->op->name); + const auto& axis_name = CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint); + + const auto& outs = ApplyToSchedule(stages, stage_to_axes, schedule); + + for (size_t i = 0; i < outs.size(); ++i) { + ss << CleanName(outs[i]->op->name); + if (i != outs.size() - 1) { + ss << ", "; + } + } + ss << " = " + << "s.rfactor(" << tensor_name << ", " << axis_name << ", " << factor_iter_id << ")\n"; + + for (const auto& out : outs) { + const auto& iters = out->op->root_iter_vars(); + for (size_t i = 0; i < iters.size(); ++i) { + ss << CleanName(iters[i]->var->name_hint); + if (i != iters.size() - 1) { + ss << ", "; + } + } + ss << " = " + << "tuple(" << CleanName(out->op->name) << ".op.axis)" + << " + " + << "tuple(" << CleanName(out->op->name) << ".op.reduce_axis)\n"; + } + + const auto& output = (*stages)[stage_id + 1]->op.output(0); + const auto& iters = output->op->root_iter_vars(); + for (size_t i = 0; i < iters.size(); ++i) { + ss << CleanName(iters[i]->var->name_hint); + if (i != iters.size() - 1) { + ss << ", "; + } + } + ss << " = " + << "tuple(s[" << CleanName(output->op->name) << "].op.axis)" + << " + " + << "tuple(s[" << CleanName(output->op->name) << "].op.reduce_axis)\n"; + + return ss.str(); +} + } // namespace auto_scheduler } // namespace tvm diff --git a/src/auto_scheduler/utils.h b/src/auto_scheduler/utils.h index da5032e11c97a..aacdcf4265f9e 100644 --- a/src/auto_scheduler/utils.h +++ b/src/auto_scheduler/utils.h @@ -162,6 +162,12 @@ inline double FloatArrayMean(const Array& float_array) { return sum / float_array.size(); } +/*! \brief Return whether a string starts with another substring */ +inline bool StrStartsWith(const String& a, const String& b) { + if (b.size() > a.size()) return false; + return std::equal(a.c_str(), a.c_str() + b.size(), b.c_str()); +} + /********** Other Utilities **********/ /*! \brief Get an int value from an Expr */ inline int64_t GetIntImm(const PrimExpr& expr) { diff --git a/tests/python/unittest/test_auto_scheduler_loop_state.py b/tests/python/unittest/test_auto_scheduler_loop_state.py index 8282d4a40e5ef..255acf1dcc68b 100644 --- a/tests/python/unittest/test_auto_scheduler_loop_state.py +++ b/tests/python/unittest/test_auto_scheduler_loop_state.py @@ -417,7 +417,61 @@ def test_cache_read_write(): for it0, it1 in zip(s0[kernel_split].iters, s0[kernel_split_global].iters): assert it0.range == it1.range + +def test_rfactor(): + A, B, C = matmul_auto_scheduler_test(8, 8, 512) + dag = auto_scheduler.ComputeDAG([A, B, C]) + s0 = dag.get_init_state() + + ko, ki = s0.split(C, s0[C].iters[2], [16]) + + s1 = s0.copy() + C_r = s1.rfactor(C, ko, 2) + """ + Placeholder: A, B + for i (0,8) + for j (0,8) + for k_o (0,32) + for k_i (0,16) + C.rf = ... + for ax0 (0,8) + for ax1 (0,8) + for k_o_v (0,32) + C.repl = ... + """ + assert s1[C_r].iters[0].range.extent == 8 + assert s1[C_r].iters[1].range.extent == 8 + assert s1[C_r].iters[2].range.extent == 32 + assert s1[C_r].iters[3].range.extent == 16 + assert s1[C].iters[0].range.extent == 8 + assert s1[C].iters[1].range.extent == 8 + assert s1[C].iters[2].range.extent == 32 + + s2 = s0.copy() + C_r = s2.rfactor(C, ki, 2) + """ + Placeholder: A, B + for i (0,8) + for j (0,8) + for k_i (0,16) + for k_o (0,32) + C.rf = ... + for ax0 (0,8) + for ax1 (0,8) + for k_i_v (0,16) + C.repl = ... + """ + assert s2[C_r].iters[0].range.extent == 8 + assert s2[C_r].iters[1].range.extent == 8 + assert s2[C_r].iters[2].range.extent == 16 + assert s2[C_r].iters[3].range.extent == 32 + assert s2[C].iters[0].range.extent == 8 + assert s2[C].iters[1].range.extent == 8 + assert s2[C].iters[2].range.extent == 16 + + if __name__ == "__main__": test_split_fuse_reorder_annotation() test_compute_at_root_inline() test_cache_read_write() + test_rfactor() diff --git a/tests/python/unittest/test_auto_scheduler_measure.py b/tests/python/unittest/test_auto_scheduler_measure.py index 5f2f87ad9baa2..a963fa2b16add 100644 --- a/tests/python/unittest/test_auto_scheduler_measure.py +++ b/tests/python/unittest/test_auto_scheduler_measure.py @@ -37,8 +37,10 @@ def test_record(): k = te.reduce_axis((0, 512), name='k') E = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * D[k][j], axis=[k]), name='E') F = topi.nn.relu(E) + k = te.reduce_axis((0, 512), name='k') + G = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * F[k][j], axis=[k]), name='G') - dag = auto_scheduler.ComputeDAG([A, B, F]) + dag = auto_scheduler.ComputeDAG([A, B, G]) s = dag.get_init_state() # Split @@ -71,6 +73,13 @@ def test_record(): s.compute_at(D_global, E, s[E].iters[2]) # Cache Write s.cache_write(D, "shared") + # Pragma + s.pragma(C, s[C].iters[0], "auto_unroll_max_step$64") + # StorageAlign + s.storage_align(E, s[E].iters[-1], 8, 4) + # Rfactor + ko, _ = s.split(G, s[G].iters[2], [16]) + s.rfactor(G, ko, 2) target = tvm.target.create("llvm") task = auto_scheduler.SearchTask(dag, "test", target)