From d12465d5969b22302481e9fbc659df58c08fac89 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Tue, 21 Jul 2020 15:54:35 +0800 Subject: [PATCH 01/31] Add cache_read/cache_write step --- python/tvm/auto_scheduler/compute_dag.py | 9 +- python/tvm/auto_scheduler/loop_state.py | 79 +++++ src/auto_scheduler/compute_dag.cc | 39 ++- src/auto_scheduler/compute_dag.h | 9 + src/auto_scheduler/loop_state.cc | 58 ++++ src/auto_scheduler/loop_state.h | 59 +++- src/auto_scheduler/transform_step.cc | 311 +++++++++++++++++- src/auto_scheduler/transform_step.h | 171 +++++++++- .../test_auto_scheduler_loop_state.py | 276 ++++++++++++++++ .../unittest/test_auto_scheduler_measure.py | 7 +- 10 files changed, 974 insertions(+), 44 deletions(-) diff --git a/python/tvm/auto_scheduler/compute_dag.py b/python/tvm/auto_scheduler/compute_dag.py index 115d28b4d478..7d8856a6b4e7 100644 --- a/python/tvm/auto_scheduler/compute_dag.py +++ b/python/tvm/auto_scheduler/compute_dag.py @@ -126,11 +126,16 @@ def infer_bound_from_state(self, state): Returns ------- - state : State + updated_state : State The State with complete bound information. """ state_obj = state if isinstance(state, StateObject) else state.state_object - return State(_ffi_api.ComputeDAGInferBoundFromState(self, state_obj), self) + updated_state = State(_ffi_api.ComputeDAGInferBoundFromState(self, state_obj), self) + # Copy the stage_id_map from the original state + if isinstance(state, State): + for k, v in state.stage_id_map.items(): + updated_state.stage_id_map[k] = v + return updated_state def __hash__(self): # TODO(merrymercy): Implement this more carefully and move this to c++ as a member function diff --git a/python/tvm/auto_scheduler/loop_state.py b/python/tvm/auto_scheduler/loop_state.py index ab041cf4a43d..fa50bfa0e1ec 100644 --- a/python/tvm/auto_scheduler/loop_state.py +++ b/python/tvm/auto_scheduler/loop_state.py @@ -351,6 +351,68 @@ def compute_root(self, stage): self.state_object = _ffi_api.StateComputeRoot(self.state_object, self._resolve_stage_id(stage)) + def cache_read(self, stage, scope_name, reader_stages): + """ Schedule primitive corresponds to te.schedule.cache_read. + + Parameters + ---------- + stage : Union[int, Operation, Tensor] + The Stage to be cache read, which can be specified by the integer index, Operation, + or output tensor of the stage. + scope_name : str + The scope name to be set for the new added read stage. + reader_stages : List[Union[int, Operation, Tensor]] + The reader stages. Each of the list can be specified by the integer index, Operation, + or output tensor of the stage. + + Returns + ------- + new_stage_op : Operator + The Operator of the new added stage. + + Notes + ----- + Cache read step will add an extra stage to the original ComputeDAG. + """ + if isinstance(reader_stages, list): + reader_stage_ids = [self._resolve_stage_id(id) for id in reader_stages] + else: + raise ValueError("reader_stages must be a list of the integer index, Operation, " + \ + "or output tensor of the stage") + + self.state_object, new_stage_id = _ffi_api.StateCacheRead(self.state_object, + self._resolve_stage_id(stage), + scope_name, reader_stage_ids, + self.compute_dag) + return self._insert_new_stage(int(new_stage_id)) + + def cache_write(self, stage, scope_name): + """ Schedule primitive corresponds to te.schedule.cache_write. + + Parameters + ---------- + stage : Union[int, Operation, Tensor] + The Stage to be cache write, which can be specified by the integer index, Operation, + or output tensor of the stage. + scope_name : str + The scope name to be set for the new added write stage. + + Returns + ------- + new_stage_op : Operator + The Operator of the new added stage. + + Notes + ----- + Cache write step will add an extra stage to the original ComputeDAG, a up-to-date + ComputeDAG is stored in State's `current_compute_dag`. + This step will cache write all output tensors of the target stage. + """ + self.state_object, new_stage_id = _ffi_api.StateCacheWrite(self.state_object, + self._resolve_stage_id(stage), + scope_name, self.compute_dag) + return self._insert_new_stage(int(new_stage_id)) + def copy(self): """ Do deep copy of this State. """ state = State(self.state_object, self.compute_dag) @@ -371,6 +433,23 @@ def _update_stage_id_map(self): for index, stage in enumerate(self.stages): self.stage_id_map[stage.op] = index + def _insert_new_stage(self, new_stage_id): + added_op = self.stages[new_stage_id].op + + # Add a new stage will change all ops. But we still want to use the old ops to index stages, + # So we keep updating them and do not remove the old ops. + + # Update stage_id_map for old ops, so we can still use the old ops to index stages. + for key, value in self.stage_id_map.items(): + if value >= new_stage_id: + self.stage_id_map[key] = value + 1 + self.stage_id_map[added_op] = new_stage_id + + # Update stage_id_map for new ops + self._update_stage_id_map() + + return added_op + def __getitem__(self, key): if isinstance(key, Tensor): key = key.op diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index d81dff66d402..0d964cb63513 100644 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -221,24 +221,6 @@ ComputeDAG::ComputeDAG(Array tensors) { data_ = std::move(node); } -// Update the te::stage to tir::IterVar axis mapping -void UpdateStageToAxesMap(const te::Stage& stage, StageToAxesMap* stage_to_axes) { - if (auto pop = stage->op.as()) { - Array axes; - for (const auto& axis : pop->axis) { - axes.push_back(axis); - } - for (const auto& axis : pop->reduce_axis) { - axes.push_back(axis); - } - stage_to_axes->Set(stage, std::move(axes)); - } else if (stage->op->IsInstance()) { - {} // do nothing on Placeholder - } else { - LOG(FATAL) << "Invalid op " << stage->op; - } -} - std::pair> ComputeDAG::ApplySteps( const Array& transform_steps, Array* stages, StageToAxesMap* stage_to_axes) const { @@ -272,7 +254,7 @@ std::pair> ComputeDAG::ApplySteps( // Apply the history steps to TVM schedule // Call each step's ApplyToSchedule method for (const auto& step : transform_steps) { - StepApplyToSchedule(step, stages, stage_to_axes); + StepApplyToSchedule(step, stages, stage_to_axes, &schedule); } return std::make_pair(schedule, operator->()->tensors); @@ -316,7 +298,7 @@ String ComputeDAG::PrintStepsAsPython(const Array& transform_steps) const } // Call each step's PrintAsPythonAPI method for (const auto& step : transform_steps) { - ss << StepPrintAsPythonAPI(step, &stages, &stage_to_axes); + ss << StepPrintAsPythonAPI(step, &stages, &stage_to_axes, &schedule); } return ss.str(); @@ -382,6 +364,23 @@ State ComputeDAG::InferBound(const State& state) const { return ret_state; } +ComputeDAG ComputeDAG::ReplayAndGetDAG(const Array& transform_steps) const { + te::Schedule sch; + Array old_tensors; + std::tie(sch, old_tensors) = ApplySteps(transform_steps); + + Array new_tensors; + for (auto stage : sch->stages) { + if (stage->op->IsInstance() || stage->is_output) { + for (auto i = 0; i < stage->op->num_outputs(); ++i) { + new_tensors.push_back(stage->op.output(i)); + } + } + } + + return ComputeDAG(new_tensors); +} + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); diff --git a/src/auto_scheduler/compute_dag.h b/src/auto_scheduler/compute_dag.h index 2417d72983b0..9b20cd36b992 100644 --- a/src/auto_scheduler/compute_dag.h +++ b/src/auto_scheduler/compute_dag.h @@ -114,6 +114,15 @@ class ComputeDAG : public ObjectRef { */ State InferBound(const State& state) const; + /*! + * \brief Some steps may change the structure of ComputeDAG(e.g. CacheRead/CacheWrite Step), this + * is to replay the transform steps and get the up-to-date ComputeDAG. + * \param steps The steps to be replaied. Usually we'll filter out the unused steps to speed up + * the replay process, for we only need to get the new ComputeDAG structure. + * \return The up-to-date ComputeDAG. + */ + ComputeDAG ReplayAndGetDAG(const Array& steps) const; + TVM_DEFINE_OBJECT_REF_METHODS(ComputeDAG, ObjectRef, ComputeDAGNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeDAGNode); }; diff --git a/src/auto_scheduler/loop_state.cc b/src/auto_scheduler/loop_state.cc index bfe547864ed1..22dab595e374 100644 --- a/src/auto_scheduler/loop_state.cc +++ b/src/auto_scheduler/loop_state.cc @@ -30,6 +30,7 @@ #include +#include "compute_dag.h" #include "transform_step.h" #include "utils.h" @@ -151,6 +152,36 @@ void AttachMap::DeleteStageEntry(AttachMapNode* pnode, int stage_id) { } } +AttachMap AttachMap::ApplyStageIdOfffset(int start_id, int offset) const { + AttachMap map = AttachMap(make_object()); + auto pmap = map.CopyOnWrite(); + for (const auto& x : operator->()->stage_to_attach_iter) { + auto key = x.first; + if (key >= start_id) { + key += offset; + } + auto value = x.second; + if (value.first >= start_id) { + value.first += offset; + } + pmap->stage_to_attach_iter.insert(std::make_pair(key, value)); + } + for (const auto& x : operator->()->iter_to_attached_stages) { + auto key = x.first; + if (key.first >= start_id) { + key.first += offset; + } + auto value = x.second; + for (auto& i : value) { + if (i >= start_id) { + i += offset; + } + } + pmap->iter_to_attached_stages.insert(std::make_pair(key, value)); + } + return map; +} + /********** State **********/ State::State(const Array& ops) { auto node = make_object(); @@ -258,6 +289,19 @@ void State::compute_root(int stage_id) { step->ApplyToState(this); } +int State::cache_read(int stage_id, const String& scope_name, + const Array& reader_stage_ids, const ComputeDAG& dag) { + CacheReadStep step = CacheReadStep(stage_id, scope_name, reader_stage_ids); + CopyOnWrite()->transform_steps.push_back(step); + return step->ApplyToState(this, dag); +} + +int State::cache_write(int stage_id, const String& scope_name, const ComputeDAG& dag) { + CacheWriteStep step = CacheWriteStep(stage_id, scope_name); + CopyOnWrite()->transform_steps.push_back(step); + return step->ApplyToState(this, dag); +} + void State::ApplySteps(const ComputeDAG& dag) { CHECK(operator->()->stages.size()) << "Invalid State with empty operation stages."; @@ -430,6 +474,20 @@ TVM_REGISTER_GLOBAL("auto_scheduler.StateComputeRoot") return state; }); +TVM_REGISTER_GLOBAL("auto_scheduler.StateCacheRead") + .set_body_typed([](State state, int stage_id, const String& scope_name, + const Array& reader_stage_ids, const ComputeDAG& dag) { + int res = state.cache_read(stage_id, scope_name, reader_stage_ids, dag); + return Array{state, Integer(res)}; + }); + +TVM_REGISTER_GLOBAL("auto_scheduler.StateCacheWrite") + .set_body_typed([](State state, int stage_id, const String& scope_name, + const ComputeDAG& task_dag) { + int res = state.cache_write(stage_id, scope_name, task_dag); + return Array{state, Integer(res)}; + }); + TVM_REGISTER_GLOBAL("auto_scheduler.StateEqual").set_body_typed([](State state1, State state2) { return std::equal_to()(state1, state2); }); diff --git a/src/auto_scheduler/loop_state.h b/src/auto_scheduler/loop_state.h index 4d6477b92b0f..427baccbc788 100644 --- a/src/auto_scheduler/loop_state.h +++ b/src/auto_scheduler/loop_state.h @@ -181,11 +181,13 @@ class AttachMap : public ObjectRef { * \param target_iter_id The index of iterator in target stage that this step will compute at to. */ void SetComputeAtIter(int stage_id, int target_stage_id, int target_iter_id); + /*! * \brief This is a public wrapper of `DeleteStageEntry`. To delete the entry of a specific stage. * \param stage_id The index of the stage to be compute at. */ void DeleteStage(int stage_id); + /*! * \brief Find the relations of original iterators in AttachMap, and update them with the new * iterators. Both `stage_to_attach_iter` and `iter_to_attached_stages` will be updated. @@ -195,6 +197,17 @@ class AttachMap : public ObjectRef { void UpdateIters(const std::vector& original_iters, const std::vector& new_iters); + /*! + * \brief Traverse through `stage_to_attach_iter` and `iter_to_attached_stages` map, add offset + * to stage indexes that are larger than the start_id. Used for steps that inserts net stages to + * ComputeDAG(e.g. CacheRead/CacheWrite step). + * \param start_id The index threshold, stage indexes in AttachMap which are larger than this + * will be applied the extra offset. + * \param offset The index offset to be added to the stage index. + * \return The updated AttachMap after applying stage index offset. + */ + AttachMap ApplyStageIdOfffset(int start_id, int offset) const; + TVM_DEFINE_OBJECT_REF_METHODS(AttachMap, ObjectRef, AttachMapNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(AttachMapNode); @@ -225,6 +238,13 @@ class StateNode : public Object { * operation. */ AttachMap attach_map; + /*! + * \brief The up-to-date ComputeDAG of this state, used for some steps that may change the + * stage structure of the ComputeDAG (e.g. CacheReadStep/CacheWriteStep which Will be added + * later). + * The default value is an empty ObjectRef. (means no modification to the original DAG) + */ + ObjectRef current_compute_dag; /*! * \brief Indicate whether this state has unfilled tile sizes. A concrete state means that all * tile sizes of the state is filled. Only concrete state can be apply to TVM schedule. @@ -239,15 +259,6 @@ class StateNode : public Object { static constexpr const char* _type_key = "auto_scheduler.State"; TVM_DECLARE_FINAL_OBJECT_INFO(StateNode, Object); - - private: - /*! - * \brief The up-to-date ComputeDAG of this state, used for some steps that may change the - * stage structure of the ComputeDAG (e.g. CacheReadStep/CacheWriteStep which Will be added - * later). - * The default value is an empty ObjectRef. (means no modification to the original DAG) - */ - ObjectRef current_compute_dag; }; /*! @@ -347,7 +358,7 @@ class State : public ObjectRef { /*! * \brief Schedule primitive corresponds to te.compute_at. - * \param stage_id The index of the stage to be reordered. + * \param stage_id The index of the stage to be compute at. * \param target_stage_id The index of stage that this step will compute at to. * \param target_iter The iterator in target stage that this step will compute at to. * \note After compute_at, we need careful dependency analysis to compute the accurate bound @@ -358,12 +369,12 @@ class State : public ObjectRef { void compute_at(int stage_id, int target_stage_id, const Iterator& target_iter); /*! * \brief Schedule primitive corresponds to te.compute_inline. - * \param stage_id The index of the stage to be reordered. + * \param stage_id The index of the stage to be compute inlined. */ void compute_inline(int stage_id); /*! * \brief Schedule primitive corresponds to te.compute_root. - * \param stage_id The index of the stage to be reordered. + * \param stage_id The index of the stage to be compute root. * \note After compute_root, we need careful dependency analysis to compute the accurate bound * information. However, it is relatively expensive and complicated, so we just fill "None" as * bound for the newly created iterators. @@ -371,6 +382,30 @@ class State : public ObjectRef { */ void compute_root(int stage_id); + /********** Step APIs adding new stages **********/ + + /*! + * \brief Schedule primitive corresponds to te.schedule.cache_read. + * \param stage_id The index of the stage to be cache read. + * \param scope_name The scope name to be set for the new added read stage. + * \param reader_stage_ids The indexes of reader stages. + * \param dag The original ComputeDAG of this state. + * \note Cache read step will add an extra stage to the original ComputeDAG, a up-to-date + * ComputeDAG is stored in State's `current_compute_dag`. + */ + int cache_read(int stage_id, const String& scope_name, const Array& reader_stage_ids, + const ComputeDAG& dag); + /*! + * \brief Schedule primitive corresponds to te.schedule.cache_write. + * \param stage_id The index of the stage to be cache write. + * \param scope_name The scope name to be set for the new added write stage. + * \param dag The original ComputeDAG of this state. + * \note Cache write step will add an extra stage to the original ComputeDAG, a up-to-date + * ComputeDAG is stored in State's `current_compute_dag`. + * This step will cache write all output tensors of the target stage. + */ + int cache_write(int stage_id, const String& scope_name, const ComputeDAG& dag); + TVM_DEFINE_OBJECT_REF_METHODS(State, ObjectRef, StateNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(StateNode); }; diff --git a/src/auto_scheduler/transform_step.cc b/src/auto_scheduler/transform_step.cc index 6c672a5215f2..bff611d8eed1 100644 --- a/src/auto_scheduler/transform_step.cc +++ b/src/auto_scheduler/transform_step.cc @@ -32,12 +32,31 @@ #include #include +#include "compute_dag.h" #include "loop_state.h" #include "utils.h" namespace tvm { namespace auto_scheduler { +// Update the te::stage to tir::IterVar axis mapping +void UpdateStageToAxesMap(const te::Stage& stage, StageToAxesMap* stage_to_axes) { + if (auto pop = stage->op.as()) { + Array axes; + for (const auto& axis : pop->axis) { + axes.push_back(axis); + } + for (const auto& axis : pop->reduce_axis) { + axes.push_back(axis); + } + stage_to_axes->Set(stage, std::move(axes)); + } else if (stage->op->IsInstance()) { + {} // do nothing on Placeholder + } else { + LOG(FATAL) << "Invalid op " << stage->op; + } +} + const char* IteratorAnnotationString[] = { "for", // kNone = 0 "unroll", // kUnroll = 1 @@ -73,6 +92,10 @@ Step StepReadFromRecord(dmlc::JSONReader* reader) { return ComputeInlineStep(reader); } else if (name == ComputeRootStepNode::record_prefix_str) { return ComputeRootStep(reader); + } else if (name == CacheReadStepNode::record_prefix_str) { + return CacheReadStep(reader); + } else if (name == CacheWriteStepNode::record_prefix_str) { + return CacheWriteStep(reader); } else { LOG(FATAL) << "Invalid step format: " << name; } @@ -94,13 +117,17 @@ void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag) { ps->ApplyToState(state); } else if (auto ps = step.as()) { ps->ApplyToState(state); + } else if (auto ps = step.as()) { + ps->ApplyToState(state, dag); + } else if (auto ps = step.as()) { + ps->ApplyToState(state, dag); } else { LOG(FATAL) << "Invalid step: " << step; } } -void StepApplyToSchedule(const Step& step, Array* stages, - StageToAxesMap* stage_to_axes) { +void StepApplyToSchedule(const Step& step, Array* stages, StageToAxesMap* stage_to_axes, + te::Schedule* schedule) { if (auto ps = step.as()) { ps->ApplyToSchedule(stages, stage_to_axes); } else if (auto ps = step.as()) { @@ -115,13 +142,17 @@ void StepApplyToSchedule(const Step& step, Array* stages, ps->ApplyToSchedule(stages, stage_to_axes); } else if (auto ps = step.as()) { ps->ApplyToSchedule(stages, stage_to_axes); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes, schedule); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes, schedule); } else { LOG(FATAL) << "Invalid Step: " << step; } } String StepPrintAsPythonAPI(const Step& step, Array* stages, - StageToAxesMap* stage_to_axes) { + StageToAxesMap* stage_to_axes, te::Schedule* schedule) { if (auto ps = step.as()) { return ps->PrintAsPythonAPI(stages, stage_to_axes); } else if (auto ps = step.as()) { @@ -136,6 +167,10 @@ String StepPrintAsPythonAPI(const Step& step, Array* stages, return ps->PrintAsPythonAPI(stages, stage_to_axes); } else if (auto ps = step.as()) { return ps->PrintAsPythonAPI(stages, stage_to_axes); + } else if (auto ps = step.as()) { + return ps->PrintAsPythonAPI(stages, stage_to_axes, schedule); + } else if (auto ps = step.as()) { + return ps->PrintAsPythonAPI(stages, stage_to_axes, schedule); } else { LOG(FATAL) << "Invalid Step: " << step; } @@ -923,5 +958,275 @@ String ComputeRootStepNode::PrintAsPythonAPI(Array* stages, return ss.str(); } +/********** Primitives adding new stages **********/ + +// Common part for steps that add new stages +// (e.g. CacheReadStep, CacheWriteStep, RfactorStep) +void AddStageModificationSteps(int step_id, const Array& transform_steps, + Array* replay_steps) { + const Step& step = transform_steps[step_id]; + if (step->IsInstance() || step->IsInstance()) { + replay_steps->push_back(step); + } + // TODO(jcf94): add rfactor support +} + +/********** Cache Read **********/ +CacheReadStep::CacheReadStep(int stage_id, String scope_name, + const Array& reader_stage_ids) { + auto node = make_object(); + node->stage_id = stage_id; + node->scope_name = std::move(scope_name); + node->reader_stage_ids = reader_stage_ids; + data_ = std::move(node); +} + +CacheReadStep::CacheReadStep(dmlc::JSONReader* reader) { + auto node = make_object(); + bool s; + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->stage_id); + s = reader->NextArrayItem(); + CHECK(s); + std::string string_value; + reader->Read(&string_value); + node->scope_name = std::move(string_value); + s = reader->NextArrayItem(); + CHECK(s); + std::vector int_list; + reader->Read(&int_list); + Array reader_stage_ids; + for (int i : int_list) { + reader_stage_ids.push_back(i); + } + node->reader_stage_ids = std::move(reader_stage_ids); + data_ = std::move(node); +} + +void CacheReadStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { + writer->WriteArraySeperator(); + writer->WriteString(record_prefix_str); + writer->WriteArrayItem(stage_id); + writer->WriteArraySeperator(); + writer->WriteString(scope_name); + writer->WriteArrayItem(IntArrayToVector(reader_stage_ids)); +} + +int CacheReadStepNode::ApplyToState(State* state, const ComputeDAG& dag) const { + StateNode* pstate = state->CopyOnWrite(); + Array replay_steps; + for (size_t i = 0; i < pstate->transform_steps.size(); ++i) { + AddStageModificationSteps(i, pstate->transform_steps, &replay_steps); + if (pstate->transform_steps[i].same_as(GetRef(this))) { + break; + } + } + const ComputeDAG& current_compute_dag = dag.ReplayAndGetDAG(replay_steps); + + // target -> target + target_store + // Should update target's op, insert new stage, update the later stage's op + int added_stage_id = stage_id + 1; + Stage tmp_stage = pstate->stages[stage_id]; + tmp_stage.CopyOnWrite()->op = current_compute_dag->ops[stage_id]; + pstate->stages.Set(stage_id, std::move(tmp_stage)); + pstate->stages.insert(pstate->stages.begin() + added_stage_id, + Stage(current_compute_dag->ops[added_stage_id])); + for (size_t i = added_stage_id + 1; i < pstate->stages.size(); ++i) { + tmp_stage = pstate->stages[i]; + tmp_stage.CopyOnWrite()->op = current_compute_dag->ops[i]; + pstate->stages.Set(i, std::move(tmp_stage)); + } + pstate->attach_map = pstate->attach_map.ApplyStageIdOfffset(added_stage_id, 1); + pstate->current_compute_dag = std::move(current_compute_dag); + + return added_stage_id; +} + +te::Tensor CacheReadStepNode::ApplyToSchedule(Array* stages, + StageToAxesMap* stage_to_axes, + te::Schedule* schedule) const { + const te::Stage& stage = (*stages)[stage_id]; + + Array readers; + for (const auto& i : reader_stage_ids) { + readers.push_back((*stages)[i]->origin_op); + } + auto out = schedule->cache_read(stage->origin_op.output(0), scope_name, readers); + + const auto& new_stage = (*schedule)[out->op]; + UpdateStageToAxesMap(new_stage, stage_to_axes); + stages->insert(stages->begin() + stage_id + 1, new_stage); + + return out; +} + +String CacheReadStepNode::PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes, + te::Schedule* schedule) const { + std::stringstream ss; + // Copy stage here, for the original stage will change after apply + auto stage = (*stages)[stage_id]; + std::vector reader_stages; + for (size_t i = 0; i < reader_stage_ids.size(); ++i) { + reader_stages.push_back((*stages)[reader_stage_ids[i]]); + } + + auto out = ApplyToSchedule(stages, stage_to_axes, schedule); + + ss << CleanName(out->op->name) << " = " + << "s.cache_read(" << CleanName(stage->op->name) << ", \"" << scope_name << "\", [" + << CleanName(reader_stages[0]->op->name); + for (size_t i = 1; i < reader_stage_ids.size(); ++i) { + ss << ", " << CleanName(reader_stages[i]->op->name); + } + ss << "])\n"; + + const auto& iters = out->op->root_iter_vars(); + for (size_t i = 0; i < iters.size(); ++i) { + ss << CleanName(iters[i]->var->name_hint); + if (i != iters.size() - 1) { + ss << ", "; + } + } + ss << " = " + << "tuple(" << CleanName(out->op->name) << ".op.axis)\n"; + + return ss.str(); +} + +/********** Cache Write **********/ +CacheWriteStep::CacheWriteStep(int stage_id, String scope_name) { + auto node = make_object(); + node->stage_id = stage_id; + node->scope_name = std::move(scope_name); + data_ = std::move(node); +} + +CacheWriteStep::CacheWriteStep(dmlc::JSONReader* reader) { + auto node = make_object(); + bool s; + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->stage_id); + s = reader->NextArrayItem(); + CHECK(s); + std::string string_value; + reader->Read(&string_value); + node->scope_name = std::move(string_value); + data_ = std::move(node); +} + +void CacheWriteStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { + writer->WriteArraySeperator(); + writer->WriteString(record_prefix_str); + writer->WriteArrayItem(stage_id); + writer->WriteArraySeperator(); + writer->WriteString(scope_name); +} + +int CacheWriteStepNode::ApplyToState(State* state, const ComputeDAG& dag) const { + StateNode* pstate = state->CopyOnWrite(); + Array replay_steps; + for (size_t i = 0; i < pstate->transform_steps.size(); ++i) { + AddStageModificationSteps(i, pstate->transform_steps, &replay_steps); + if (pstate->transform_steps[i].same_as(GetRef(this))) { + break; + } + } + int last_dag_op_size = pstate->current_compute_dag.defined() + ? pstate->current_compute_dag.as()->ops.size() + : dag->ops.size(); + const ComputeDAG& current_compute_dag = dag.ReplayAndGetDAG(replay_steps); + int added_ops = current_compute_dag->ops.size() - last_dag_op_size; + CHECK_GE(added_ops, 1); + + // target -> target_compute + target + // Assume target stage has never been applied any steps before cache_write + // Should insert new stage, update target stage, update the later stage's op + pstate->stages.insert(pstate->stages.begin() + stage_id, + Stage(current_compute_dag->ops[stage_id])); + pstate->stages.Set(stage_id + 1, Stage(current_compute_dag->ops[stage_id + 1])); + int next_stage_id = stage_id + 2; + // Notice: added_ops should actually assert to be 1 + // branch of 2 here is somehow a hack to TVM's cache_write bug with multi outputs + // see `tests/python/unittest/test_auto_scheduler_loop_state.py::test_cache_read_write` test for + // more information + // TODO(jcf94): Fix the cache write bug in TVM and remove these branches here + if (added_ops == 2) { + pstate->stages.insert(pstate->stages.begin() + next_stage_id, + Stage(current_compute_dag->ops[next_stage_id])); + next_stage_id++; + } else if (added_ops > 2) { + LOG(ERROR) << "Unexpected behavior of CacheWrite."; + } + for (size_t i = next_stage_id; i < current_compute_dag->ops.size(); ++i) { + Stage tmp_stage = pstate->stages[i]; + tmp_stage.CopyOnWrite()->op = current_compute_dag->ops[i]; + pstate->stages.Set(i, std::move(tmp_stage)); + } + pstate->attach_map = pstate->attach_map.ApplyStageIdOfffset(stage_id, added_ops); + pstate->current_compute_dag = std::move(current_compute_dag); + + return stage_id; +} + +Array CacheWriteStepNode::ApplyToSchedule(Array* stages, + StageToAxesMap* stage_to_axes, + te::Schedule* schedule) const { + const te::Stage& stage = (*stages)[stage_id]; + + Array tensor_array; + // If the target stage has multi outputs, TVM requires to cache_write + // all of them or schedule.cache_write will raise an error + for (auto i = 0; i < stage->op->num_outputs(); ++i) { + tensor_array.push_back(stage->origin_op.output(i)); + } + auto outs = schedule->cache_write(tensor_array, scope_name); + + UpdateStageToAxesMap(stage, stage_to_axes); + // Even if there is multi outputs, TVM schedule only generate one + // new stage + const auto& new_stage = (*schedule)[outs[0]->op]; + UpdateStageToAxesMap(new_stage, stage_to_axes); + stages->insert(stages->begin() + stage_id, new_stage); + + return outs; +} + +String CacheWriteStepNode::PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes, + te::Schedule* schedule) const { + std::stringstream ss; + // Copy stage here, for the original stage will change after apply + te::Stage stage = (*stages)[stage_id]; + + auto outs = ApplyToSchedule(stages, stage_to_axes, schedule); + + for (size_t i = 0; i < outs.size(); ++i) { + ss << CleanName(outs[i]->op->name) << ", "; + } + ss << "= " + << "s.cache_write([" << CleanName(stage->op.output(0)->op->name); + for (auto i = 1; i < stage->op->num_outputs(); ++i) { + ss << ", " << CleanName(stage->op.output(i)->op->name); + } + ss << "], \"" << scope_name << "\")\n"; + + for (const auto& out : outs) { + const auto& iters = out->op->root_iter_vars(); + for (size_t i = 0; i < iters.size(); ++i) { + ss << CleanName(iters[i]->var->name_hint); + if (i != iters.size() - 1) { + ss << ", "; + } + } + ss << " = " + << "tuple(" << CleanName(out->op->name) << ".op.axis)" + << " + " + << "tuple(" << CleanName(out->op->name) << ".op.reduce_axis)\n"; + } + + return ss.str(); +} + } // namespace auto_scheduler } // namespace tvm diff --git a/src/auto_scheduler/transform_step.h b/src/auto_scheduler/transform_step.h index ce3ca50ffae6..e1746189c29e 100644 --- a/src/auto_scheduler/transform_step.h +++ b/src/auto_scheduler/transform_step.h @@ -58,6 +58,13 @@ namespace auto_scheduler { typedef Map, ObjectHash, ObjectEqual> StageToAxesMap; +/*! + * \brief Update the current stage IterVar information to StageToAxesMap. + * \param stage A te::Stage Object. + * \param stage_to_axes A mutable pointer to StageToAxesMap, this map will be updated. + */ +void UpdateStageToAxesMap(const te::Stage& stage, StageToAxesMap* stage_to_axes); + /*! \brief The type of an iterator. */ enum class IteratorKind : int { /*! \brief Spatial iterator. */ @@ -194,20 +201,25 @@ void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag); /*! * \brief Apply the step to tvm.schedule. * \param step The step to be applied to tvm.schedule. - * \param stages A pointer to a `te::Stage` Array. - * \param stage_to_axes A pointer to a StageToAxesMap. + * \param stages A mutable pointer to a `te::Stage` Array. + * \param stage_to_axes A mutable pointer to a StageToAxesMap. + * \param schedule A mutable pointer to a te::Schedule. This is required by some steps. (e.g. + * CacheRead/CacheWrite step) */ -void StepApplyToSchedule(const Step& step, Array* stages, StageToAxesMap* stage_to_axes); +void StepApplyToSchedule(const Step& step, Array* stages, StageToAxesMap* stage_to_axes, + te::Schedule* schedule); /*! * \brief Print the step as equivalent python schedule API. * \param step The step to be applied to python API. - * \param stages A pointer to a `te::Stage` Array. - * \param stage_to_axes A pointer to a StageToAxesMap. + * \param stages A mutable pointer to a `te::Stage` Array. + * \param stage_to_axes A mutable pointer to a StageToAxesMap. + * \param schedule A mutable pointer to a te::Schedule. This is required by some steps. (e.g. + * CacheRead/CacheWrite step) * \return Python schedule code. */ String StepPrintAsPythonAPI(const Step& step, Array* stages, - StageToAxesMap* stage_to_axes); + StageToAxesMap* stage_to_axes, te::Schedule* schedule); /********** Primitives working on single stage **********/ @@ -659,6 +671,153 @@ class ComputeRootStep : public Step { TVM_DEFINE_OBJECT_REF_METHODS(ComputeRootStep, Step, ComputeRootStepNode); }; +/********** Primitives adding new stages **********/ + +/*! + * \brief Cache read step that corresponds to te::Schedule::cache_read. + * \note Cache read step will add an extra stage to the original ComputeDAG, a up-to-date ComputeDAG + * is stored in State's `current_compute_dag`. + */ +class CacheReadStepNode : public StepNode { + public: + /*! \brief The scope name to be set for the new added read stage. (e.g. local, shared, global) */ + String scope_name; + /*! \brief The indexes of reader stages. */ + Array reader_stage_ids; + + void WriteToRecord(dmlc::JSONWriter* writer) const final; + + /*! + * \brief Apply the current step to State. + * \param state A mutable pointer to State. + * \param dag The original ComputeDAG of this state. + * \return The index of the new added stage. + */ + int ApplyToState(State* state, const ComputeDAG& dag) const; + + /*! + * \brief Apply the current step to tvm.schedule. + * \param stages A mutable pointer to a `te::Stage` Array. + * \param stage_to_axes A mutable pointer to a StageToAxesMap. + * \param schedule A mutable pointer to a te::Schedule. + * \return The output Tensor of the new added stage. + */ + te::Tensor ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes, + te::Schedule* schedule) const; + + /*! + * \brief Print the current step as equivalent python schedule API. + * \param stages A mutable pointer to a `te::Stage` Array. + * \param stage_to_axes A mutable pointer to a StageToAxesMap. + * \param schedule A mutable pointer to a te::Schedule. + * \return Python schedule code. + */ + String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes, + te::Schedule* schedule) const; + + static constexpr const char* record_prefix_str = "CHR"; + + static constexpr const char* _type_key = "auto_scheduler.CacheReadStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(CacheReadStepNode, Object); +}; + +/*! + * \brief Managed reference to CacheReadStepNode. + * \sa CacheReadStepNode + */ +class CacheReadStep : public Step { + public: + /*! + * \brief The constructor. + * \param stage_id The index of the stage to be cache read. + * \param scope_name The scope name to be set for the new added read stage. + * \param reader_stage_ids The indexes of reader stages. + */ + CacheReadStep(int stage_id, String scope_name, const Array& reader_stage_ids); + + /*! + * \brief The constructor used to read a step record from JSONReader and create the + * corresponding step. + * \param reader The input JSONReader. + */ + explicit CacheReadStep(dmlc::JSONReader* reader); + + TVM_DEFINE_OBJECT_REF_METHODS(CacheReadStep, Step, CacheReadStepNode); +}; + +/*! + * \brief Cache write step that corresponds to te::Schedule::cache_write. + * \note Cache write step will add an extra stage to the original ComputeDAG, a up-to-date + * ComputeDAG is stored in State's `current_compute_dag`. + * This step will cache write all output tensors of the target stage. + */ +class CacheWriteStepNode : public StepNode { + public: + /*! + * \brief The scope name to be set for the new added write stage. (e.g. local, shared, + * global) + */ + String scope_name; + + void WriteToRecord(dmlc::JSONWriter* writer) const final; + + /*! + * \brief Apply the current step to State. + * \param state A mutable pointer to State. + * \param dag The original ComputeDAG of this state. + * \return The index of the new added stage. + */ + int ApplyToState(State* state, const ComputeDAG& dag) const; + + /*! + * \brief Apply the current step to tvm.schedule. + * \param stages A mutable pointer to a `te::Stage` Array. + * \param stage_to_axes A mutable pointer to a StageToAxesMap. + * \param schedule A mutable pointer to a te::Schedule. + * \return The output Tensors of the new added stage. + */ + Array ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes, + te::Schedule* schedule) const; + + /*! + * \brief Print the current step as equivalent python schedule API. + * \param stages A mutable pointer to a `te::Stage` Array. + * \param stage_to_axes A mutable pointer to a StageToAxesMap. + * \param schedule A mutable pointer to a te::Schedule. + * \return Python schedule code. + */ + String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes, + te::Schedule* schedule) const; + + static constexpr const char* record_prefix_str = "CHW"; + + static constexpr const char* _type_key = "auto_scheduler.CacheWriteStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(CacheWriteStepNode, Object); +}; + +/*! + * \brief Managed reference to CacheWriteStepNode. + * \sa CacheWriteStepNode + */ +class CacheWriteStep : public Step { + public: + /*! + * \brief The constructor. + * \param stage_id The index of the stage to be cache write. + * \param scope_name The scope name to be set for the new added write stage. + */ + CacheWriteStep(int stage_id, String scope_name); + + /*! + * \brief The constructor used to read a step record from JSONReader and create the + * corresponding step. + * \param reader The input JSONReader. + */ + explicit CacheWriteStep(dmlc::JSONReader* reader); + + TVM_DEFINE_OBJECT_REF_METHODS(CacheWriteStep, Step, CacheWriteStepNode); +}; + } // namespace auto_scheduler } // namespace tvm diff --git a/tests/python/unittest/test_auto_scheduler_loop_state.py b/tests/python/unittest/test_auto_scheduler_loop_state.py index 32ea8faa84d0..8c9d635b526c 100644 --- a/tests/python/unittest/test_auto_scheduler_loop_state.py +++ b/tests/python/unittest/test_auto_scheduler_loop_state.py @@ -143,6 +143,282 @@ def test_compute_at_root_inline(): assert s0[conv].iters[6].range.extent == 7 +def test_cache_read_write(): + N, H, W, CO, CI, KH, KW, strides, padding = 4, 7, 7, 512, 512, 3, 3, ( + 1, 1), (1, 1) + + data = te.placeholder((N, CI, H, W), name='Data') + kernel_data = te.placeholder((CO, CI, KH, KW), name='Kernel_data') + k0, k1 = te.compute(kernel_data.shape, + lambda *i: (kernel_data(*i)+1, kernel_data(*i)/2), + name='Kernel_split') + kernel = te.compute(kernel_data.shape, + lambda *i: k0(*i) + k1(*i), + name='Kernel') + conv = topi.nn.conv2d_nchw(data, kernel, strides, padding, dilation=1) + relu = topi.nn.relu(conv) + add = topi.add(data, relu) + + dag = auto_scheduler.ComputeDAG([data, kernel_data, add]) + s0 = dag.get_init_state() + + pad_temp = s0.stage_ops[1] + kernel_split = s0.stage_ops[3] + + # 0: init state + ori_its = s0[add].iters + its = s0.split(add, s0[add].iters[0], [2]) + s0.reorder(add, [its[0], ori_its[1], its[1], ori_its[2], ori_its[3]]) + s0.compute_inline(relu) + + # 1: simple cache_write with compute_at + conv_global = s0.cache_write(conv, "global") + s0.compute_at(conv_global, conv, s0[conv].iters[3]) + + # 2: simple cache_read with compute_at + kernel_global = s0.cache_read(kernel, "global", [conv_global]) + s0.compute_at(kernel_global, conv_global, s0[conv_global].iters[4]) + """ + Placeholder: Data, Kernel_data + for i0 (0,4) + for i1 (0,512) + for i2 (0,9) + for i3 (0,9) + pad_temp = ... + for i0 (0,512) + for i1 (0,512) + for i2 (0,3) + for i3 (0,3) + Kernel_split = ... + for i0 (0,512) + for i1 (0,512) + for i2 (0,3) + for i3 (0,3) + Kernel = ... + for nn (0,4) + for ff (0,512) + for yy (0,7) + for xx (0,7) + for nn_c (None) + for ff_c (None) + for yy_c (None) + for xx_c (None) + for rc (None) + for ax0 (None) + for ax1 (None) + for ax2 (None) + for ax3 (None) + Kernel.global = ... + for ry (None) + for rx (None) + compute.global = ... + compute = ... + for ax0.0 (0,2) + for ax1 (0,512) + for ax0.1 (0,2) + for ax2 (0,7) + for ax3 (0,7) + T_add = ... + """ + s1 = dag.infer_bound_from_state(s0) + assert s1[conv].iters[0].range.extent == 4 + assert s1[conv].iters[1].range.extent == 512 + assert s1[conv].iters[2].range.extent == 7 + assert s1[conv].iters[3].range.extent == 7 + assert s1[kernel_global].iters[0].range.extent == 1 + assert s1[kernel_global].iters[1].range.extent == 1 + assert s1[kernel_global].iters[2].range.extent == 3 + assert s1[kernel_global].iters[3].range.extent == 3 + assert s1[conv_global].iters[0].range.extent == 1 + assert s1[conv_global].iters[1].range.extent == 1 + assert s1[conv_global].iters[2].range.extent == 1 + assert s1[conv_global].iters[3].range.extent == 1 + assert s1[conv_global].iters[4].range.extent == 512 + assert s1[conv_global].iters[5].range.extent == 3 + assert s1[conv_global].iters[6].range.extent == 3 + + # 3: two level cache_read with compute_at + # preparing for GPU's shared memory & local memory + pad_temp_global = s0.cache_read(pad_temp, "global", [conv_global]) + pad_temp_shared = s0.cache_read(pad_temp_global, "shared", [conv_global]) + s0.compute_at(pad_temp_global, conv_global, s0[conv_global].iters[2]) + s0.compute_at(pad_temp_shared, conv_global, s0[conv_global].iters[4]) + + # 4: cache_read with multi readers + # This stage cannot be compute at to its consumer + s0.cache_read(data, "global", [pad_temp, add]) + """ + Placeholder: Data, Kernel_data + for ax0 (0,4) + for ax1 (0,512) + for ax2 (0,7) + for ax3 (0,7) + Data.global = ... + for i0 (0,4) + for i1 (0,512) + for i2 (0,9) + for i3 (0,9) + pad_temp = ... + for i0 (0,512) + for i1 (0,512) + for i2 (0,3) + for i3 (0,3) + Kernel_split = ... + for i0 (0,512) + for i1 (0,512) + for i2 (0,3) + for i3 (0,3) + Kernel = ... + for nn (0,4) + for ff (0,512) + for yy (0,7) + for xx (0,7) + for nn_c (None) + for ff_c (None) + for yy_c (None) + for ax0 (None) + for ax1 (None) + for ax2 (None) + for ax3 (None) + pad_temp.global = ... + for xx_c (None) + for rc (None) + for ax0 (None) + for ax1 (None) + for ax2 (None) + for ax3 (None) + Kernel.global = ... + for ax0 (None) + for ax1 (None) + for ax2 (None) + for ax3 (None) + pad_temp.global.shared = ... + for ry (None) + for rx (None) + compute.global = ... + compute = ... + for ax0.0 (0,2) + for ax1 (0,512) + for ax0.1 (0,2) + for ax2 (0,7) + for ax3 (0,7) + T_add = ... + """ + s1 = dag.infer_bound_from_state(s0) + assert s1[conv].iters[0].range.extent == 4 + assert s1[conv].iters[1].range.extent == 512 + assert s1[conv].iters[2].range.extent == 7 + assert s1[conv].iters[3].range.extent == 7 + assert s1[kernel_global].iters[0].range.extent == 1 + assert s1[kernel_global].iters[1].range.extent == 1 + assert s1[kernel_global].iters[2].range.extent == 3 + assert s1[kernel_global].iters[3].range.extent == 3 + assert s1[conv_global].iters[0].range.extent == 1 + assert s1[conv_global].iters[1].range.extent == 1 + assert s1[conv_global].iters[2].range.extent == 1 + assert s1[conv_global].iters[3].range.extent == 1 + assert s1[conv_global].iters[4].range.extent == 512 + assert s1[conv_global].iters[5].range.extent == 3 + assert s1[conv_global].iters[6].range.extent == 3 + assert s1[pad_temp_global].iters[0].range.extent == 1 + assert s1[pad_temp_global].iters[1].range.extent == 512 + assert s1[pad_temp_global].iters[2].range.extent == 3 + assert s1[pad_temp_global].iters[3].range.extent == 3 + assert s1[pad_temp_shared].iters[0].range.extent == 1 + assert s1[pad_temp_shared].iters[1].range.extent == 1 + assert s1[pad_temp_shared].iters[2].range.extent == 3 + assert s1[pad_temp_shared].iters[3].range.extent == 3 + + # 5: cache_write with multi outputs + # TVM's cache_write actually has a bug with this case: + # + # After schedule.cache_write, TVM generate one new stage: + # From: kernel_data -> kernel_split -> kernel + # To: kernel_data -> kernel_split_global -> kernel_split -> kernel + # + # But with topo sort analyse, we get: + # // kernel_data -> kernel_split_global -> kernel_split -> kernel + # \ / + # ----------------> kernel_split ----------------> + # + # Seems there's bug with the input/output tensor. Such multi outputs case + # should be unusual, so we make some hack on DoCacheWrite + # To be fixed in the future + kernel_split_global = s0.cache_write(kernel_split, "global") + """ + Placeholder: Data, Kernel_data + for ax0 (0,4) + for ax1 (0,512) + for ax2 (0,7) + for ax3 (0,7) + Data.global = ... + for i0 (0,4) + for i1 (0,512) + for i2 (0,9) + for i3 (0,9) + pad_temp = ... + for i0_c (0,512) + for i1_c (0,512) + for i2_c (0,3) + for i3_c (0,3) + Kernel_split.global = ... + for i0 (0,512) + for i1 (0,512) + for i2 (0,3) + for i3 (0,3) + Kernel_split = ... + (******* Bug here, there should not be two kernel_split stage *******) + for i0 (0,512) + for i1 (0,512) + for i2 (0,3) + for i3 (0,3) + Kernel_split = ... + (******* Bug here, there should not be two kernel_split stage *******) + for i0 (0,512) + for i1 (0,512) + for i2 (0,3) + for i3 (0,3) + Kernel = ... + for nn (0,4) + for ff (0,512) + for yy (0,7) + for xx (0,7) + for nn_c (None) + for ff_c (None) + for yy_c (None) + for ax0 (None) + for ax1 (None) + for ax2 (None) + for ax3 (None) + pad_temp.global = ... + for xx_c (None) + for rc (None) + for ax0 (None) + for ax1 (None) + for ax2 (None) + for ax3 (None) + Kernel.global = ... + for ax0 (None) + for ax1 (None) + for ax2 (None) + for ax3 (None) + pad_temp.global.shared = ... + for ry (None) + for rx (None) + compute.global = ... + compute = ... + for ax0.0 (0,2) + for ax1 (0,512) + for ax0.1 (0,2) + for ax2 (0,7) + for ax3 (0,7) + T_add = ... + """ + assert len(s0[kernel_split].iters) == len(s0[kernel_split_global].iters) + for it0, it1 in zip(s0[kernel_split].iters, s0[kernel_split_global].iters): + assert it0.range == it1.range + if __name__ == "__main__": test_split_fuse_reorder_annotation() test_compute_at_root_inline() + test_cache_read_write() diff --git a/tests/python/unittest/test_auto_scheduler_measure.py b/tests/python/unittest/test_auto_scheduler_measure.py index 333d20e4ce9a..5f2f87ad9baa 100644 --- a/tests/python/unittest/test_auto_scheduler_measure.py +++ b/tests/python/unittest/test_auto_scheduler_measure.py @@ -35,7 +35,7 @@ def test_record(): C = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name='C') D = topi.nn.relu(C) k = te.reduce_axis((0, 512), name='k') - E = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * D[k][j], axis=[k]), name='C') + E = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * D[k][j], axis=[k]), name='E') F = topi.nn.relu(E) dag = auto_scheduler.ComputeDAG([A, B, F]) @@ -66,6 +66,11 @@ def test_record(): s.unroll(C, s[C].iters[4]) # Vectorize s.vectorize(C, s[C].iters[6]) + # Cache Read + D_global = s.cache_read(D, "global", [E]) + s.compute_at(D_global, E, s[E].iters[2]) + # Cache Write + s.cache_write(D, "shared") target = tvm.target.create("llvm") task = auto_scheduler.SearchTask(dag, "test", target) From 920f4b1d1d7a664d88462f365ecaaee74f693ac1 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Wed, 22 Jul 2020 14:41:30 +0800 Subject: [PATCH 02/31] Update --- python/tvm/auto_scheduler/loop_state.py | 44 ++++++++++--------------- src/auto_scheduler/compute_dag.h | 5 +-- src/auto_scheduler/loop_state.cc | 2 +- src/auto_scheduler/loop_state.h | 2 +- src/auto_scheduler/transform_step.cc | 4 +-- 5 files changed, 25 insertions(+), 32 deletions(-) diff --git a/python/tvm/auto_scheduler/loop_state.py b/python/tvm/auto_scheduler/loop_state.py index fa50bfa0e1ec..0311ee39c37d 100644 --- a/python/tvm/auto_scheduler/loop_state.py +++ b/python/tvm/auto_scheduler/loop_state.py @@ -372,19 +372,19 @@ def cache_read(self, stage, scope_name, reader_stages): Notes ----- - Cache read step will add an extra stage to the original ComputeDAG. + Cache read step will insert an extra stage to the original ComputeDAG (at the back of the + target stage). """ - if isinstance(reader_stages, list): - reader_stage_ids = [self._resolve_stage_id(id) for id in reader_stages] - else: - raise ValueError("reader_stages must be a list of the integer index, Operation, " + \ - "or output tensor of the stage") - + reader_stage_ids = [self._resolve_stage_id(i) for i in reader_stages] self.state_object, new_stage_id = _ffi_api.StateCacheRead(self.state_object, self._resolve_stage_id(stage), scope_name, reader_stage_ids, self.compute_dag) - return self._insert_new_stage(int(new_stage_id)) + # Add a new stage will change all ops behind the added stage. But we still want to keep the + # original ops map, apply stage id offset to stage_id_map to make them work. + self._apply_stage_id_offset(int(new_stage_id)) + self._update_stage_id_map() + return self.stages[int(new_stage_id)].op def cache_write(self, stage, scope_name): """ Schedule primitive corresponds to te.schedule.cache_write. @@ -404,14 +404,18 @@ def cache_write(self, stage, scope_name): Notes ----- - Cache write step will add an extra stage to the original ComputeDAG, a up-to-date - ComputeDAG is stored in State's `current_compute_dag`. + Cache write step will insert an extra stage to the original ComputeDAG (in the front of the + target stage). This step will cache write all output tensors of the target stage. """ self.state_object, new_stage_id = _ffi_api.StateCacheWrite(self.state_object, self._resolve_stage_id(stage), scope_name, self.compute_dag) - return self._insert_new_stage(int(new_stage_id)) + # Add a new stage will change all ops behind the added stage. But we still want to keep the + # original ops map, apply stage id offset to stage_id_map to make them work. + self._apply_stage_id_offset(int(new_stage_id)) + self._update_stage_id_map() + return self.stages[int(new_stage_id)].op def copy(self): """ Do deep copy of this State. """ @@ -433,22 +437,10 @@ def _update_stage_id_map(self): for index, stage in enumerate(self.stages): self.stage_id_map[stage.op] = index - def _insert_new_stage(self, new_stage_id): - added_op = self.stages[new_stage_id].op - - # Add a new stage will change all ops. But we still want to use the old ops to index stages, - # So we keep updating them and do not remove the old ops. - - # Update stage_id_map for old ops, so we can still use the old ops to index stages. + def _apply_stage_id_offset(self, start_id, offset=1): for key, value in self.stage_id_map.items(): - if value >= new_stage_id: - self.stage_id_map[key] = value + 1 - self.stage_id_map[added_op] = new_stage_id - - # Update stage_id_map for new ops - self._update_stage_id_map() - - return added_op + if value >= start_id: + self.stage_id_map[key] = value + offset def __getitem__(self, key): if isinstance(key, Tensor): diff --git a/src/auto_scheduler/compute_dag.h b/src/auto_scheduler/compute_dag.h index 9b20cd36b992..3f4ea6f269d7 100644 --- a/src/auto_scheduler/compute_dag.h +++ b/src/auto_scheduler/compute_dag.h @@ -115,8 +115,9 @@ class ComputeDAG : public ObjectRef { State InferBound(const State& state) const; /*! - * \brief Some steps may change the structure of ComputeDAG(e.g. CacheRead/CacheWrite Step), this - * is to replay the transform steps and get the up-to-date ComputeDAG. + * \brief Since some steps may change the ComputeDAG (e.g. CacheRead/CacheWrite), the initial + * ComputeDAG may not be up-to-date. This function replays the given transform steps from the + * initial state and return an up-to-date ComputeDAG. * \param steps The steps to be replaied. Usually we'll filter out the unused steps to speed up * the replay process, for we only need to get the new ComputeDAG structure. * \return The up-to-date ComputeDAG. diff --git a/src/auto_scheduler/loop_state.cc b/src/auto_scheduler/loop_state.cc index 22dab595e374..18cc6c2537f3 100644 --- a/src/auto_scheduler/loop_state.cc +++ b/src/auto_scheduler/loop_state.cc @@ -152,7 +152,7 @@ void AttachMap::DeleteStageEntry(AttachMapNode* pnode, int stage_id) { } } -AttachMap AttachMap::ApplyStageIdOfffset(int start_id, int offset) const { +AttachMap AttachMap::ApplyStageIdOffset(int start_id, int offset) const { AttachMap map = AttachMap(make_object()); auto pmap = map.CopyOnWrite(); for (const auto& x : operator->()->stage_to_attach_iter) { diff --git a/src/auto_scheduler/loop_state.h b/src/auto_scheduler/loop_state.h index 427baccbc788..fb07e5a0a32e 100644 --- a/src/auto_scheduler/loop_state.h +++ b/src/auto_scheduler/loop_state.h @@ -206,7 +206,7 @@ class AttachMap : public ObjectRef { * \param offset The index offset to be added to the stage index. * \return The updated AttachMap after applying stage index offset. */ - AttachMap ApplyStageIdOfffset(int start_id, int offset) const; + AttachMap ApplyStageIdOffset(int start_id, int offset = 1) const; TVM_DEFINE_OBJECT_REF_METHODS(AttachMap, ObjectRef, AttachMapNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(AttachMapNode); diff --git a/src/auto_scheduler/transform_step.cc b/src/auto_scheduler/transform_step.cc index bff611d8eed1..e6450125772b 100644 --- a/src/auto_scheduler/transform_step.cc +++ b/src/auto_scheduler/transform_step.cc @@ -1037,7 +1037,7 @@ int CacheReadStepNode::ApplyToState(State* state, const ComputeDAG& dag) const { tmp_stage.CopyOnWrite()->op = current_compute_dag->ops[i]; pstate->stages.Set(i, std::move(tmp_stage)); } - pstate->attach_map = pstate->attach_map.ApplyStageIdOfffset(added_stage_id, 1); + pstate->attach_map = pstate->attach_map.ApplyStageIdOffset(added_stage_id); pstate->current_compute_dag = std::move(current_compute_dag); return added_stage_id; @@ -1164,7 +1164,7 @@ int CacheWriteStepNode::ApplyToState(State* state, const ComputeDAG& dag) const tmp_stage.CopyOnWrite()->op = current_compute_dag->ops[i]; pstate->stages.Set(i, std::move(tmp_stage)); } - pstate->attach_map = pstate->attach_map.ApplyStageIdOfffset(stage_id, added_ops); + pstate->attach_map = pstate->attach_map.ApplyStageIdOffset(stage_id, added_ops); pstate->current_compute_dag = std::move(current_compute_dag); return stage_id; From 90e63919ffec9a026a31c06d1426243bb8427e5d Mon Sep 17 00:00:00 2001 From: "jingbang.yjb" Date: Wed, 22 Jul 2020 16:12:23 +0800 Subject: [PATCH 03/31] Add follow split and follow fused split Signed-off-by: jingbang.yjb Conflicts: src/auto_scheduler/compute_dag.cc src/auto_scheduler/transform_step.cc src/auto_scheduler/transform_step.h tests/python/unittest/test_auto_scheduler_loop_state.py --- python/tvm/auto_scheduler/loop_state.py | 60 ++++++ src/auto_scheduler/compute_dag.cc | 4 +- src/auto_scheduler/loop_state.cc | 42 ++++ src/auto_scheduler/loop_state.h | 4 +- src/auto_scheduler/transform_step.cc | 202 +++++++++++++++++- src/auto_scheduler/transform_step.h | 90 +++++++- .../test_auto_scheduler_loop_state.py | 39 ++++ .../unittest/test_auto_scheduler_measure.py | 11 +- 8 files changed, 441 insertions(+), 11 deletions(-) diff --git a/python/tvm/auto_scheduler/loop_state.py b/python/tvm/auto_scheduler/loop_state.py index 0311ee39c37d..163a2e342197 100644 --- a/python/tvm/auto_scheduler/loop_state.py +++ b/python/tvm/auto_scheduler/loop_state.py @@ -116,6 +116,15 @@ def stages(self): stages : List[Stage] """ return self.state_object.stages + + @property + def transform_steps(self): + """ + Returns + ------- + transform_steps : List[transform_steps] + """ + return self.state_object.transform_steps @property def stage_ops(self): @@ -293,6 +302,57 @@ def split(self, stage, iterator, lengths, inner_to_outer=True): self._resolve_stage_id(stage), iterator, lengths, inner_to_outer) return res + + def follow_split(self, stage, iterator, src_step_id, n_split): + """ + Parameters + ---------- + iterator : Iterator + The iterator to split + src_step_id : int + The index of the split step to follow in the history + n_split : int + The number of split level + + Returns + ------- + res_its : List[Iterator] + The splitted new Iterators + """ + + self.state_object, res = _ffi_api.StateFollowSplit(self.state_object, + self._resolve_stage_id(stage), + iterator, + src_step_id, n_split) + return res + + def follow_fused_split(self, stage, iterator, src_step_ids, level, + factor_or_nparts): + """ + Parameters + ---------- + iterator : Iterator + The iterator to split + src_step_ids : List[int] + The indices of the split steps to follow in the history + level : int + Use the length in this split level + factor_or_nparts : bool + True to use `factor` for split from inner to outer, + False to use `nparts` for split from outer to inner + + Returns + ------- + res_its : List[Iterator] + The splitted new Iterators + """ + + self.state_object, res = _ffi_api.StateFollowFusedSplit(self.state_object, + self._resolve_stage_id(stage), + iterator, + src_step_ids, level, + factor_or_nparts) + return res def compute_at(self, stage, target_stage, target_iter): """ Schedule primitive corresponds to te.compute_at. diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index 0d964cb63513..7b987be49e4e 100644 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -254,7 +254,7 @@ std::pair> ComputeDAG::ApplySteps( // Apply the history steps to TVM schedule // Call each step's ApplyToSchedule method for (const auto& step : transform_steps) { - StepApplyToSchedule(step, stages, stage_to_axes, &schedule); + StepApplyToSchedule(step, stages, stage_to_axes, &schedule, transform_steps); } return std::make_pair(schedule, operator->()->tensors); @@ -298,7 +298,7 @@ String ComputeDAG::PrintStepsAsPython(const Array& transform_steps) const } // Call each step's PrintAsPythonAPI method for (const auto& step : transform_steps) { - ss << StepPrintAsPythonAPI(step, &stages, &stage_to_axes, &schedule); + ss << StepPrintAsPythonAPI(step, &stages, &stage_to_axes, &schedule, transform_steps); } return ss.str(); diff --git a/src/auto_scheduler/loop_state.cc b/src/auto_scheduler/loop_state.cc index 18cc6c2537f3..8db726d17f95 100644 --- a/src/auto_scheduler/loop_state.cc +++ b/src/auto_scheduler/loop_state.cc @@ -239,6 +239,28 @@ Iterator State::vectorize(int stage_id, const Iterator& it) { return step->ApplyToState(this); } +Array State::follow_split(int stage_id, const Iterator& it, + int src_step_id, int n_split) { + const Stage& stage = operator->()->stages[stage_id]; + + FollowSplitStep step = FollowSplitStep( + stage_id, GetIndex(stage->iters, it), src_step_id, n_split); + CopyOnWrite()->transform_steps.push_back(step); + return step->ApplyToState(this); +} + +Array State::follow_fused_split( + int stage_id, const Iterator& it, const Array& src_step_ids, + int level, bool factor_or_nparts) { + const Stage& stage = operator->()->stages[stage_id]; + + FollowFusedSplitStep step = + FollowFusedSplitStep(stage_id, GetIndex(stage->iters, it), + src_step_ids, level, factor_or_nparts); + CopyOnWrite()->transform_steps.push_back(step); + return step->ApplyToState(this); +} + Iterator State::fuse(int stage_id, const Array& iters) { const Stage& stage = operator->()->stages[stage_id]; Array indices; @@ -436,6 +458,26 @@ TVM_REGISTER_GLOBAL("auto_scheduler.StateVectorize") return Array{state, res}; }); +TVM_REGISTER_GLOBAL("auto_scheduler.StateFollowSplit") +.set_body_typed([](State state, int stage_id, const Iterator& it, + int src_step_id, int n_split) { + const auto& res = state.follow_split(stage_id, it, src_step_id, n_split); + return Array{state, Array(res)}; +}); + +TVM_REGISTER_GLOBAL("auto_scheduler.StateFollowFusedSplit") +.set_body_typed([](State state, int stage_id, const Iterator& it, + const Array& src_step_ids, int level, + bool factor_or_nparts) { + Array array_src_step_ids; + for (const auto& i : src_step_ids) { + array_src_step_ids.push_back(i->value); + } + const auto& res = state.follow_fused_split( + stage_id, it, array_src_step_ids, level, factor_or_nparts); + return Array{state, Array(res)}; +}); + TVM_REGISTER_GLOBAL("auto_scheduler.StateFuse") .set_body_typed([](State state, int stage_id, const Array& iters) { const auto& res = state.fuse(stage_id, iters); diff --git a/src/auto_scheduler/loop_state.h b/src/auto_scheduler/loop_state.h index fb07e5a0a32e..d8ea9d467753 100644 --- a/src/auto_scheduler/loop_state.h +++ b/src/auto_scheduler/loop_state.h @@ -353,8 +353,10 @@ class State : public ObjectRef { */ Array split(int stage_id, const Iterator& it, const Array>& lengths, bool inner_to_outer = true); + + Array follow_split(int stage_id, const Iterator& it, int src_step_id, int n_split); - /********** Step APIs working on multiple stages **********/ + Array follow_fused_split(int stage_id, const Iterator& it, const Array& src_step_ids, int level, bool factor_or_nparts); /*! * \brief Schedule primitive corresponds to te.compute_at. diff --git a/src/auto_scheduler/transform_step.cc b/src/auto_scheduler/transform_step.cc index e6450125772b..d8fbbb221dab 100644 --- a/src/auto_scheduler/transform_step.cc +++ b/src/auto_scheduler/transform_step.cc @@ -96,7 +96,12 @@ Step StepReadFromRecord(dmlc::JSONReader* reader) { return CacheReadStep(reader); } else if (name == CacheWriteStepNode::record_prefix_str) { return CacheWriteStep(reader); - } else { + } else if (name == FollowSplitStepNode::record_prefix_str) { + return FollowSplitStep(reader); + } else if (name == FollowFusedSplitStepNode::record_prefix_str) { + return FollowFusedSplitStep(reader); + } + else { LOG(FATAL) << "Invalid step format: " << name; } return Step(); @@ -121,13 +126,17 @@ void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag) { ps->ApplyToState(state, dag); } else if (auto ps = step.as()) { ps->ApplyToState(state, dag); + } else if(auto ps = step.as()){ + ps->ApplyToState(state); + } else if(auto ps = step.as()){ + ps->ApplyToState(state); } else { LOG(FATAL) << "Invalid step: " << step; } } void StepApplyToSchedule(const Step& step, Array* stages, StageToAxesMap* stage_to_axes, - te::Schedule* schedule) { + te::Schedule* schedule, const Array& transform_steps) { if (auto ps = step.as()) { ps->ApplyToSchedule(stages, stage_to_axes); } else if (auto ps = step.as()) { @@ -146,13 +155,17 @@ void StepApplyToSchedule(const Step& step, Array* stages, StageToAxes ps->ApplyToSchedule(stages, stage_to_axes, schedule); } else if (auto ps = step.as()) { ps->ApplyToSchedule(stages, stage_to_axes, schedule); + } else if(auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes, transform_steps); + } else if(auto ps = step.as()){ + ps->ApplyToSchedule(stages, stage_to_axes, transform_steps); } else { LOG(FATAL) << "Invalid Step: " << step; } } String StepPrintAsPythonAPI(const Step& step, Array* stages, - StageToAxesMap* stage_to_axes, te::Schedule* schedule) { + StageToAxesMap* stage_to_axes, te::Schedule* schedule, const Array& transform_steps) { if (auto ps = step.as()) { return ps->PrintAsPythonAPI(stages, stage_to_axes); } else if (auto ps = step.as()) { @@ -171,6 +184,10 @@ String StepPrintAsPythonAPI(const Step& step, Array* stages, return ps->PrintAsPythonAPI(stages, stage_to_axes, schedule); } else if (auto ps = step.as()) { return ps->PrintAsPythonAPI(stages, stage_to_axes, schedule); + } else if(auto ps = step.as()){ + return ps->PrintAsPythonAPI(stages, stage_to_axes, transform_steps); + } else if(auto ps = step.as()){ + return ps->PrintAsPythonAPI(stages, stage_to_axes, transform_steps); } else { LOG(FATAL) << "Invalid Step: " << step; } @@ -776,6 +793,185 @@ String SplitStepNode::PrintAsPythonAPI(Array* stages, return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, lengths, inner_to_outer); } +/********** Follow Split **********/ +FollowSplitStep::FollowSplitStep(int stage_id, int iter_id, + int src_step_id, int n_split) { + auto node = make_object(); + node->stage_id = stage_id; + node->iter_id = iter_id; + node->src_step_id = src_step_id; + node->n_split = n_split; + data_ = std::move(node); +} + +void FollowSplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { + writer->WriteArraySeperator(); + writer->WriteString(record_prefix_str); + writer->WriteArrayItem(stage_id); + writer->WriteArrayItem(iter_id); + writer->WriteArrayItem(src_step_id); + writer->WriteArrayItem(n_split); +} + +void FollowSplitStepNode::ExtractSplitLengths( + const Array& transform_steps, + Array>* lengths) const { + CHECK_LT(src_step_id, transform_steps.size()); + auto ps = transform_steps[src_step_id].as(); + CHECK(ps != nullptr); + + // get lengths from src step + lengths->reserve(n_split); + int j = 0; + for (; j < n_split - 1; ++j) { + lengths->push_back(ps->lengths[j]); + } + PrimExpr last_factor = 1; + for (; j < static_cast(ps->lengths.size()); ++j) { + if (ps->lengths[j]) { + last_factor *= ps->lengths[j].value(); + } else { + last_factor = PrimExpr(); + break; + } + } + if (last_factor.defined()) { + lengths->push_back(Downcast(last_factor)); + } else { + lengths->push_back(NullOpt); + } +} + +FollowSplitStep::FollowSplitStep(dmlc::JSONReader* reader) { + auto node = make_object(); + bool s; + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->stage_id); + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->iter_id); + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->src_step_id); + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->n_split); + + data_ = std::move(node); +} + +Array FollowSplitStepNode::ApplyToState(State* state) const { + Array> lengths; + ExtractSplitLengths((*state)->transform_steps, &lengths); + return ApplySplitToState(state, stage_id, iter_id, lengths, true); +} + +Array FollowSplitStepNode::ApplyToSchedule( + Array *stages, StageToAxesMap *stage_to_axes, + const Array& transform_steps) const { + Array> lengths; + ExtractSplitLengths(transform_steps, &lengths); + return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, + lengths, true); +} + +String FollowSplitStepNode::PrintAsPythonAPI( + Array *stages, StageToAxesMap *stage_to_axes, const Array& transform_steps) const { + Array> lengths; + ExtractSplitLengths(transform_steps, &lengths); + return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, + lengths, true); +} + +/********** Follow Fused Split **********/ +FollowFusedSplitStep::FollowFusedSplitStep(int stage_id, int iter_id, + const Array& src_step_ids, int level, bool factor_or_nparts) { + auto node = make_object(); + node->stage_id = stage_id; + node->iter_id = iter_id; + node->src_step_ids = src_step_ids;; + node->level = level; + node->factor_or_nparts = factor_or_nparts; + data_ = std::move(node); +} + +FollowFusedSplitStep::FollowFusedSplitStep(dmlc::JSONReader* reader) { + auto node = make_object(); + bool s; + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->stage_id); + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->iter_id); + s = reader->NextArrayItem(); + CHECK(s); + std::vector int_list; + reader->Read(&int_list); + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->level); + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->factor_or_nparts); + + ::tvm::Array<::tvm::Integer> src_step_ids; + for (const auto& i : int_list) { + src_step_ids.push_back(i); + } + node->src_step_ids = src_step_ids; + data_ = std::move(node); +} + +void FollowFusedSplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { + writer->WriteArraySeperator(); + writer->WriteString(record_prefix_str); + writer->WriteArrayItem(stage_id); + writer->WriteArrayItem(iter_id); + writer->WriteArrayItem(IntArrayToVector(src_step_ids)); + writer->WriteArrayItem(level); + writer->WriteArrayItem(static_cast(factor_or_nparts)); +} + +Optional FollowFusedSplitStepNode::ExtractSplitLength( + const Array& transform_steps) const { + PrimExpr ret(1); + + for (int src_step_id : src_step_ids) { + CHECK_LT(src_step_id, transform_steps.size()); + auto ps = transform_steps[src_step_id].as(); + CHECK(ps != nullptr); + if (ps->lengths[level] && ret.defined()) { + ret *= ps->lengths[level].value(); + } else { + return NullOpt; + } + } + return Downcast(ret); +} + +Array FollowFusedSplitStepNode::ApplyToState(State* state) const { + const Optional& length = ExtractSplitLength((*state)->transform_steps); + return ApplySplitToState(state, stage_id, iter_id, {length}, factor_or_nparts); +} + +Array FollowFusedSplitStepNode::ApplyToSchedule( + Array *stages, StageToAxesMap *stage_to_axes, + const Array& transform_steps) const { + const Optional& length = ExtractSplitLength(transform_steps); + return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, + {length}, factor_or_nparts); +} + +String FollowFusedSplitStepNode::PrintAsPythonAPI( + Array *stages, StageToAxesMap *stage_to_axes, const Array& transform_steps) const { + const Optional& length = ExtractSplitLength(transform_steps); + return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, + {length}, factor_or_nparts); +} + + /********** Primitives working on multiple stages **********/ /********** Compute At **********/ diff --git a/src/auto_scheduler/transform_step.h b/src/auto_scheduler/transform_step.h index e1746189c29e..0a1044f38988 100644 --- a/src/auto_scheduler/transform_step.h +++ b/src/auto_scheduler/transform_step.h @@ -207,7 +207,7 @@ void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag); * CacheRead/CacheWrite step) */ void StepApplyToSchedule(const Step& step, Array* stages, StageToAxesMap* stage_to_axes, - te::Schedule* schedule); + te::Schedule* schedule, const Array& transform_steps); /*! * \brief Print the step as equivalent python schedule API. @@ -219,8 +219,8 @@ void StepApplyToSchedule(const Step& step, Array* stages, StageToAxes * \return Python schedule code. */ String StepPrintAsPythonAPI(const Step& step, Array* stages, - StageToAxesMap* stage_to_axes, te::Schedule* schedule); - + StageToAxesMap* stage_to_axes, te::Schedule* schedule, + const Array& transform_steps); /********** Primitives working on single stage **********/ /*! @@ -411,6 +411,90 @@ class ReorderStep : public Step { TVM_DEFINE_OBJECT_REF_METHODS(ReorderStep, Step, ReorderStepNode); }; +/*! \brief Similar to SplitStepNode, but use split factor from another stepf + * (i.e. Follow another split step) */ +class FollowSplitStepNode: public StepNode { + public: + int iter_id; // The id of the iter to split + int src_step_id; // The index of the split step to follow in the history + int n_split; // The number of split level + + void WriteToRecord(dmlc::JSONWriter* writer) const final; + + void ExtractSplitLengths( const Array& transform_steps, + Array>* lengths) const; + + Array ApplyToState(State* state) const; + + Array ApplyToSchedule(Array *stages, + StageToAxesMap *stage_to_axes, + const Array& transform_steps) const; + + String PrintAsPythonAPI(Array *stages, + StageToAxesMap *stage_to_axes, + const Array& transform_steps) const; + static constexpr const char* record_prefix_str = "FSP"; + static constexpr const char* _type_key = "auto_scheduler.FollowSplitStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(FollowSplitStepNode, Object); +}; + +/*! + * \brief Managed reference to FollowSplitStepNode. + * \sa FollowSplitStepNode + */ +class FollowSplitStep : public Step { + public: + FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split); + + explicit FollowSplitStep(dmlc::JSONReader* reader); + + TVM_DEFINE_OBJECT_REF_METHODS(FollowSplitStep, Step, FollowSplitStepNode); +}; + +/*! \brief Similar to FollowSplitStep, but use split factors from multiple steps. + * \Note This can be used for the split in cooperative fetching + */ +class FollowFusedSplitStepNode: public StepNode { + public: + int iter_id; // The id of the iter to split + Array src_step_ids; // The indices of the split steps to follow in the history + int level; // Use the length in this split level + bool factor_or_nparts; // If this is true, use factor. Otherwise, use nparts + + void WriteToRecord(dmlc::JSONWriter* writer) const final; + + Optional ExtractSplitLength( const Array& transform_steps) const; + + Array ApplyToState(State* state) const; + + Array ApplyToSchedule(Array *stages, + StageToAxesMap *stage_to_axes, + const Array& transform_steps) const; + + String PrintAsPythonAPI(Array *stages, + StageToAxesMap *stage_to_axes, + const Array& transform_steps) const; + + static constexpr const char* record_prefix_str = "FFSP"; + static constexpr const char* _type_key = "auto_scheduler.FollowFusedSplitStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(FollowFusedSplitStepNode, Object); +}; + +/*! + * \brief Managed reference to FollowFusedSplitStepNode. + * \sa FollowFusedSplitStepNode + */ +class FollowFusedSplitStep : public Step { + public: + FollowFusedSplitStep(int stage_id, int iter_id, + const Array& src_step_ids, + int level, bool factor_or_nparts); + + explicit FollowFusedSplitStep(dmlc::JSONReader* reader); + + TVM_DEFINE_OBJECT_REF_METHODS(FollowFusedSplitStep, Step, FollowFusedSplitStepNode); +}; + /*! * \brief Split step that corresponds to te::Stage::split with additional * support of multiple-level of factors diff --git a/tests/python/unittest/test_auto_scheduler_loop_state.py b/tests/python/unittest/test_auto_scheduler_loop_state.py index 8c9d635b526c..0243b40904df 100644 --- a/tests/python/unittest/test_auto_scheduler_loop_state.py +++ b/tests/python/unittest/test_auto_scheduler_loop_state.py @@ -142,6 +142,44 @@ def test_compute_at_root_inline(): assert s0[conv].iters[5].range.extent == 7 assert s0[conv].iters[6].range.extent == 7 +def test_follow_split_follow_fused_split(): + A, B, C = matmul_auto_scheduler_test(512, 512, 512) + dag = auto_scheduler.ComputeDAG([A, B, C]) + s0 = dag.get_init_state() + + C_global = s0.cache_write(C, "global") + its0 = s0.split(C, s0[C].iters[0], [4, 2, 8, 4], True) + split_step0 = len(s0.transform_steps) - 1 + for level in range(1, 6): + tmp = s0.copy() + tmp.follow_split(C_global, tmp[C_global].iters[0], split_step0, level) + for i in range(0, level): + assert tmp[C].iters[i].range.extent == \ + tmp[C_global].iters[i].range.extent + + its1 = s0.split(C, s0[C].iters[5], [2, 2, 4, 8]) + split_step1 = len(s0.transform_steps) - 1 + its = [] + for i0, i1 in zip(its0, its1): + its.append(i0) + its.append(i1) + s0.reorder(C, its) + for i in range(0, 5): + s0.fuse(C, [s0[C].iters[i], s0[C].iters[i + 1]]) + + for level in range(0, 4): + tmp = s0.copy() + tmp.follow_fused_split(C_global, tmp[C_global].iters[0], + [split_step0, split_step1], level, False) + assert tmp[C].iters[level + 1].range.extent == \ + tmp[C_global].iters[0].range.extent + + for level in range(0, 4): + tmp = s0.copy() + tmp.follow_fused_split(C_global, tmp[C_global].iters[0], + [split_step0, split_step1], level, True) + assert tmp[C].iters[level + 1].range.extent == \ + tmp[C_global].iters[1].range.extent def test_cache_read_write(): N, H, W, CO, CI, KH, KW, strides, padding = 4, 7, 7, 512, 512, 3, 3, ( @@ -422,3 +460,4 @@ def test_cache_read_write(): test_split_fuse_reorder_annotation() test_compute_at_root_inline() test_cache_read_write() + test_follow_split_follow_fused_split() diff --git a/tests/python/unittest/test_auto_scheduler_measure.py b/tests/python/unittest/test_auto_scheduler_measure.py index 5f2f87ad9baa..039dff921144 100644 --- a/tests/python/unittest/test_auto_scheduler_measure.py +++ b/tests/python/unittest/test_auto_scheduler_measure.py @@ -37,10 +37,17 @@ def test_record(): k = te.reduce_axis((0, 512), name='k') E = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * D[k][j], axis=[k]), name='E') F = topi.nn.relu(E) - dag = auto_scheduler.ComputeDAG([A, B, F]) s = dag.get_init_state() - + + # Follow split + C_global = s0.cache_write(C, "global") + split_step0 = len(s0.transform_steps) - 1 + split_step1 = len(s0.transform_steps) - 1 + s.follow_split(C_global, tmp[C_global].iters[0], split_step0, 0) + # Follow fused split + s.follow_fused_split(C_global, tmp[C_global].iters[0], + [split_step0, split_step1], 0, False) # Split its0 = s.split(C, s[C].iters[0], [4, 8, 8]) its1 = s.split(C, s[C].iters[4], [8, 4, 4]) From e144082c1db5b9c62aebbe10909989b8dd069a02 Mon Sep 17 00:00:00 2001 From: "jingbang.yjb" Date: Wed, 22 Jul 2020 20:01:29 +0800 Subject: [PATCH 04/31] add loop_state.py Signed-off-by: jingbang.yjb --- python/tvm/auto_scheduler/loop_state.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/python/tvm/auto_scheduler/loop_state.py b/python/tvm/auto_scheduler/loop_state.py index 163a2e342197..0bdd35aed40f 100644 --- a/python/tvm/auto_scheduler/loop_state.py +++ b/python/tvm/auto_scheduler/loop_state.py @@ -116,7 +116,7 @@ def stages(self): stages : List[Stage] """ return self.state_object.stages - + @property def transform_steps(self): """ @@ -302,7 +302,7 @@ def split(self, stage, iterator, lengths, inner_to_outer=True): self._resolve_stage_id(stage), iterator, lengths, inner_to_outer) return res - + def follow_split(self, stage, iterator, src_step_id, n_split): """ Parameters @@ -320,12 +320,12 @@ def follow_split(self, stage, iterator, src_step_id, n_split): The splitted new Iterators """ - self.state_object, res = _ffi_api.StateFollowSplit(self.state_object, - self._resolve_stage_id(stage), + self.state_object, res = _ffi_api.StateFollowSplit(self.state_object, + self._resolve_stage_id(stage), iterator, src_step_id, n_split) return res - + def follow_fused_split(self, stage, iterator, src_step_ids, level, factor_or_nparts): """ @@ -347,9 +347,9 @@ def follow_fused_split(self, stage, iterator, src_step_ids, level, The splitted new Iterators """ - self.state_object, res = _ffi_api.StateFollowFusedSplit(self.state_object, + self.state_object, res = _ffi_api.StateFollowFusedSplit(self.state_object, self._resolve_stage_id(stage), - iterator, + iterator, src_step_ids, level, factor_or_nparts) return res From 86c36709e37879a652b89a1af92a6afaaf23d5a7 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Wed, 22 Jul 2020 15:07:31 +0800 Subject: [PATCH 05/31] Update --- src/auto_scheduler/loop_state.h | 8 +++---- src/auto_scheduler/transform_step.cc | 32 +++++++++++++++++----------- 2 files changed, 23 insertions(+), 17 deletions(-) diff --git a/src/auto_scheduler/loop_state.h b/src/auto_scheduler/loop_state.h index fb07e5a0a32e..a40b9973d373 100644 --- a/src/auto_scheduler/loop_state.h +++ b/src/auto_scheduler/loop_state.h @@ -390,8 +390,8 @@ class State : public ObjectRef { * \param scope_name The scope name to be set for the new added read stage. * \param reader_stage_ids The indexes of reader stages. * \param dag The original ComputeDAG of this state. - * \note Cache read step will add an extra stage to the original ComputeDAG, a up-to-date - * ComputeDAG is stored in State's `current_compute_dag`. + * \note Cache read step will add an extra stage to the original ComputeDAG (at the back of the + * target stage), a up-to-date ComputeDAG is stored in State's `current_compute_dag`. */ int cache_read(int stage_id, const String& scope_name, const Array& reader_stage_ids, const ComputeDAG& dag); @@ -400,8 +400,8 @@ class State : public ObjectRef { * \param stage_id The index of the stage to be cache write. * \param scope_name The scope name to be set for the new added write stage. * \param dag The original ComputeDAG of this state. - * \note Cache write step will add an extra stage to the original ComputeDAG, a up-to-date - * ComputeDAG is stored in State's `current_compute_dag`. + * \note Cache write step will add an extra stage to the original ComputeDAG (in the front of the + * target stage), a up-to-date ComputeDAG is stored in State's `current_compute_dag`. * This step will cache write all output tensors of the target stage. */ int cache_write(int stage_id, const String& scope_name, const ComputeDAG& dag); diff --git a/src/auto_scheduler/transform_step.cc b/src/auto_scheduler/transform_step.cc index e6450125772b..34c415cfd849 100644 --- a/src/auto_scheduler/transform_step.cc +++ b/src/auto_scheduler/transform_step.cc @@ -1024,8 +1024,9 @@ int CacheReadStepNode::ApplyToState(State* state, const ComputeDAG& dag) const { } const ComputeDAG& current_compute_dag = dag.ReplayAndGetDAG(replay_steps); - // target -> target + target_store - // Should update target's op, insert new stage, update the later stage's op + // target_stage -> target_stage + target_store + // Update the op of the target stage, insert a new cache read stage behind, update the op of + // later stages, then update the stage_id mapping in AttachMap int added_stage_id = stage_id + 1; Stage tmp_stage = pstate->stages[stage_id]; tmp_stage.CopyOnWrite()->op = current_compute_dag->ops[stage_id]; @@ -1064,9 +1065,10 @@ te::Tensor CacheReadStepNode::ApplyToSchedule(Array* stages, String CacheReadStepNode::PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes, te::Schedule* schedule) const { std::stringstream ss; - // Copy stage here, for the original stage will change after apply + // Since the original stage will be changed after schedule apply, keep a copy here + // These information will be used to print Python API string later auto stage = (*stages)[stage_id]; - std::vector reader_stages; + Array reader_stages; for (size_t i = 0; i < reader_stage_ids.size(); ++i) { reader_stages.push_back((*stages)[reader_stage_ids[i]]); } @@ -1081,6 +1083,7 @@ String CacheReadStepNode::PrintAsPythonAPI(Array* stages, StageToAxes } ss << "])\n"; + // Print the iterators of the new added stage const auto& iters = out->op->root_iter_vars(); for (size_t i = 0; i < iters.size(); ++i) { ss << CleanName(iters[i]->var->name_hint); @@ -1138,20 +1141,21 @@ int CacheWriteStepNode::ApplyToState(State* state, const ComputeDAG& dag) const : dag->ops.size(); const ComputeDAG& current_compute_dag = dag.ReplayAndGetDAG(replay_steps); int added_ops = current_compute_dag->ops.size() - last_dag_op_size; + // TODO(jcf94): Update this check to equal after fixing the cache write bug in TVM CHECK_GE(added_ops, 1); - // target -> target_compute + target - // Assume target stage has never been applied any steps before cache_write - // Should insert new stage, update target stage, update the later stage's op + // target_stage -> cache_write_stage + target_stage + // Assume no step has been applied to the target stage before cache write. + // Insert a new cache write stage ahead, update the op of the target stage and later stages, then + // update the stage_id mapping in AttachMap pstate->stages.insert(pstate->stages.begin() + stage_id, Stage(current_compute_dag->ops[stage_id])); pstate->stages.Set(stage_id + 1, Stage(current_compute_dag->ops[stage_id + 1])); int next_stage_id = stage_id + 2; - // Notice: added_ops should actually assert to be 1 - // branch of 2 here is somehow a hack to TVM's cache_write bug with multi outputs - // see `tests/python/unittest/test_auto_scheduler_loop_state.py::test_cache_read_write` test for - // more information - // TODO(jcf94): Fix the cache write bug in TVM and remove these branches here + // TODO(jc94): Fix the cache write bug in TVM and remove added_op == 2 support. + // TVM's cache_write has a bug with multi outputs. See + // `tests/python/unittest/test_auto_scheduler_loop_state.py::test_cache_read_write` test + // for more details if (added_ops == 2) { pstate->stages.insert(pstate->stages.begin() + next_stage_id, Stage(current_compute_dag->ops[next_stage_id])); @@ -1196,7 +1200,8 @@ Array CacheWriteStepNode::ApplyToSchedule(Array* stages, String CacheWriteStepNode::PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes, te::Schedule* schedule) const { std::stringstream ss; - // Copy stage here, for the original stage will change after apply + // Since the original stage will be changed after schedule apply, keep a copy here + // These information will be used to print Python API string later te::Stage stage = (*stages)[stage_id]; auto outs = ApplyToSchedule(stages, stage_to_axes, schedule); @@ -1211,6 +1216,7 @@ String CacheWriteStepNode::PrintAsPythonAPI(Array* stages, StageToAxe } ss << "], \"" << scope_name << "\")\n"; + // Print the iterators of the new added stage for (const auto& out : outs) { const auto& iters = out->op->root_iter_vars(); for (size_t i = 0; i < iters.size(); ++i) { From abfb150ea1d7a57874adce655addb36c958e96f8 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Thu, 23 Jul 2020 10:11:42 +0800 Subject: [PATCH 06/31] Update --- src/auto_scheduler/transform_step.cc | 47 +++++++++++----------------- 1 file changed, 19 insertions(+), 28 deletions(-) diff --git a/src/auto_scheduler/transform_step.cc b/src/auto_scheduler/transform_step.cc index 34c415cfd849..ec29a93231cb 100644 --- a/src/auto_scheduler/transform_step.cc +++ b/src/auto_scheduler/transform_step.cc @@ -960,15 +960,22 @@ String ComputeRootStepNode::PrintAsPythonAPI(Array* stages, /********** Primitives adding new stages **********/ -// Common part for steps that add new stages -// (e.g. CacheReadStep, CacheWriteStep, RfactorStep) -void AddStageModificationSteps(int step_id, const Array& transform_steps, - Array* replay_steps) { - const Step& step = transform_steps[step_id]; - if (step->IsInstance() || step->IsInstance()) { - replay_steps->push_back(step); +/*! + * \brief Common part for steps that add new stages(e.g. CacheReadStep, CacheWriteStep, + * RfactorStep). This will filter out all steps that can change the stages of ComputeDAG. + */ +Array GetStageModifiableSteps(Step current_step, const Array& transform_steps) { + Array ret_steps; + for (const Step& step : transform_steps) { + if (step->IsInstance() || step->IsInstance()) { + ret_steps.push_back(step); + } + // TODO(jcf94): add rfactor support + if (step.same_as(current_step)) { + break; + } } - // TODO(jcf94): add rfactor support + return ret_steps; } /********** Cache Read **********/ @@ -1015,14 +1022,8 @@ void CacheReadStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { int CacheReadStepNode::ApplyToState(State* state, const ComputeDAG& dag) const { StateNode* pstate = state->CopyOnWrite(); - Array replay_steps; - for (size_t i = 0; i < pstate->transform_steps.size(); ++i) { - AddStageModificationSteps(i, pstate->transform_steps, &replay_steps); - if (pstate->transform_steps[i].same_as(GetRef(this))) { - break; - } - } - const ComputeDAG& current_compute_dag = dag.ReplayAndGetDAG(replay_steps); + const ComputeDAG& current_compute_dag = + dag.ReplayAndGetDAG(GetStageModifiableSteps(GetRef(this), (*state)->transform_steps)); // target_stage -> target_stage + target_store // Update the op of the target stage, insert a new cache read stage behind, update the op of @@ -1048,7 +1049,6 @@ te::Tensor CacheReadStepNode::ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes, te::Schedule* schedule) const { const te::Stage& stage = (*stages)[stage_id]; - Array readers; for (const auto& i : reader_stage_ids) { readers.push_back((*stages)[i]->origin_op); @@ -1072,7 +1072,6 @@ String CacheReadStepNode::PrintAsPythonAPI(Array* stages, StageToAxes for (size_t i = 0; i < reader_stage_ids.size(); ++i) { reader_stages.push_back((*stages)[reader_stage_ids[i]]); } - auto out = ApplyToSchedule(stages, stage_to_axes, schedule); ss << CleanName(out->op->name) << " = " @@ -1129,17 +1128,11 @@ void CacheWriteStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { int CacheWriteStepNode::ApplyToState(State* state, const ComputeDAG& dag) const { StateNode* pstate = state->CopyOnWrite(); - Array replay_steps; - for (size_t i = 0; i < pstate->transform_steps.size(); ++i) { - AddStageModificationSteps(i, pstate->transform_steps, &replay_steps); - if (pstate->transform_steps[i].same_as(GetRef(this))) { - break; - } - } int last_dag_op_size = pstate->current_compute_dag.defined() ? pstate->current_compute_dag.as()->ops.size() : dag->ops.size(); - const ComputeDAG& current_compute_dag = dag.ReplayAndGetDAG(replay_steps); + const ComputeDAG& current_compute_dag = + dag.ReplayAndGetDAG(GetStageModifiableSteps(GetRef(this), (*state)->transform_steps)); int added_ops = current_compute_dag->ops.size() - last_dag_op_size; // TODO(jcf94): Update this check to equal after fixing the cache write bug in TVM CHECK_GE(added_ops, 1); @@ -1178,7 +1171,6 @@ Array CacheWriteStepNode::ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes, te::Schedule* schedule) const { const te::Stage& stage = (*stages)[stage_id]; - Array tensor_array; // If the target stage has multi outputs, TVM requires to cache_write // all of them or schedule.cache_write will raise an error @@ -1203,7 +1195,6 @@ String CacheWriteStepNode::PrintAsPythonAPI(Array* stages, StageToAxe // Since the original stage will be changed after schedule apply, keep a copy here // These information will be used to print Python API string later te::Stage stage = (*stages)[stage_id]; - auto outs = ApplyToSchedule(stages, stage_to_axes, schedule); for (size_t i = 0; i < outs.size(); ++i) { From 3c1da648c699db5e750885ba8101fcc043f71fc5 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Thu, 23 Jul 2020 10:26:41 +0800 Subject: [PATCH 07/31] Update state->current_compute_dag to Optional --- src/auto_scheduler/loop_state.h | 4 ++-- src/auto_scheduler/transform_step.cc | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/auto_scheduler/loop_state.h b/src/auto_scheduler/loop_state.h index a40b9973d373..3cab133f3c25 100644 --- a/src/auto_scheduler/loop_state.h +++ b/src/auto_scheduler/loop_state.h @@ -242,9 +242,9 @@ class StateNode : public Object { * \brief The up-to-date ComputeDAG of this state, used for some steps that may change the * stage structure of the ComputeDAG (e.g. CacheReadStep/CacheWriteStep which Will be added * later). - * The default value is an empty ObjectRef. (means no modification to the original DAG) + * The default value is an empty NullOpt. (means no modification to the original DAG) */ - ObjectRef current_compute_dag; + Optional current_compute_dag; /*! * \brief Indicate whether this state has unfilled tile sizes. A concrete state means that all * tile sizes of the state is filled. Only concrete state can be apply to TVM schedule. diff --git a/src/auto_scheduler/transform_step.cc b/src/auto_scheduler/transform_step.cc index ec29a93231cb..e18154940271 100644 --- a/src/auto_scheduler/transform_step.cc +++ b/src/auto_scheduler/transform_step.cc @@ -1128,8 +1128,8 @@ void CacheWriteStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { int CacheWriteStepNode::ApplyToState(State* state, const ComputeDAG& dag) const { StateNode* pstate = state->CopyOnWrite(); - int last_dag_op_size = pstate->current_compute_dag.defined() - ? pstate->current_compute_dag.as()->ops.size() + int last_dag_op_size = pstate->current_compute_dag + ? pstate->current_compute_dag.value().as()->ops.size() : dag->ops.size(); const ComputeDAG& current_compute_dag = dag.ReplayAndGetDAG(GetStageModifiableSteps(GetRef(this), (*state)->transform_steps)); From c4a344cbf524ffc383488988d01ddf8087963745 Mon Sep 17 00:00:00 2001 From: "jingbang.yjb" Date: Thu, 23 Jul 2020 11:57:08 +0800 Subject: [PATCH 08/31] Add some doc strings for Follow_Split and Follow_fused_split Signed-off-by: jingbang.yjb --- python/tvm/auto_scheduler/loop_state.py | 31 ++++++++++++--------- src/auto_scheduler/loop_state.h | 22 +++++++++++++-- src/auto_scheduler/transform_step.cc | 37 ++++++++++++------------- 3 files changed, 56 insertions(+), 34 deletions(-) diff --git a/python/tvm/auto_scheduler/loop_state.py b/python/tvm/auto_scheduler/loop_state.py index 0bdd35aed40f..5904ed9cd588 100644 --- a/python/tvm/auto_scheduler/loop_state.py +++ b/python/tvm/auto_scheduler/loop_state.py @@ -305,19 +305,23 @@ def split(self, stage, iterator, lengths, inner_to_outer=True): def follow_split(self, stage, iterator, src_step_id, n_split): """ + Schedule primitive corresponds to te.follow_split. Parameters ---------- + stage : Union[int, Operation, Tensor] + The Stage to be split, which can be specified by the integer index, Operation, + or output tensor of the stage. iterator : Iterator - The iterator to split - src_step_id : int - The index of the split step to follow in the history - n_split : int - The number of split level + The iterator to split. + src_step_id : Int + The index of the split step to follow in the history. + n_split : Int + The number of split level. Returns ------- res_its : List[Iterator] - The splitted new Iterators + The splitted new Iterators. """ self.state_object, res = _ffi_api.StateFollowSplit(self.state_object, @@ -329,22 +333,23 @@ def follow_split(self, stage, iterator, src_step_id, n_split): def follow_fused_split(self, stage, iterator, src_step_ids, level, factor_or_nparts): """ + Schedule primitive corresponds to te.follow_fused_split. Parameters ---------- iterator : Iterator - The iterator to split + The iterator to split. src_step_ids : List[int] - The indices of the split steps to follow in the history - level : int - Use the length in this split level - factor_or_nparts : bool + The indices of the split steps to follow in the history. + level : Int + Use the length in this split level. + factor_or_nparts : Bool True to use `factor` for split from inner to outer, - False to use `nparts` for split from outer to inner + False to use `nparts` for split from outer to inner. Returns ------- res_its : List[Iterator] - The splitted new Iterators + The splitted new Iterators. """ self.state_object, res = _ffi_api.StateFollowFusedSplit(self.state_object, diff --git a/src/auto_scheduler/loop_state.h b/src/auto_scheduler/loop_state.h index d8ea9d467753..cc2fd2d698cb 100644 --- a/src/auto_scheduler/loop_state.h +++ b/src/auto_scheduler/loop_state.h @@ -353,9 +353,27 @@ class State : public ObjectRef { */ Array split(int stage_id, const Iterator& it, const Array>& lengths, bool inner_to_outer = true); - - Array follow_split(int stage_id, const Iterator& it, int src_step_id, int n_split); + /********** Step APIs working on multiple stages **********/ + /*! + * \brief Schedule primitive corresponds to te.follow_split. + * \param stage_id The index of the stage to be split. + * \param it The iterator to be split. + * \param src_step_id The index of the split step to follow in the history. + * \param n_split The number of split level. + * \return The splitted new Iterators. + */ + Array follow_split(int stage_id, const Iterator& it, int src_step_id, int n_split); + /*! + * \brief Schedule primitive corresponds to te.follow_split. + * \param stage_id The index of the stage to be split. + * \param it The iterator to be split. + * \param src_step_ids The indices of the split steps to follow in the history. + * \param level Use the length in this split level. + * \param factor_or_nparts True to use `factor` for split from inner to outer, + False to use `nparts` for split from outer to inner. + * \return The splitted new Iterators. + */ Array follow_fused_split(int stage_id, const Iterator& it, const Array& src_step_ids, int level, bool factor_or_nparts); /*! diff --git a/src/auto_scheduler/transform_step.cc b/src/auto_scheduler/transform_step.cc index d8fbbb221dab..715bbe97eb09 100644 --- a/src/auto_scheduler/transform_step.cc +++ b/src/auto_scheduler/transform_step.cc @@ -86,6 +86,10 @@ Step StepReadFromRecord(dmlc::JSONReader* reader) { return ReorderStep(reader); } else if (name == SplitStepNode::record_prefix_str) { return SplitStep(reader); + } else if (name == FollowSplitStepNode::record_prefix_str) { + return FollowSplitStep(reader); + } else if (name == FollowFusedSplitStepNode::record_prefix_str) { + return FollowFusedSplitStep(reader); } else if (name == ComputeAtStepNode::record_prefix_str) { return ComputeAtStep(reader); } else if (name == ComputeInlineStepNode::record_prefix_str) { @@ -96,12 +100,7 @@ Step StepReadFromRecord(dmlc::JSONReader* reader) { return CacheReadStep(reader); } else if (name == CacheWriteStepNode::record_prefix_str) { return CacheWriteStep(reader); - } else if (name == FollowSplitStepNode::record_prefix_str) { - return FollowSplitStep(reader); - } else if (name == FollowFusedSplitStepNode::record_prefix_str) { - return FollowFusedSplitStep(reader); - } - else { + } else { LOG(FATAL) << "Invalid step format: " << name; } return Step(); @@ -116,6 +115,10 @@ void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag) { ps->ApplyToState(state); } else if (auto ps = step.as()) { ps->ApplyToState(state); + } else if(auto ps = step.as()){ + ps->ApplyToState(state); + } else if(auto ps = step.as()){ + ps->ApplyToState(state); } else if (auto ps = step.as()) { ps->ApplyToState(state); } else if (auto ps = step.as()) { @@ -126,10 +129,6 @@ void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag) { ps->ApplyToState(state, dag); } else if (auto ps = step.as()) { ps->ApplyToState(state, dag); - } else if(auto ps = step.as()){ - ps->ApplyToState(state); - } else if(auto ps = step.as()){ - ps->ApplyToState(state); } else { LOG(FATAL) << "Invalid step: " << step; } @@ -145,6 +144,10 @@ void StepApplyToSchedule(const Step& step, Array* stages, StageToAxes ps->ApplyToSchedule(stages, stage_to_axes); } else if (auto ps = step.as()) { ps->ApplyToSchedule(stages, stage_to_axes); + } else if(auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes, transform_steps); + } else if(auto ps = step.as()){ + ps->ApplyToSchedule(stages, stage_to_axes, transform_steps); } else if (auto ps = step.as()) { ps->ApplyToSchedule(stages, stage_to_axes); } else if (auto ps = step.as()) { @@ -155,10 +158,6 @@ void StepApplyToSchedule(const Step& step, Array* stages, StageToAxes ps->ApplyToSchedule(stages, stage_to_axes, schedule); } else if (auto ps = step.as()) { ps->ApplyToSchedule(stages, stage_to_axes, schedule); - } else if(auto ps = step.as()) { - ps->ApplyToSchedule(stages, stage_to_axes, transform_steps); - } else if(auto ps = step.as()){ - ps->ApplyToSchedule(stages, stage_to_axes, transform_steps); } else { LOG(FATAL) << "Invalid Step: " << step; } @@ -174,6 +173,10 @@ String StepPrintAsPythonAPI(const Step& step, Array* stages, return ps->PrintAsPythonAPI(stages, stage_to_axes); } else if (auto ps = step.as()) { return ps->PrintAsPythonAPI(stages, stage_to_axes); + } else if(auto ps = step.as()){ + return ps->PrintAsPythonAPI(stages, stage_to_axes, transform_steps); + } else if(auto ps = step.as()){ + return ps->PrintAsPythonAPI(stages, stage_to_axes, transform_steps); } else if (auto ps = step.as()) { return ps->PrintAsPythonAPI(stages, stage_to_axes); } else if (auto ps = step.as()) { @@ -184,11 +187,7 @@ String StepPrintAsPythonAPI(const Step& step, Array* stages, return ps->PrintAsPythonAPI(stages, stage_to_axes, schedule); } else if (auto ps = step.as()) { return ps->PrintAsPythonAPI(stages, stage_to_axes, schedule); - } else if(auto ps = step.as()){ - return ps->PrintAsPythonAPI(stages, stage_to_axes, transform_steps); - } else if(auto ps = step.as()){ - return ps->PrintAsPythonAPI(stages, stage_to_axes, transform_steps); - } else { + } else { LOG(FATAL) << "Invalid Step: " << step; } return ""; From d3969b8d5d8152c2644771e537d7796858c9df42 Mon Sep 17 00:00:00 2001 From: "jingbang.yjb" Date: Thu, 23 Jul 2020 13:54:01 +0800 Subject: [PATCH 09/31] Check code using c-lint Signed-off-by: jingbang.yjb --- src/auto_scheduler/loop_state.cc | 50 ++++++++++----------- src/auto_scheduler/loop_state.h | 4 +- src/auto_scheduler/transform_step.cc | 66 ++++++++++++++-------------- src/auto_scheduler/transform_step.h | 47 +++++++++----------- 4 files changed, 80 insertions(+), 87 deletions(-) diff --git a/src/auto_scheduler/loop_state.cc b/src/auto_scheduler/loop_state.cc index 8db726d17f95..7aaddf129f6d 100644 --- a/src/auto_scheduler/loop_state.cc +++ b/src/auto_scheduler/loop_state.cc @@ -239,24 +239,23 @@ Iterator State::vectorize(int stage_id, const Iterator& it) { return step->ApplyToState(this); } -Array State::follow_split(int stage_id, const Iterator& it, - int src_step_id, int n_split) { +Array State::follow_split(int stage_id, const Iterator& it, int src_step_id, + int n_split) { const Stage& stage = operator->()->stages[stage_id]; - FollowSplitStep step = FollowSplitStep( - stage_id, GetIndex(stage->iters, it), src_step_id, n_split); + FollowSplitStep step = + FollowSplitStep(stage_id, GetIndex(stage->iters, it), src_step_id, n_split); CopyOnWrite()->transform_steps.push_back(step); return step->ApplyToState(this); } -Array State::follow_fused_split( - int stage_id, const Iterator& it, const Array& src_step_ids, - int level, bool factor_or_nparts) { +Array State::follow_fused_split(int stage_id, const Iterator& it, + const Array& src_step_ids, int level, + bool factor_or_nparts) { const Stage& stage = operator->()->stages[stage_id]; - FollowFusedSplitStep step = - FollowFusedSplitStep(stage_id, GetIndex(stage->iters, it), - src_step_ids, level, factor_or_nparts); + FollowFusedSplitStep step = FollowFusedSplitStep(stage_id, GetIndex(stage->iters, it), + src_step_ids, level, factor_or_nparts); CopyOnWrite()->transform_steps.push_back(step); return step->ApplyToState(this); } @@ -459,24 +458,23 @@ TVM_REGISTER_GLOBAL("auto_scheduler.StateVectorize") }); TVM_REGISTER_GLOBAL("auto_scheduler.StateFollowSplit") -.set_body_typed([](State state, int stage_id, const Iterator& it, - int src_step_id, int n_split) { - const auto& res = state.follow_split(stage_id, it, src_step_id, n_split); - return Array{state, Array(res)}; -}); + .set_body_typed([](State state, int stage_id, const Iterator& it, int src_step_id, + int n_split) { + const auto& res = state.follow_split(stage_id, it, src_step_id, n_split); + return Array{state, Array(res)}; + }); TVM_REGISTER_GLOBAL("auto_scheduler.StateFollowFusedSplit") -.set_body_typed([](State state, int stage_id, const Iterator& it, - const Array& src_step_ids, int level, - bool factor_or_nparts) { - Array array_src_step_ids; - for (const auto& i : src_step_ids) { - array_src_step_ids.push_back(i->value); - } - const auto& res = state.follow_fused_split( - stage_id, it, array_src_step_ids, level, factor_or_nparts); - return Array{state, Array(res)}; -}); + .set_body_typed([](State state, int stage_id, const Iterator& it, + const Array& src_step_ids, int level, bool factor_or_nparts) { + Array array_src_step_ids; + for (const auto& i : src_step_ids) { + array_src_step_ids.push_back(i->value); + } + const auto& res = + state.follow_fused_split(stage_id, it, array_src_step_ids, level, factor_or_nparts); + return Array{state, Array(res)}; + }); TVM_REGISTER_GLOBAL("auto_scheduler.StateFuse") .set_body_typed([](State state, int stage_id, const Array& iters) { diff --git a/src/auto_scheduler/loop_state.h b/src/auto_scheduler/loop_state.h index cc2fd2d698cb..0d8823bb3ef6 100644 --- a/src/auto_scheduler/loop_state.h +++ b/src/auto_scheduler/loop_state.h @@ -374,7 +374,9 @@ class State : public ObjectRef { False to use `nparts` for split from outer to inner. * \return The splitted new Iterators. */ - Array follow_fused_split(int stage_id, const Iterator& it, const Array& src_step_ids, int level, bool factor_or_nparts); + Array follow_fused_split(int stage_id, const Iterator& it, + const Array& src_step_ids, int level, + bool factor_or_nparts); /*! * \brief Schedule primitive corresponds to te.compute_at. diff --git a/src/auto_scheduler/transform_step.cc b/src/auto_scheduler/transform_step.cc index 715bbe97eb09..5b26c5e98e3a 100644 --- a/src/auto_scheduler/transform_step.cc +++ b/src/auto_scheduler/transform_step.cc @@ -115,9 +115,9 @@ void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag) { ps->ApplyToState(state); } else if (auto ps = step.as()) { ps->ApplyToState(state); - } else if(auto ps = step.as()){ + } else if (auto ps = step.as()) { ps->ApplyToState(state); - } else if(auto ps = step.as()){ + } else if (auto ps = step.as()) { ps->ApplyToState(state); } else if (auto ps = step.as()) { ps->ApplyToState(state); @@ -144,9 +144,9 @@ void StepApplyToSchedule(const Step& step, Array* stages, StageToAxes ps->ApplyToSchedule(stages, stage_to_axes); } else if (auto ps = step.as()) { ps->ApplyToSchedule(stages, stage_to_axes); - } else if(auto ps = step.as()) { + } else if (auto ps = step.as()) { ps->ApplyToSchedule(stages, stage_to_axes, transform_steps); - } else if(auto ps = step.as()){ + } else if (auto ps = step.as()) { ps->ApplyToSchedule(stages, stage_to_axes, transform_steps); } else if (auto ps = step.as()) { ps->ApplyToSchedule(stages, stage_to_axes); @@ -164,7 +164,8 @@ void StepApplyToSchedule(const Step& step, Array* stages, StageToAxes } String StepPrintAsPythonAPI(const Step& step, Array* stages, - StageToAxesMap* stage_to_axes, te::Schedule* schedule, const Array& transform_steps) { + StageToAxesMap* stage_to_axes, te::Schedule* schedule, + const Array& transform_steps) { if (auto ps = step.as()) { return ps->PrintAsPythonAPI(stages, stage_to_axes); } else if (auto ps = step.as()) { @@ -173,9 +174,9 @@ String StepPrintAsPythonAPI(const Step& step, Array* stages, return ps->PrintAsPythonAPI(stages, stage_to_axes); } else if (auto ps = step.as()) { return ps->PrintAsPythonAPI(stages, stage_to_axes); - } else if(auto ps = step.as()){ + } else if (auto ps = step.as()) { return ps->PrintAsPythonAPI(stages, stage_to_axes, transform_steps); - } else if(auto ps = step.as()){ + } else if (auto ps = step.as()) { return ps->PrintAsPythonAPI(stages, stage_to_axes, transform_steps); } else if (auto ps = step.as()) { return ps->PrintAsPythonAPI(stages, stage_to_axes); @@ -187,7 +188,7 @@ String StepPrintAsPythonAPI(const Step& step, Array* stages, return ps->PrintAsPythonAPI(stages, stage_to_axes, schedule); } else if (auto ps = step.as()) { return ps->PrintAsPythonAPI(stages, stage_to_axes, schedule); - } else { + } else { LOG(FATAL) << "Invalid Step: " << step; } return ""; @@ -793,8 +794,7 @@ String SplitStepNode::PrintAsPythonAPI(Array* stages, } /********** Follow Split **********/ -FollowSplitStep::FollowSplitStep(int stage_id, int iter_id, - int src_step_id, int n_split) { +FollowSplitStep::FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split) { auto node = make_object(); node->stage_id = stage_id; node->iter_id = iter_id; @@ -812,9 +812,8 @@ void FollowSplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { writer->WriteArrayItem(n_split); } -void FollowSplitStepNode::ExtractSplitLengths( - const Array& transform_steps, - Array>* lengths) const { +void FollowSplitStepNode::ExtractSplitLengths(const Array& transform_steps, + Array>* lengths) const { CHECK_LT(src_step_id, transform_steps.size()); auto ps = transform_steps[src_step_id].as(); CHECK(ps != nullptr); @@ -866,30 +865,30 @@ Array FollowSplitStepNode::ApplyToState(State* state) const { return ApplySplitToState(state, stage_id, iter_id, lengths, true); } -Array FollowSplitStepNode::ApplyToSchedule( - Array *stages, StageToAxesMap *stage_to_axes, - const Array& transform_steps) const { +Array FollowSplitStepNode::ApplyToSchedule(Array* stages, + StageToAxesMap* stage_to_axes, + const Array& transform_steps) const { Array> lengths; ExtractSplitLengths(transform_steps, &lengths); - return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, - lengths, true); + return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, lengths, true); } -String FollowSplitStepNode::PrintAsPythonAPI( - Array *stages, StageToAxesMap *stage_to_axes, const Array& transform_steps) const { +String FollowSplitStepNode::PrintAsPythonAPI(Array* stages, + StageToAxesMap* stage_to_axes, + const Array& transform_steps) const { Array> lengths; ExtractSplitLengths(transform_steps, &lengths); - return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, - lengths, true); + return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, lengths, true); } /********** Follow Fused Split **********/ FollowFusedSplitStep::FollowFusedSplitStep(int stage_id, int iter_id, - const Array& src_step_ids, int level, bool factor_or_nparts) { + const Array& src_step_ids, int level, + bool factor_or_nparts) { auto node = make_object(); node->stage_id = stage_id; node->iter_id = iter_id; - node->src_step_ids = src_step_ids;; + node->src_step_ids = src_step_ids; node->level = level; node->factor_or_nparts = factor_or_nparts; data_ = std::move(node); @@ -955,22 +954,21 @@ Array FollowFusedSplitStepNode::ApplyToState(State* state) const { return ApplySplitToState(state, stage_id, iter_id, {length}, factor_or_nparts); } -Array FollowFusedSplitStepNode::ApplyToSchedule( - Array *stages, StageToAxesMap *stage_to_axes, - const Array& transform_steps) const { +Array FollowFusedSplitStepNode::ApplyToSchedule(Array* stages, + StageToAxesMap* stage_to_axes, + const Array& transform_steps) const { const Optional& length = ExtractSplitLength(transform_steps); - return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, - {length}, factor_or_nparts); + return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, {length}, factor_or_nparts); } -String FollowFusedSplitStepNode::PrintAsPythonAPI( - Array *stages, StageToAxesMap *stage_to_axes, const Array& transform_steps) const { +String FollowFusedSplitStepNode::PrintAsPythonAPI(Array* stages, + StageToAxesMap* stage_to_axes, + const Array& transform_steps) const { const Optional& length = ExtractSplitLength(transform_steps); - return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, - {length}, factor_or_nparts); + return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, {length}, + factor_or_nparts); } - /********** Primitives working on multiple stages **********/ /********** Compute At **********/ diff --git a/src/auto_scheduler/transform_step.h b/src/auto_scheduler/transform_step.h index 0a1044f38988..3d1456eeb088 100644 --- a/src/auto_scheduler/transform_step.h +++ b/src/auto_scheduler/transform_step.h @@ -413,25 +413,23 @@ class ReorderStep : public Step { /*! \brief Similar to SplitStepNode, but use split factor from another stepf * (i.e. Follow another split step) */ -class FollowSplitStepNode: public StepNode { +class FollowSplitStepNode : public StepNode { public: int iter_id; // The id of the iter to split int src_step_id; // The index of the split step to follow in the history int n_split; // The number of split level - + void WriteToRecord(dmlc::JSONWriter* writer) const final; - void ExtractSplitLengths( const Array& transform_steps, - Array>* lengths) const; + void ExtractSplitLengths(const Array& transform_steps, + Array>* lengths) const; Array ApplyToState(State* state) const; - Array ApplyToSchedule(Array *stages, - StageToAxesMap *stage_to_axes, - const Array& transform_steps) const; + Array ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes, + const Array& transform_steps) const; - String PrintAsPythonAPI(Array *stages, - StageToAxesMap *stage_to_axes, + String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes, const Array& transform_steps) const; static constexpr const char* record_prefix_str = "FSP"; static constexpr const char* _type_key = "auto_scheduler.FollowSplitStep"; @@ -445,7 +443,7 @@ class FollowSplitStepNode: public StepNode { class FollowSplitStep : public Step { public: FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split); - + explicit FollowSplitStep(dmlc::JSONReader* reader); TVM_DEFINE_OBJECT_REF_METHODS(FollowSplitStep, Step, FollowSplitStepNode); @@ -454,26 +452,24 @@ class FollowSplitStep : public Step { /*! \brief Similar to FollowSplitStep, but use split factors from multiple steps. * \Note This can be used for the split in cooperative fetching */ -class FollowFusedSplitStepNode: public StepNode { +class FollowFusedSplitStepNode : public StepNode { public: - int iter_id; // The id of the iter to split + int iter_id; // The id of the iter to split Array src_step_ids; // The indices of the split steps to follow in the history - int level; // Use the length in this split level - bool factor_or_nparts; // If this is true, use factor. Otherwise, use nparts - + int level; // Use the length in this split level + bool factor_or_nparts; // If this is true, use factor. Otherwise, use nparts + void WriteToRecord(dmlc::JSONWriter* writer) const final; - Optional ExtractSplitLength( const Array& transform_steps) const; + Optional ExtractSplitLength(const Array& transform_steps) const; Array ApplyToState(State* state) const; - Array ApplyToSchedule(Array *stages, - StageToAxesMap *stage_to_axes, - const Array& transform_steps) const; + Array ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes, + const Array& transform_steps) const; - String PrintAsPythonAPI(Array *stages, - StageToAxesMap *stage_to_axes, - const Array& transform_steps) const; + String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes, + const Array& transform_steps) const; static constexpr const char* record_prefix_str = "FFSP"; static constexpr const char* _type_key = "auto_scheduler.FollowFusedSplitStep"; @@ -486,10 +482,9 @@ class FollowFusedSplitStepNode: public StepNode { */ class FollowFusedSplitStep : public Step { public: - FollowFusedSplitStep(int stage_id, int iter_id, - const Array& src_step_ids, - int level, bool factor_or_nparts); - + FollowFusedSplitStep(int stage_id, int iter_id, const Array& src_step_ids, int level, + bool factor_or_nparts); + explicit FollowFusedSplitStep(dmlc::JSONReader* reader); TVM_DEFINE_OBJECT_REF_METHODS(FollowFusedSplitStep, Step, FollowFusedSplitStepNode); From f20952538ae3b3eb209772d092bf800abeea255b Mon Sep 17 00:00:00 2001 From: "jingbang.yjb" Date: Thu, 23 Jul 2020 17:34:32 +0800 Subject: [PATCH 10/31] Add more doc strings and change the order for follow split. Signed-off-by: jingbang.yjb --- python/tvm/auto_scheduler/loop_state.py | 43 ++++-- src/auto_scheduler/loop_state.cc | 80 +++++------ src/auto_scheduler/loop_state.h | 1 - src/auto_scheduler/transform_step.h | 174 +++++++++++++----------- 4 files changed, 170 insertions(+), 128 deletions(-) diff --git a/python/tvm/auto_scheduler/loop_state.py b/python/tvm/auto_scheduler/loop_state.py index 5904ed9cd588..9072c46a7dab 100644 --- a/python/tvm/auto_scheduler/loop_state.py +++ b/python/tvm/auto_scheduler/loop_state.py @@ -304,8 +304,17 @@ def split(self, stage, iterator, lengths, inner_to_outer=True): return res def follow_split(self, stage, iterator, src_step_id, n_split): - """ - Schedule primitive corresponds to te.follow_split. + """ Schedule primitive extends to split step. + + This step is used to follow a former SplitStep, keeps their iterator structures to be same. + + Example cases: + With subgraph: Dense -> Relu + Some tiling structures are used in Relu stage and we intend to compute the Dense + stage at Relu. + The follow_split is used here to keep their outer most few iterators the same for + applying compute at. + Parameters ---------- stage : Union[int, Operation, Tensor] @@ -313,9 +322,9 @@ def follow_split(self, stage, iterator, src_step_id, n_split): or output tensor of the stage. iterator : Iterator The iterator to split. - src_step_id : Int + src_step_id : int The index of the split step to follow in the history. - n_split : Int + n_split : int The number of split level. Returns @@ -332,17 +341,35 @@ def follow_split(self, stage, iterator, src_step_id, n_split): def follow_fused_split(self, stage, iterator, src_step_ids, level, factor_or_nparts): - """ - Schedule primitive corresponds to te.follow_fused_split. + """ Schedule primitive extends to split step. + + This step is used to follow several former SplitSteps and FuseSteps. + + Example cases: + With subgraph in GPU schedule: Input -> Dense + for i.0@j.0 = ... : Bind to blockIdx.x + for i.1@j.1 = ... : Bind to threadIdx.x + for i.2@j.2 = ... + Input_shared = Input ... + for k = ... + Dense = ... + We intend to apply cooperative fetching with the Input stage, while the threadIdx.x + axis is binded to a iterator generated by split & fuse step. + The follow_fused_step is used here to figure out the final extent of the threadIdx.x + binded iterator. + Parameters ---------- + stage : Union[int, Operation, Tensor] + The Stage to be split, which can be specified by the integer index, Operation, + or output tensor of the stage. iterator : Iterator The iterator to split. src_step_ids : List[int] The indices of the split steps to follow in the history. - level : Int + level : int Use the length in this split level. - factor_or_nparts : Bool + factor_or_nparts : bool True to use `factor` for split from inner to outer, False to use `nparts` for split from outer to inner. diff --git a/src/auto_scheduler/loop_state.cc b/src/auto_scheduler/loop_state.cc index 7aaddf129f6d..804256917dec 100644 --- a/src/auto_scheduler/loop_state.cc +++ b/src/auto_scheduler/loop_state.cc @@ -239,27 +239,6 @@ Iterator State::vectorize(int stage_id, const Iterator& it) { return step->ApplyToState(this); } -Array State::follow_split(int stage_id, const Iterator& it, int src_step_id, - int n_split) { - const Stage& stage = operator->()->stages[stage_id]; - - FollowSplitStep step = - FollowSplitStep(stage_id, GetIndex(stage->iters, it), src_step_id, n_split); - CopyOnWrite()->transform_steps.push_back(step); - return step->ApplyToState(this); -} - -Array State::follow_fused_split(int stage_id, const Iterator& it, - const Array& src_step_ids, int level, - bool factor_or_nparts) { - const Stage& stage = operator->()->stages[stage_id]; - - FollowFusedSplitStep step = FollowFusedSplitStep(stage_id, GetIndex(stage->iters, it), - src_step_ids, level, factor_or_nparts); - CopyOnWrite()->transform_steps.push_back(step); - return step->ApplyToState(this); -} - Iterator State::fuse(int stage_id, const Array& iters) { const Stage& stage = operator->()->stages[stage_id]; Array indices; @@ -290,6 +269,27 @@ Array State::split(int stage_id, const Iterator& it, return step->ApplyToState(this); } +Array State::follow_split(int stage_id, const Iterator& it, int src_step_id, + int n_split) { + const Stage& stage = operator->()->stages[stage_id]; + + FollowSplitStep step = + FollowSplitStep(stage_id, GetIndex(stage->iters, it), src_step_id, n_split); + CopyOnWrite()->transform_steps.push_back(step); + return step->ApplyToState(this); +} + +Array State::follow_fused_split(int stage_id, const Iterator& it, + const Array& src_step_ids, int level, + bool factor_or_nparts) { + const Stage& stage = operator->()->stages[stage_id]; + + FollowFusedSplitStep step = FollowFusedSplitStep(stage_id, GetIndex(stage->iters, it), + src_step_ids, level, factor_or_nparts); + CopyOnWrite()->transform_steps.push_back(step); + return step->ApplyToState(this); +} + void State::compute_at(int stage_id, int target_stage_id, const Iterator& target_iter) { const Stage& target_stage = operator->()->stages[target_stage_id]; ComputeAtStep step = @@ -457,25 +457,6 @@ TVM_REGISTER_GLOBAL("auto_scheduler.StateVectorize") return Array{state, res}; }); -TVM_REGISTER_GLOBAL("auto_scheduler.StateFollowSplit") - .set_body_typed([](State state, int stage_id, const Iterator& it, int src_step_id, - int n_split) { - const auto& res = state.follow_split(stage_id, it, src_step_id, n_split); - return Array{state, Array(res)}; - }); - -TVM_REGISTER_GLOBAL("auto_scheduler.StateFollowFusedSplit") - .set_body_typed([](State state, int stage_id, const Iterator& it, - const Array& src_step_ids, int level, bool factor_or_nparts) { - Array array_src_step_ids; - for (const auto& i : src_step_ids) { - array_src_step_ids.push_back(i->value); - } - const auto& res = - state.follow_fused_split(stage_id, it, array_src_step_ids, level, factor_or_nparts); - return Array{state, Array(res)}; - }); - TVM_REGISTER_GLOBAL("auto_scheduler.StateFuse") .set_body_typed([](State state, int stage_id, const Array& iters) { const auto& res = state.fuse(stage_id, iters); @@ -495,6 +476,25 @@ TVM_REGISTER_GLOBAL("auto_scheduler.StateSplit") return Array{state, res}; }); +TVM_REGISTER_GLOBAL("auto_scheduler.StateFollowSplit") + .set_body_typed([](State state, int stage_id, const Iterator& it, int src_step_id, + int n_split) { + const auto& res = state.follow_split(stage_id, it, src_step_id, n_split); + return Array{state, Array(res)}; + }); + +TVM_REGISTER_GLOBAL("auto_scheduler.StateFollowFusedSplit") + .set_body_typed([](State state, int stage_id, const Iterator& it, + const Array& src_step_ids, int level, bool factor_or_nparts) { + Array array_src_step_ids; + for (const auto& i : src_step_ids) { + array_src_step_ids.push_back(i->value); + } + const auto& res = + state.follow_fused_split(stage_id, it, array_src_step_ids, level, factor_or_nparts); + return Array{state, Array(res)}; + }); + TVM_REGISTER_GLOBAL("auto_scheduler.StateComputeAt") .set_body_typed([](State state, int stage_id, int target_stage_id, const Iterator& target_iter) { diff --git a/src/auto_scheduler/loop_state.h b/src/auto_scheduler/loop_state.h index 0d8823bb3ef6..da63b8a11d37 100644 --- a/src/auto_scheduler/loop_state.h +++ b/src/auto_scheduler/loop_state.h @@ -377,7 +377,6 @@ class State : public ObjectRef { Array follow_fused_split(int stage_id, const Iterator& it, const Array& src_step_ids, int level, bool factor_or_nparts); - /*! * \brief Schedule primitive corresponds to te.compute_at. * \param stage_id The index of the stage to be compute at. diff --git a/src/auto_scheduler/transform_step.h b/src/auto_scheduler/transform_step.h index 3d1456eeb088..0ddeb9b20c3f 100644 --- a/src/auto_scheduler/transform_step.h +++ b/src/auto_scheduler/transform_step.h @@ -411,85 +411,6 @@ class ReorderStep : public Step { TVM_DEFINE_OBJECT_REF_METHODS(ReorderStep, Step, ReorderStepNode); }; -/*! \brief Similar to SplitStepNode, but use split factor from another stepf - * (i.e. Follow another split step) */ -class FollowSplitStepNode : public StepNode { - public: - int iter_id; // The id of the iter to split - int src_step_id; // The index of the split step to follow in the history - int n_split; // The number of split level - - void WriteToRecord(dmlc::JSONWriter* writer) const final; - - void ExtractSplitLengths(const Array& transform_steps, - Array>* lengths) const; - - Array ApplyToState(State* state) const; - - Array ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes, - const Array& transform_steps) const; - - String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes, - const Array& transform_steps) const; - static constexpr const char* record_prefix_str = "FSP"; - static constexpr const char* _type_key = "auto_scheduler.FollowSplitStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(FollowSplitStepNode, Object); -}; - -/*! - * \brief Managed reference to FollowSplitStepNode. - * \sa FollowSplitStepNode - */ -class FollowSplitStep : public Step { - public: - FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split); - - explicit FollowSplitStep(dmlc::JSONReader* reader); - - TVM_DEFINE_OBJECT_REF_METHODS(FollowSplitStep, Step, FollowSplitStepNode); -}; - -/*! \brief Similar to FollowSplitStep, but use split factors from multiple steps. - * \Note This can be used for the split in cooperative fetching - */ -class FollowFusedSplitStepNode : public StepNode { - public: - int iter_id; // The id of the iter to split - Array src_step_ids; // The indices of the split steps to follow in the history - int level; // Use the length in this split level - bool factor_or_nparts; // If this is true, use factor. Otherwise, use nparts - - void WriteToRecord(dmlc::JSONWriter* writer) const final; - - Optional ExtractSplitLength(const Array& transform_steps) const; - - Array ApplyToState(State* state) const; - - Array ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes, - const Array& transform_steps) const; - - String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes, - const Array& transform_steps) const; - - static constexpr const char* record_prefix_str = "FFSP"; - static constexpr const char* _type_key = "auto_scheduler.FollowFusedSplitStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(FollowFusedSplitStepNode, Object); -}; - -/*! - * \brief Managed reference to FollowFusedSplitStepNode. - * \sa FollowFusedSplitStepNode - */ -class FollowFusedSplitStep : public Step { - public: - FollowFusedSplitStep(int stage_id, int iter_id, const Array& src_step_ids, int level, - bool factor_or_nparts); - - explicit FollowFusedSplitStep(dmlc::JSONReader* reader); - - TVM_DEFINE_OBJECT_REF_METHODS(FollowFusedSplitStep, Step, FollowFusedSplitStepNode); -}; - /*! * \brief Split step that corresponds to te::Stage::split with additional * support of multiple-level of factors @@ -569,6 +490,101 @@ class SplitStep : public Step { TVM_DEFINE_OBJECT_REF_METHODS(SplitStep, Step, SplitStepNode); }; +/*! \brief Similar to SplitStepNode, but use split factor from another stepf + * (i.e. Follow another split step) */ +class FollowSplitStepNode : public StepNode { + public: + /*! \brief The id of the iter to split. */ + int iter_id; + /*! \brief The index of the split step to follow in the history. */ + int src_step_id; + /*! \brief The number of split level. */ + int n_split; + + void WriteToRecord(dmlc::JSONWriter* writer) const final; + + /*! + * \brief Extract split lengths. + * \param transform_steps An array record all transform steps. + * \param lengths The multiple split factors. Can be None to be filled by search policy. + */ + void ExtractSplitLengths(const Array& transform_steps, + Array>* lengths) const; + + Array ApplyToState(State* state) const; + + Array ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes, + const Array& transform_steps) const; + + String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes, + const Array& transform_steps) const; + + static constexpr const char* record_prefix_str = "FSP"; + + static constexpr const char* _type_key = "auto_scheduler.FollowSplitStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(FollowSplitStepNode, Object); +}; + +/*! + * \brief Managed reference to FollowSplitStepNode. + * \sa FollowSplitStepNode + */ +class FollowSplitStep : public Step { + public: + FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split); + + explicit FollowSplitStep(dmlc::JSONReader* reader); + + TVM_DEFINE_OBJECT_REF_METHODS(FollowSplitStep, Step, FollowSplitStepNode); +}; + +/*! \brief Similar to FollowSplitStep, but use split factors from multiple steps. + * \Note This can be used for the split in cooperative fetching + */ +class FollowFusedSplitStepNode : public StepNode { + public: + int iter_id; // The id of the iter to split + Array src_step_ids; // The indices of the split steps to follow in the history + int level; // Use the length in this split level + bool factor_or_nparts; // If this is true, use factor. Otherwise, use nparts + + void WriteToRecord(dmlc::JSONWriter* writer) const final; + + /*! + * \brief Extract split length. + * \param transform_steps An array record all transform steps. + * \return Split factor. + */ + Optional ExtractSplitLength(const Array& transform_steps) const; + + Array ApplyToState(State* state) const; + + Array ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes, + const Array& transform_steps) const; + + String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes, + const Array& transform_steps) const; + + static constexpr const char* record_prefix_str = "FFSP"; + + static constexpr const char* _type_key = "auto_scheduler.FollowFusedSplitStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(FollowFusedSplitStepNode, Object); +}; + +/*! + * \brief Managed reference to FollowFusedSplitStepNode. + * \sa FollowFusedSplitStepNode + */ +class FollowFusedSplitStep : public Step { + public: + FollowFusedSplitStep(int stage_id, int iter_id, const Array& src_step_ids, int level, + bool factor_or_nparts); + + explicit FollowFusedSplitStep(dmlc::JSONReader* reader); + + TVM_DEFINE_OBJECT_REF_METHODS(FollowFusedSplitStep, Step, FollowFusedSplitStepNode); +}; + /********** Primitives working on multiple stages **********/ /*! \brief Compute at step that corresponds to te::Stage::compute_at */ From 50f7c4a50d91c95814790c8ffc9f4a8299e26e6c Mon Sep 17 00:00:00 2001 From: "jingbang.yjb" Date: Thu, 23 Jul 2020 19:11:18 +0800 Subject: [PATCH 11/31] Add record test for follow_split and follow_fused_split Signed-off-by: jingbang.yjb --- .../unittest/test_auto_scheduler_measure.py | 29 ++++++++++++------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/tests/python/unittest/test_auto_scheduler_measure.py b/tests/python/unittest/test_auto_scheduler_measure.py index 039dff921144..4071bd7f42ac 100644 --- a/tests/python/unittest/test_auto_scheduler_measure.py +++ b/tests/python/unittest/test_auto_scheduler_measure.py @@ -22,8 +22,7 @@ from tvm import te, auto_scheduler import tempfile -from test_auto_scheduler_common import get_tiled_matmul - +from test_auto_scheduler_common import matmul_auto_scheduler_test, get_tiled_matmul def test_record(): if not tvm.runtime.enabled("llvm"): @@ -37,20 +36,28 @@ def test_record(): k = te.reduce_axis((0, 512), name='k') E = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * D[k][j], axis=[k]), name='E') F = topi.nn.relu(E) + I, H, G = matmul_auto_scheduler_test(512, 512, 512) dag = auto_scheduler.ComputeDAG([A, B, F]) s = dag.get_init_state() - - # Follow split - C_global = s0.cache_write(C, "global") - split_step0 = len(s0.transform_steps) - 1 - split_step1 = len(s0.transform_steps) - 1 - s.follow_split(C_global, tmp[C_global].iters[0], split_step0, 0) - # Follow fused split - s.follow_fused_split(C_global, tmp[C_global].iters[0], - [split_step0, split_step1], 0, False) + dag_follow = auto_scheduler.ComputeDAG([G, H, I]) + s_follow = dag_follow.get_init_state() + # Split its0 = s.split(C, s[C].iters[0], [4, 8, 8]) its1 = s.split(C, s[C].iters[4], [8, 4, 4]) + # Follow split + G_global = s_follow.cache_write(G, "global") + its2 = s_follow.split(G, s_follow[G].iters[0], [4, 2, 8, 4], True) + split_step0 = len(s_follow.transform_steps) - 1 + tmp = s_follow.copy() + tmp.follow_split(G_global, tmp[G_global].iters[0], split_step0, 1) + its3 = s_follow.split(G, s_follow[G].iters[5], [2, 2, 4, 8]) + split_step1 = len(s_follow.transform_steps) - 1 + + # Follow fused split + tmp = s_follow.copy() + tmp.follow_fused_split(G_global, tmp[G_global].iters[0], + [split_step0, split_step1], 0, False) # Reorder s.reorder(C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], its0[3], s[C].iters[8], its1[3]]) From 7bf8dd5190890212809b78f4bb5cf96f4f24e720 Mon Sep 17 00:00:00 2001 From: "jingbang.yjb" Date: Fri, 24 Jul 2020 07:58:26 +0800 Subject: [PATCH 12/31] Add record test for follow_split Signed-off-by: jingbang.yjb --- src/auto_scheduler/loop_state.h | 5 ++-- .../unittest/test_auto_scheduler_measure.py | 25 +++++++------------ 2 files changed, 12 insertions(+), 18 deletions(-) diff --git a/src/auto_scheduler/loop_state.h b/src/auto_scheduler/loop_state.h index da63b8a11d37..cf0f35dce899 100644 --- a/src/auto_scheduler/loop_state.h +++ b/src/auto_scheduler/loop_state.h @@ -353,8 +353,6 @@ class State : public ObjectRef { */ Array split(int stage_id, const Iterator& it, const Array>& lengths, bool inner_to_outer = true); - /********** Step APIs working on multiple stages **********/ - /*! * \brief Schedule primitive corresponds to te.follow_split. * \param stage_id The index of the stage to be split. @@ -377,6 +375,9 @@ class State : public ObjectRef { Array follow_fused_split(int stage_id, const Iterator& it, const Array& src_step_ids, int level, bool factor_or_nparts); + + /********** Step APIs working on multiple stages **********/ + /*! * \brief Schedule primitive corresponds to te.compute_at. * \param stage_id The index of the stage to be compute at. diff --git a/tests/python/unittest/test_auto_scheduler_measure.py b/tests/python/unittest/test_auto_scheduler_measure.py index 4071bd7f42ac..72b8804f0b4d 100644 --- a/tests/python/unittest/test_auto_scheduler_measure.py +++ b/tests/python/unittest/test_auto_scheduler_measure.py @@ -36,28 +36,16 @@ def test_record(): k = te.reduce_axis((0, 512), name='k') E = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * D[k][j], axis=[k]), name='E') F = topi.nn.relu(E) - I, H, G = matmul_auto_scheduler_test(512, 512, 512) + k = te.reduce_axis((0, 512), name='k') + G = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * F[k][j], axis=[k]), name='G') + H = topi.nn.relu(G) + dag = auto_scheduler.ComputeDAG([A, B, F]) s = dag.get_init_state() - dag_follow = auto_scheduler.ComputeDAG([G, H, I]) - s_follow = dag_follow.get_init_state() # Split its0 = s.split(C, s[C].iters[0], [4, 8, 8]) its1 = s.split(C, s[C].iters[4], [8, 4, 4]) - # Follow split - G_global = s_follow.cache_write(G, "global") - its2 = s_follow.split(G, s_follow[G].iters[0], [4, 2, 8, 4], True) - split_step0 = len(s_follow.transform_steps) - 1 - tmp = s_follow.copy() - tmp.follow_split(G_global, tmp[G_global].iters[0], split_step0, 1) - its3 = s_follow.split(G, s_follow[G].iters[5], [2, 2, 4, 8]) - split_step1 = len(s_follow.transform_steps) - 1 - - # Follow fused split - tmp = s_follow.copy() - tmp.follow_fused_split(G_global, tmp[G_global].iters[0], - [split_step0, split_step1], 0, False) # Reorder s.reorder(C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], its0[3], s[C].iters[8], its1[3]]) @@ -85,6 +73,11 @@ def test_record(): s.compute_at(D_global, E, s[E].iters[2]) # Cache Write s.cache_write(D, "shared") + #follow_split + s.split(F, s[F].iters[0], [2]) + split_step0 = len(s.transform_steps) - 1 + s.follow_split(F, s[F].iters[0], split_step0, 1) + #follow_fused_split target = tvm.target.create("llvm") task = auto_scheduler.SearchTask(dag, "test", target) From 98d943b8e3332a197f1b501725d2fd6e7a5016c8 Mon Sep 17 00:00:00 2001 From: "jingbang.yjb" Date: Fri, 24 Jul 2020 08:21:59 +0800 Subject: [PATCH 13/31] Add record test for follow_fused_split. Signed-off-by: jingbang.yjb --- tests/python/unittest/test_auto_scheduler_measure.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/python/unittest/test_auto_scheduler_measure.py b/tests/python/unittest/test_auto_scheduler_measure.py index 72b8804f0b4d..6ec88b714c67 100644 --- a/tests/python/unittest/test_auto_scheduler_measure.py +++ b/tests/python/unittest/test_auto_scheduler_measure.py @@ -40,7 +40,8 @@ def test_record(): G = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * F[k][j], axis=[k]), name='G') H = topi.nn.relu(G) - dag = auto_scheduler.ComputeDAG([A, B, F]) + #dag = auto_scheduler.ComputeDAG([A, B, F]) + dag = auto_scheduler.ComputeDAG([A, B, G]) s = dag.get_init_state() # Split @@ -74,10 +75,13 @@ def test_record(): # Cache Write s.cache_write(D, "shared") #follow_split - s.split(F, s[F].iters[0], [2]) + its2 = s.split(G, s[G].iters[0], [4, 2, 8, 4], True) split_step0 = len(s.transform_steps) - 1 - s.follow_split(F, s[F].iters[0], split_step0, 1) + s.follow_split(G, s[G].iters[0], split_step0, 1) #follow_fused_split + its1 = s.split(G, s[G].iters[5], [2, 2, 4, 8]) + split_step1 = len(s.transform_steps) - 1 + s.follow_fused_split(G, s[G].iters[0], [split_step0, split_step1], 0, False) target = tvm.target.create("llvm") task = auto_scheduler.SearchTask(dag, "test", target) From 296cb367f0b54e08cf4634cd4ceefc89a12f23b3 Mon Sep 17 00:00:00 2001 From: "jingbang.yjb" Date: Fri, 24 Jul 2020 08:39:08 +0800 Subject: [PATCH 14/31] Add test record for follow_fused_split 1. delete a comment 2. add "fuse" between follow_split and follow_fused_split Signed-off-by: jingbang.yjb --- tests/python/unittest/test_auto_scheduler_measure.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_auto_scheduler_measure.py b/tests/python/unittest/test_auto_scheduler_measure.py index 6ec88b714c67..e3b31ed9d936 100644 --- a/tests/python/unittest/test_auto_scheduler_measure.py +++ b/tests/python/unittest/test_auto_scheduler_measure.py @@ -40,7 +40,6 @@ def test_record(): G = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * F[k][j], axis=[k]), name='G') H = topi.nn.relu(G) - #dag = auto_scheduler.ComputeDAG([A, B, F]) dag = auto_scheduler.ComputeDAG([A, B, G]) s = dag.get_init_state() @@ -79,8 +78,14 @@ def test_record(): split_step0 = len(s.transform_steps) - 1 s.follow_split(G, s[G].iters[0], split_step0, 1) #follow_fused_split - its1 = s.split(G, s[G].iters[5], [2, 2, 4, 8]) + its3 = s.split(G, s[G].iters[5], [2, 2, 4, 8]) split_step1 = len(s.transform_steps) - 1 + its = [] + for i0, i1 in zip(its2, its3): + its.append(i0) + its.append(i1) + for i in range(0, 5): + s.fuse(G, [s[G].iters[i], s[G].iters[i + 1]]) s.follow_fused_split(G, s[G].iters[0], [split_step0, split_step1], 0, False) target = tvm.target.create("llvm") From a7b12946961f0bec88678346c1ce54781f7055af Mon Sep 17 00:00:00 2001 From: "jingbang.yjb" Date: Fri, 24 Jul 2020 09:06:15 +0800 Subject: [PATCH 15/31] Add doc strings for some functions and variables Signed-off-by: jingbang.yjb --- src/auto_scheduler/transform_step.h | 73 +++++++++++++++++++++++++++-- 1 file changed, 68 insertions(+), 5 deletions(-) diff --git a/src/auto_scheduler/transform_step.h b/src/auto_scheduler/transform_step.h index 0ddeb9b20c3f..21a3d71b7322 100644 --- a/src/auto_scheduler/transform_step.h +++ b/src/auto_scheduler/transform_step.h @@ -511,11 +511,28 @@ class FollowSplitStepNode : public StepNode { void ExtractSplitLengths(const Array& transform_steps, Array>* lengths) const; + /*! + * \brief Apply the current step to State. + * \param state A mutable pointer to State. + */ Array ApplyToState(State* state) const; + /*! + * \brief Apply the current step to tvm.schedule. + * \param stages A pointer to a `te::Stage` Array. + * \param stage_to_axes A pointer to a StageToAxesMap. + * \param transform_steps An array record all transform steps. + */ Array ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes, const Array& transform_steps) const; + /*! + * \brief Print the current step as equivalent python schedule API. + * \param stages A pointer to a `te::Stage` Array. + * \param stage_to_axes A pointer to a StageToAxesMap. + * \param transform_steps An array record all transform steps. + * \return Python schedule code. + */ String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes, const Array& transform_steps) const; @@ -531,8 +548,20 @@ class FollowSplitStepNode : public StepNode { */ class FollowSplitStep : public Step { public: + /*! + * \brief The constructor. + * \param stage_id The index of the stage to be split. + * \param iter_id The index of the iterator to be split. + * \param src_step_id The index of the split step to follow in the history. + * \param n_split The number of split level. + */ FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split); - + + /*! + * \brief The constructor used to read a step record from JSONReader and create the + * corresponding step. + * \param reader The input JSONReader. + */ explicit FollowSplitStep(dmlc::JSONReader* reader); TVM_DEFINE_OBJECT_REF_METHODS(FollowSplitStep, Step, FollowSplitStepNode); @@ -543,10 +572,14 @@ class FollowSplitStep : public Step { */ class FollowFusedSplitStepNode : public StepNode { public: - int iter_id; // The id of the iter to split - Array src_step_ids; // The indices of the split steps to follow in the history - int level; // Use the length in this split level - bool factor_or_nparts; // If this is true, use factor. Otherwise, use nparts + /*! \brief The id of the iter to split. */ + int iter_id; + /*! \brief The indices of the split steps to follow in the history. */ + Array src_step_ids; + /*! \brief Use the length in this split level. */ + int level; + /*! \brief If this is true, use factor. Otherwise, use nparts. */ + bool factor_or_nparts; void WriteToRecord(dmlc::JSONWriter* writer) const final; @@ -557,11 +590,28 @@ class FollowFusedSplitStepNode : public StepNode { */ Optional ExtractSplitLength(const Array& transform_steps) const; + /*! + * \brief Apply the current step to State. + * \param state A mutable pointer to State. + */ Array ApplyToState(State* state) const; + /*! + * \brief Apply the current step to tvm.schedule. + * \param stages A pointer to a `te::Stage` Array. + * \param stage_to_axes A pointer to a StageToAxesMap. + * \param transform_steps An array record all transform steps. + */ Array ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes, const Array& transform_steps) const; + /*! + * \brief Print the current step as equivalent python schedule API. + * \param stages A pointer to a `te::Stage` Array. + * \param stage_to_axes A pointer to a StageToAxesMap. + * \param transform_steps An array record all transform steps. + * \return Python schedule code. + */ String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes, const Array& transform_steps) const; @@ -577,9 +627,22 @@ class FollowFusedSplitStepNode : public StepNode { */ class FollowFusedSplitStep : public Step { public: + /*! + * \brief The constructor. + * \param stage_id The index of the stage to be split. + * \param iter_id The index of the iterator to be split. + * \param src_step_ids An array of index for split step to follow in the history. + * \param level Use the length in this split level. + * \param factor_or_nparts If this is true, use factor. Otherwise, use nparts. + */ FollowFusedSplitStep(int stage_id, int iter_id, const Array& src_step_ids, int level, bool factor_or_nparts); + /*! + * \brief The constructor used to read a step record from JSONReader and create the + * corresponding step. + * \param reader The input JSONReader. + */ explicit FollowFusedSplitStep(dmlc::JSONReader* reader); TVM_DEFINE_OBJECT_REF_METHODS(FollowFusedSplitStep, Step, FollowFusedSplitStepNode); From 5220a681402289e42af2aae338ea8092bbe62ceb Mon Sep 17 00:00:00 2001 From: "jingbang.yjb" Date: Fri, 24 Jul 2020 09:15:49 +0800 Subject: [PATCH 16/31] Fix the code format in src/auto_scheduler/transform_step.h Signed-off-by: jingbang.yjb --- src/auto_scheduler/transform_step.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/auto_scheduler/transform_step.h b/src/auto_scheduler/transform_step.h index 21a3d71b7322..23c6321334a5 100644 --- a/src/auto_scheduler/transform_step.h +++ b/src/auto_scheduler/transform_step.h @@ -556,7 +556,7 @@ class FollowSplitStep : public Step { * \param n_split The number of split level. */ FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split); - + /*! * \brief The constructor used to read a step record from JSONReader and create the * corresponding step. @@ -573,7 +573,7 @@ class FollowSplitStep : public Step { class FollowFusedSplitStepNode : public StepNode { public: /*! \brief The id of the iter to split. */ - int iter_id; + int iter_id; /*! \brief The indices of the split steps to follow in the history. */ Array src_step_ids; /*! \brief Use the length in this split level. */ From 2a113d3c408a3cbe90bc7f672d675c462c8c7289 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Fri, 24 Jul 2020 10:02:02 +0800 Subject: [PATCH 17/31] Update --- python/tvm/auto_scheduler/loop_state.py | 13 +++++++---- src/auto_scheduler/loop_state.h | 23 +++++++++++-------- src/auto_scheduler/transform_step.cc | 15 +++++++----- src/auto_scheduler/transform_step.h | 15 +++++------- .../test_auto_scheduler_loop_state.py | 5 ++-- 5 files changed, 39 insertions(+), 32 deletions(-) diff --git a/python/tvm/auto_scheduler/loop_state.py b/python/tvm/auto_scheduler/loop_state.py index 0311ee39c37d..d41a95851889 100644 --- a/python/tvm/auto_scheduler/loop_state.py +++ b/python/tvm/auto_scheduler/loop_state.py @@ -352,7 +352,10 @@ def compute_root(self, stage): self._resolve_stage_id(stage)) def cache_read(self, stage, scope_name, reader_stages): - """ Schedule primitive corresponds to te.schedule.cache_read. + """ Schedule primitive corresponds to `te.schedule.cache_read`. + + See also `te.schedule.cache_read` for more details. + Parameters ---------- @@ -360,7 +363,7 @@ def cache_read(self, stage, scope_name, reader_stages): The Stage to be cache read, which can be specified by the integer index, Operation, or output tensor of the stage. scope_name : str - The scope name to be set for the new added read stage. + The scope name of the newly added read stage. reader_stages : List[Union[int, Operation, Tensor]] The reader stages. Each of the list can be specified by the integer index, Operation, or output tensor of the stage. @@ -387,7 +390,9 @@ def cache_read(self, stage, scope_name, reader_stages): return self.stages[int(new_stage_id)].op def cache_write(self, stage, scope_name): - """ Schedule primitive corresponds to te.schedule.cache_write. + """ Schedule primitive corresponds to `te.schedule.cache_write`. + + See also `te.schedule.cache_write` for more details. Parameters ---------- @@ -395,7 +400,7 @@ def cache_write(self, stage, scope_name): The Stage to be cache write, which can be specified by the integer index, Operation, or output tensor of the stage. scope_name : str - The scope name to be set for the new added write stage. + The scope name of the newly added compute stage. Returns ------- diff --git a/src/auto_scheduler/loop_state.h b/src/auto_scheduler/loop_state.h index 3cab133f3c25..6b91bc9c7516 100644 --- a/src/auto_scheduler/loop_state.h +++ b/src/auto_scheduler/loop_state.h @@ -199,7 +199,7 @@ class AttachMap : public ObjectRef { /*! * \brief Traverse through `stage_to_attach_iter` and `iter_to_attached_stages` map, add offset - * to stage indexes that are larger than the start_id. Used for steps that inserts net stages to + * to stage indexes that are larger than the start_id. Used for steps that inserts new stages to * ComputeDAG(e.g. CacheRead/CacheWrite step). * \param start_id The index threshold, stage indexes in AttachMap which are larger than this * will be applied the extra offset. @@ -240,9 +240,12 @@ class StateNode : public Object { AttachMap attach_map; /*! * \brief The up-to-date ComputeDAG of this state, used for some steps that may change the - * stage structure of the ComputeDAG (e.g. CacheReadStep/CacheWriteStep which Will be added - * later). - * The default value is an empty NullOpt. (means no modification to the original DAG) + * stage structure of the ComputeDAG (e.g. CacheReadStep/CacheWriteStep). This will alway be kept + * up-to-date, while the original ComputeDAG may not be up-to-date. + * The default value is an empty NullOpt, means no modification to the original DAG. + * Typical usage for this is when acquiring information from ComputeDAG (e.g. check for its + * AccessAnalyzer), use the `current_compute_dag` first, if it's Null, use the original + * ComputeDAG. */ Optional current_compute_dag; /*! @@ -358,7 +361,7 @@ class State : public ObjectRef { /*! * \brief Schedule primitive corresponds to te.compute_at. - * \param stage_id The index of the stage to be compute at. + * \param stage_id The index of the stage to be computed at. * \param target_stage_id The index of stage that this step will compute at to. * \param target_iter The iterator in target stage that this step will compute at to. * \note After compute_at, we need careful dependency analysis to compute the accurate bound @@ -369,12 +372,12 @@ class State : public ObjectRef { void compute_at(int stage_id, int target_stage_id, const Iterator& target_iter); /*! * \brief Schedule primitive corresponds to te.compute_inline. - * \param stage_id The index of the stage to be compute inlined. + * \param stage_id The index of the stage to be marked compute inlined. */ void compute_inline(int stage_id); /*! * \brief Schedule primitive corresponds to te.compute_root. - * \param stage_id The index of the stage to be compute root. + * \param stage_id The index of the stage to be the compute root. * \note After compute_root, we need careful dependency analysis to compute the accurate bound * information. However, it is relatively expensive and complicated, so we just fill "None" as * bound for the newly created iterators. @@ -387,8 +390,8 @@ class State : public ObjectRef { /*! * \brief Schedule primitive corresponds to te.schedule.cache_read. * \param stage_id The index of the stage to be cache read. - * \param scope_name The scope name to be set for the new added read stage. - * \param reader_stage_ids The indexes of reader stages. + * \param scope_name The scope name of the newly added read stage. + * \param reader_stage_ids The indices of read stages. * \param dag The original ComputeDAG of this state. * \note Cache read step will add an extra stage to the original ComputeDAG (at the back of the * target stage), a up-to-date ComputeDAG is stored in State's `current_compute_dag`. @@ -398,7 +401,7 @@ class State : public ObjectRef { /*! * \brief Schedule primitive corresponds to te.schedule.cache_write. * \param stage_id The index of the stage to be cache write. - * \param scope_name The scope name to be set for the new added write stage. + * \param scope_name The scope name of the newly added compute stage. * \param dag The original ComputeDAG of this state. * \note Cache write step will add an extra stage to the original ComputeDAG (in the front of the * target stage), a up-to-date ComputeDAG is stored in State's `current_compute_dag`. diff --git a/src/auto_scheduler/transform_step.cc b/src/auto_scheduler/transform_step.cc index e18154940271..e63591d1be36 100644 --- a/src/auto_scheduler/transform_step.cc +++ b/src/auto_scheduler/transform_step.cc @@ -962,15 +962,18 @@ String ComputeRootStepNode::PrintAsPythonAPI(Array* stages, /*! * \brief Common part for steps that add new stages(e.g. CacheReadStep, CacheWriteStep, - * RfactorStep). This will filter out all steps that can change the stages of ComputeDAG. + * RfactorStep). This will filter out all steps that can change the number of stages in a + * ComputeDAG, and stop by the current step. */ -Array GetStageModifiableSteps(Step current_step, const Array& transform_steps) { +Array GetFormerStageModifiableSteps(Step current_step, const Array& transform_steps) { Array ret_steps; for (const Step& step : transform_steps) { if (step->IsInstance() || step->IsInstance()) { ret_steps.push_back(step); } // TODO(jcf94): add rfactor support + // A state may have multiple stage modifiable steps, stop by the current step to avoid + // replaying excess steps if (step.same_as(current_step)) { break; } @@ -1022,8 +1025,8 @@ void CacheReadStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { int CacheReadStepNode::ApplyToState(State* state, const ComputeDAG& dag) const { StateNode* pstate = state->CopyOnWrite(); - const ComputeDAG& current_compute_dag = - dag.ReplayAndGetDAG(GetStageModifiableSteps(GetRef(this), (*state)->transform_steps)); + const ComputeDAG& current_compute_dag = dag.ReplayAndGetDAG( + GetFormerStageModifiableSteps(GetRef(this), (*state)->transform_steps)); // target_stage -> target_stage + target_store // Update the op of the target stage, insert a new cache read stage behind, update the op of @@ -1131,8 +1134,8 @@ int CacheWriteStepNode::ApplyToState(State* state, const ComputeDAG& dag) const int last_dag_op_size = pstate->current_compute_dag ? pstate->current_compute_dag.value().as()->ops.size() : dag->ops.size(); - const ComputeDAG& current_compute_dag = - dag.ReplayAndGetDAG(GetStageModifiableSteps(GetRef(this), (*state)->transform_steps)); + const ComputeDAG& current_compute_dag = dag.ReplayAndGetDAG( + GetFormerStageModifiableSteps(GetRef(this), (*state)->transform_steps)); int added_ops = current_compute_dag->ops.size() - last_dag_op_size; // TODO(jcf94): Update this check to equal after fixing the cache write bug in TVM CHECK_GE(added_ops, 1); diff --git a/src/auto_scheduler/transform_step.h b/src/auto_scheduler/transform_step.h index e1746189c29e..3f2b14e5b71a 100644 --- a/src/auto_scheduler/transform_step.h +++ b/src/auto_scheduler/transform_step.h @@ -680,9 +680,9 @@ class ComputeRootStep : public Step { */ class CacheReadStepNode : public StepNode { public: - /*! \brief The scope name to be set for the new added read stage. (e.g. local, shared, global) */ + /*! \brief The scope name of the newly added read stage. (e.g. local, shared, global) */ String scope_name; - /*! \brief The indexes of reader stages. */ + /*! \brief The indices of read stages. */ Array reader_stage_ids; void WriteToRecord(dmlc::JSONWriter* writer) const final; @@ -730,8 +730,8 @@ class CacheReadStep : public Step { /*! * \brief The constructor. * \param stage_id The index of the stage to be cache read. - * \param scope_name The scope name to be set for the new added read stage. - * \param reader_stage_ids The indexes of reader stages. + * \param scope_name The scope name of the newly added read stage. + * \param reader_stage_ids The indices of read stages. */ CacheReadStep(int stage_id, String scope_name, const Array& reader_stage_ids); @@ -753,10 +753,7 @@ class CacheReadStep : public Step { */ class CacheWriteStepNode : public StepNode { public: - /*! - * \brief The scope name to be set for the new added write stage. (e.g. local, shared, - * global) - */ + /*! \brief The scope name of the newly added compute stage. (e.g. local, shared, global) */ String scope_name; void WriteToRecord(dmlc::JSONWriter* writer) const final; @@ -804,7 +801,7 @@ class CacheWriteStep : public Step { /*! * \brief The constructor. * \param stage_id The index of the stage to be cache write. - * \param scope_name The scope name to be set for the new added write stage. + * \param scope_name The scope name of the newly added compute stage. */ CacheWriteStep(int stage_id, String scope_name); diff --git a/tests/python/unittest/test_auto_scheduler_loop_state.py b/tests/python/unittest/test_auto_scheduler_loop_state.py index 8c9d635b526c..8282d4a40e5e 100644 --- a/tests/python/unittest/test_auto_scheduler_loop_state.py +++ b/tests/python/unittest/test_auto_scheduler_loop_state.py @@ -341,9 +341,8 @@ def test_cache_read_write(): # \ / # ----------------> kernel_split ----------------> # - # Seems there's bug with the input/output tensor. Such multi outputs case - # should be unusual, so we make some hack on DoCacheWrite - # To be fixed in the future + # TODO(jcf94): Seems there's bug with the input/output tensor. Such multi outputs case + # should be unusual, so we make some hack on DoCacheWrite. This should be fixed later. kernel_split_global = s0.cache_write(kernel_split, "global") """ Placeholder: Data, Kernel_data From bf660a8ce143de4dc846c741993165d3af000db5 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Fri, 24 Jul 2020 10:39:42 +0800 Subject: [PATCH 18/31] Update doc --- python/tvm/auto_scheduler/loop_state.py | 39 +++++++++++++++---------- src/auto_scheduler/loop_state.h | 24 +++++++-------- 2 files changed, 36 insertions(+), 27 deletions(-) diff --git a/python/tvm/auto_scheduler/loop_state.py b/python/tvm/auto_scheduler/loop_state.py index d41a95851889..f2e53244566a 100644 --- a/python/tvm/auto_scheduler/loop_state.py +++ b/python/tvm/auto_scheduler/loop_state.py @@ -127,7 +127,8 @@ def stage_ops(self): return [stage.op for stage in self.stages] def bind(self, stage, iterator, thread_name): - """ Schedule primitive corresponds to te.bind. + """ Schedule primitive corresponds to `te.Stage.bind`, see also the `te.Stage` for more + details. Parameters ---------- @@ -160,7 +161,8 @@ def bind(self, stage, iterator, thread_name): return res def parallel(self, stage, iterator): - """ Schedule primitive corresponds to te.parallel. + """ Schedule primitive corresponds to `te.Stage.parallel`, see also the `te.Stage` for more + details. Parameters ---------- @@ -180,7 +182,8 @@ def parallel(self, stage, iterator): return res def unroll(self, stage, iterator, max_unroll=None): - """ Schedule primitive corresponds to te.unroll. + """ Schedule primitive corresponds to `te.Stage.unroll`, see also the `te.Stage` for more + details. Parameters ---------- @@ -203,7 +206,8 @@ def unroll(self, stage, iterator, max_unroll=None): return res def vectorize(self, stage, iterator): - """ Schedule primitive corresponds to te.vectorize. + """ Schedule primitive corresponds to `te.Stage.vectorize`, see also the `te.Stage` for + more details. Parameters ---------- @@ -223,7 +227,8 @@ def vectorize(self, stage, iterator): return res def fuse(self, stage, iters): - """ Schedule primitive corresponds to te.fuse. + """ Schedule primitive corresponds to `te.Stage.fuse`, see also the `te.Stage` for more + details. Parameters ---------- @@ -248,7 +253,8 @@ def fuse(self, stage, iters): return res def reorder(self, stage, order): - """ Schedule primitive corresponds to te.reorder. + """ Schedule primitive corresponds to `te.Stage.reorder`, see also the `te.Stage` for more + details. Parameters ---------- @@ -262,7 +268,8 @@ def reorder(self, stage, order): order) def split(self, stage, iterator, lengths, inner_to_outer=True): - """ Schedule primitive corresponds to te.split. + """ Schedule primitive corresponds to `te.Stage.split`, see also the `te.Stage` for more + details. This API supports multiple split factors. (e.g. with 2 split factors, the original iterator will be split to 3 parts, use `inner_to_outer` to control the split order) @@ -295,7 +302,8 @@ def split(self, stage, iterator, lengths, inner_to_outer=True): return res def compute_at(self, stage, target_stage, target_iter): - """ Schedule primitive corresponds to te.compute_at. + """ Schedule primitive corresponds to `te.Stage.compute_at`, see also the `te.Stage` for + more details. Parameters ---------- @@ -321,7 +329,8 @@ def compute_at(self, stage, target_stage, target_iter): target_iter) def compute_inline(self, stage): - """ Schedule primitive corresponds to te.compute_inline. + """ Schedule primitive corresponds to `te.Stage.compute_inline`, see also the `te.Stage` + for more details. Parameters ---------- @@ -333,7 +342,8 @@ def compute_inline(self, stage): self._resolve_stage_id(stage)) def compute_root(self, stage): - """ Schedule primitive corresponds to te.compute_root. + """ Schedule primitive corresponds to `te.Stage.compute_root`, see also the `te.Stage` for + more details. Parameters ---------- @@ -352,11 +362,11 @@ def compute_root(self, stage): self._resolve_stage_id(stage)) def cache_read(self, stage, scope_name, reader_stages): - """ Schedule primitive corresponds to `te.schedule.cache_read`. + """ Schedule primitive corresponds to `te.Schedule.cache_read`, see also the `te.Schedule` + for more details. See also `te.schedule.cache_read` for more details. - Parameters ---------- stage : Union[int, Operation, Tensor] @@ -390,9 +400,8 @@ def cache_read(self, stage, scope_name, reader_stages): return self.stages[int(new_stage_id)].op def cache_write(self, stage, scope_name): - """ Schedule primitive corresponds to `te.schedule.cache_write`. - - See also `te.schedule.cache_write` for more details. + """ Schedule primitive corresponds to `te.Schedule.cache_write`, see also the `te.Schedule` + for more details. Parameters ---------- diff --git a/src/auto_scheduler/loop_state.h b/src/auto_scheduler/loop_state.h index 6b91bc9c7516..bb9485331591 100644 --- a/src/auto_scheduler/loop_state.h +++ b/src/auto_scheduler/loop_state.h @@ -298,7 +298,7 @@ class State : public ObjectRef { /********** Step APIs working on single stage **********/ /*! - * \brief Schedule primitive corresponds to te.bind. + * \brief Schedule primitive corresponds to `te::Stage::bind`. * \param stage_id The index of the stage to be binded. * \param it The iterator to be binded. * \param thread_type The thread type to be binded. We dirctly use the IteratorAnnotation as @@ -307,14 +307,14 @@ class State : public ObjectRef { */ Iterator bind(int stage_id, const Iterator& it, IteratorAnnotation thread_type); /*! - * \brief Schedule primitive corresponds to te.parallel. + * \brief Schedule primitive corresponds to `te::Stage::parallel`. * \param stage_id The index of the stage to be paralleled. * \param it The iterator to be paralleled. * \return The iterator result after parallel. */ Iterator parallel(int stage_id, const Iterator& it); /*! - * \brief Schedule primitive corresponds to te.unroll. + * \brief Schedule primitive corresponds to `te::Stage::unroll`. * \param stage_id The index of the stage to be unrolled. * \param it The iterator to be unrolled. * \param max_unroll The max unroll limit. Iterator with extent larger than this limit will be @@ -323,14 +323,14 @@ class State : public ObjectRef { */ Iterator unroll(int stage_id, const Iterator& it, int max_unroll = -1); /*! - * \brief Schedule primitive corresponds to te.vectorize. + * \brief Schedule primitive corresponds to `te::Stage::vectorize`. * \param stage_id The index of the stage to be vectorized. * \param it The iterator to be vectorized. * \return The iterator result after vectorize. */ Iterator vectorize(int stage_id, const Iterator& it); /*! - * \brief Schedule primitive corresponds to te.fuse. + * \brief Schedule primitive corresponds to `te::Stage::fuse`. * \param stage_id The index of the stage to be fused. * \param iters The iterators to be fused. * \return The iterator result after fuse. @@ -339,13 +339,13 @@ class State : public ObjectRef { */ Iterator fuse(int stage_id, const Array& iters); /*! - * \brief Schedule primitive corresponds to te.reorder. + * \brief Schedule primitive corresponds to `te::Stage::reorder`. * \param stage_id The index of the stage to be reordered. * \param order The expected iterator order. */ void reorder(int stage_id, const Array& order); /*! - * \brief Schedule primitive corresponds to te.split. + * \brief Schedule primitive corresponds to `te::Stage::split`. * \param stage_id The index of the stage to be split. * \param it The iterator to be split. * \param lengths The multiple split factors. Can be None to be filled by search policy. @@ -360,7 +360,7 @@ class State : public ObjectRef { /********** Step APIs working on multiple stages **********/ /*! - * \brief Schedule primitive corresponds to te.compute_at. + * \brief Schedule primitive corresponds to `te::Stage::compute_at`. * \param stage_id The index of the stage to be computed at. * \param target_stage_id The index of stage that this step will compute at to. * \param target_iter The iterator in target stage that this step will compute at to. @@ -371,12 +371,12 @@ class State : public ObjectRef { */ void compute_at(int stage_id, int target_stage_id, const Iterator& target_iter); /*! - * \brief Schedule primitive corresponds to te.compute_inline. + * \brief Schedule primitive corresponds to `te::Stage::compute_inline`. * \param stage_id The index of the stage to be marked compute inlined. */ void compute_inline(int stage_id); /*! - * \brief Schedule primitive corresponds to te.compute_root. + * \brief Schedule primitive corresponds to `te::Stage::compute_root`. * \param stage_id The index of the stage to be the compute root. * \note After compute_root, we need careful dependency analysis to compute the accurate bound * information. However, it is relatively expensive and complicated, so we just fill "None" as @@ -388,7 +388,7 @@ class State : public ObjectRef { /********** Step APIs adding new stages **********/ /*! - * \brief Schedule primitive corresponds to te.schedule.cache_read. + * \brief Schedule primitive corresponds to `te::Schedule::cache_read`. * \param stage_id The index of the stage to be cache read. * \param scope_name The scope name of the newly added read stage. * \param reader_stage_ids The indices of read stages. @@ -399,7 +399,7 @@ class State : public ObjectRef { int cache_read(int stage_id, const String& scope_name, const Array& reader_stage_ids, const ComputeDAG& dag); /*! - * \brief Schedule primitive corresponds to te.schedule.cache_write. + * \brief Schedule primitive corresponds to `te::Schedule::cache_write`. * \param stage_id The index of the stage to be cache write. * \param scope_name The scope name of the newly added compute stage. * \param dag The original ComputeDAG of this state. From 3649e26b8cd27ba61529962a4ea03bd921ce15f4 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Fri, 24 Jul 2020 10:43:46 +0800 Subject: [PATCH 19/31] Update --- python/tvm/auto_scheduler/loop_state.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/tvm/auto_scheduler/loop_state.py b/python/tvm/auto_scheduler/loop_state.py index f2e53244566a..66090b7343c9 100644 --- a/python/tvm/auto_scheduler/loop_state.py +++ b/python/tvm/auto_scheduler/loop_state.py @@ -365,8 +365,6 @@ def cache_read(self, stage, scope_name, reader_stages): """ Schedule primitive corresponds to `te.Schedule.cache_read`, see also the `te.Schedule` for more details. - See also `te.schedule.cache_read` for more details. - Parameters ---------- stage : Union[int, Operation, Tensor] From 85da7e0769cf9ff87667e347ca475b3ed5ff8942 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Fri, 24 Jul 2020 11:20:50 +0800 Subject: [PATCH 20/31] Update --- python/tvm/auto_scheduler/compute_dag.py | 3 +- src/auto_scheduler/loop_state.h | 12 +-- src/auto_scheduler/transform_step.cc | 4 +- src/auto_scheduler/transform_step.h | 104 +++++++++++------------ 4 files changed, 60 insertions(+), 63 deletions(-) diff --git a/python/tvm/auto_scheduler/compute_dag.py b/python/tvm/auto_scheduler/compute_dag.py index 7d8856a6b4e7..e08454fb1d09 100644 --- a/python/tvm/auto_scheduler/compute_dag.py +++ b/python/tvm/auto_scheduler/compute_dag.py @@ -131,7 +131,8 @@ def infer_bound_from_state(self, state): """ state_obj = state if isinstance(state, StateObject) else state.state_object updated_state = State(_ffi_api.ComputeDAGInferBoundFromState(self, state_obj), self) - # Copy the stage_id_map from the original state + # Copy the stage_id_map from the original state to make sure the old indices are still + # valid if isinstance(state, State): for k, v in state.stage_id_map.items(): updated_state.stage_id_map[k] = v diff --git a/src/auto_scheduler/loop_state.h b/src/auto_scheduler/loop_state.h index bb9485331591..225fd91b9429 100644 --- a/src/auto_scheduler/loop_state.h +++ b/src/auto_scheduler/loop_state.h @@ -238,14 +238,10 @@ class StateNode : public Object { * operation. */ AttachMap attach_map; - /*! - * \brief The up-to-date ComputeDAG of this state, used for some steps that may change the - * stage structure of the ComputeDAG (e.g. CacheReadStep/CacheWriteStep). This will alway be kept - * up-to-date, while the original ComputeDAG may not be up-to-date. - * The default value is an empty NullOpt, means no modification to the original DAG. - * Typical usage for this is when acquiring information from ComputeDAG (e.g. check for its - * AccessAnalyzer), use the `current_compute_dag` first, if it's Null, use the original - * ComputeDAG. + /*! \brief The up-to-date ComputeDAG of this state. The default value is an empty NullOpt, means + * no modification to the original ComputeDAG. + * Otherwise, it means some steps (e.g., CacheReadStep/CacheWriteStep) have modified the + * ComputeDAG, the stored value is the up-to-date ComputeDAG for this state. */ Optional current_compute_dag; /*! diff --git a/src/auto_scheduler/transform_step.cc b/src/auto_scheduler/transform_step.cc index e63591d1be36..42926fbbbc75 100644 --- a/src/auto_scheduler/transform_step.cc +++ b/src/auto_scheduler/transform_step.cc @@ -962,8 +962,8 @@ String ComputeRootStepNode::PrintAsPythonAPI(Array* stages, /*! * \brief Common part for steps that add new stages(e.g. CacheReadStep, CacheWriteStep, - * RfactorStep). This will filter out all steps that can change the number of stages in a - * ComputeDAG, and stop by the current step. + * RfactorStep). This will return all steps that can change the number of stages in a ComputeDAG, + * and stop by the current step. */ Array GetFormerStageModifiableSteps(Step current_step, const Array& transform_steps) { Array ret_steps; diff --git a/src/auto_scheduler/transform_step.h b/src/auto_scheduler/transform_step.h index 3f2b14e5b71a..3dc1ffb88588 100644 --- a/src/auto_scheduler/transform_step.h +++ b/src/auto_scheduler/transform_step.h @@ -192,7 +192,7 @@ Step StepReadFromRecord(dmlc::JSONReader* reader); /*! * \brief Apply the step to State. * \param step The step to be applied to State. - * \param state A mutable pointer to State. + * \param state A mutable pointer to state, which will be updated. * \param dag The original ComputeDAG of this state. * \return The iterator result after annotate. */ @@ -201,10 +201,10 @@ void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag); /*! * \brief Apply the step to tvm.schedule. * \param step The step to be applied to tvm.schedule. - * \param stages A mutable pointer to a `te::Stage` Array. - * \param stage_to_axes A mutable pointer to a StageToAxesMap. - * \param schedule A mutable pointer to a te::Schedule. This is required by some steps. (e.g. - * CacheRead/CacheWrite step) + * \param stages The `te::Stage`s used in TVM scheduler applying. + * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. + * \param schedule A mutable pointer to a `te::Schedule`. This is required by some steps which need + * `te::Schedule` API. (e.g. CacheRead/CacheWrite step) */ void StepApplyToSchedule(const Step& step, Array* stages, StageToAxesMap* stage_to_axes, te::Schedule* schedule); @@ -212,8 +212,8 @@ void StepApplyToSchedule(const Step& step, Array* stages, StageToAxes /*! * \brief Print the step as equivalent python schedule API. * \param step The step to be applied to python API. - * \param stages A mutable pointer to a `te::Stage` Array. - * \param stage_to_axes A mutable pointer to a StageToAxesMap. + * \param stages The `te::Stage`s used in TVM scheduler applying. + * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. * \param schedule A mutable pointer to a te::Schedule. This is required by some steps. (e.g. * CacheRead/CacheWrite step) * \return Python schedule code. @@ -238,22 +238,22 @@ class AnnotationStepNode : public StepNode { /*! * \brief Apply the current step to State. - * \param state A mutable pointer to State. + * \param state A mutable pointer to state, which will be updated. * \return The iterator result after annotate. */ Iterator ApplyToState(State* state) const; /*! * \brief Apply the current step to tvm.schedule. - * \param stages A pointer to a `te::Stage` Array. - * \param stage_to_axes A pointer to a StageToAxesMap. + * \param stages The `te::Stage`s used in TVM scheduler applying. + * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. */ void ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const; /*! * \brief Print the current step as equivalent python schedule API. - * \param stages A pointer to a `te::Stage` Array. - * \param stage_to_axes A pointer to a StageToAxesMap. + * \param stages The `te::Stage`s used in TVM scheduler applying. + * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. * \return Python schedule code. */ String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; @@ -298,7 +298,7 @@ class FuseStepNode : public StepNode { /*! * \brief Apply the current step to State. - * \param state A mutable pointer to State. + * \param state A mutable pointer to state, which will be updated. * \return The iterator result after fuse. * \note If the iterators to be fused have stages attached at them(by compute_at), the fused * result will become the new attach point. @@ -307,16 +307,16 @@ class FuseStepNode : public StepNode { /*! * \brief Apply the current step to tvm.schedule. - * \param stages A pointer to a `te::Stage` Array. - * \param stage_to_axes A pointer to a StageToAxesMap. + * \param stages The `te::Stage`s used in TVM scheduler applying. + * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. * \return The iterator result after fuse. */ tir::IterVar ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const; /*! * \brief Print the current step as equivalent python schedule API. - * \param stages A pointer to a `te::Stage` Array. - * \param stage_to_axes A pointer to a StageToAxesMap. + * \param stages The `te::Stage`s used in TVM scheduler applying. + * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. * \return Python schedule code. */ String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; @@ -363,21 +363,21 @@ class ReorderStepNode : public StepNode { /*! * \brief Apply the current step to State. - * \param state A mutable pointer to State. + * \param state A mutable pointer to state, which will be updated. */ void ApplyToState(State* state) const; /*! * \brief Apply the current step to tvm.schedule. - * \param stages A pointer to a `te::Stage` Array. - * \param stage_to_axes A pointer to a StageToAxesMap. + * \param stages The `te::Stage`s used in TVM scheduler applying. + * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. */ void ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const; /*! * \brief Print the current step as equivalent python schedule API. - * \param stages A pointer to a `te::Stage` Array. - * \param stage_to_axes A pointer to a StageToAxesMap. + * \param stages The `te::Stage`s used in TVM scheduler applying. + * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. * \return Python schedule code. */ String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; @@ -433,7 +433,7 @@ class SplitStepNode : public StepNode { /*! * \brief Apply the current step to State. - * \param state A mutable pointer to State. + * \param state A mutable pointer to state, which will be updated. * \return The iterator results after split. * \note If we do split on an iterator which has stages attached at it(by compute_at), the inner * most iterator of split results will become the new attach point. @@ -442,8 +442,8 @@ class SplitStepNode : public StepNode { /*! * \brief Apply the current step to tvm.schedule. - * \param stages A pointer to a `te::Stage` Array. - * \param stage_to_axes A pointer to a StageToAxesMap. + * \param stages The `te::Stage`s used in TVM scheduler applying. + * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. * \return The iterator results after split. */ Array ApplyToSchedule(Array* stages, @@ -451,8 +451,8 @@ class SplitStepNode : public StepNode { /*! * \brief Print the current step as equivalent python schedule API. - * \param stages A pointer to a `te::Stage` Array. - * \param stage_to_axes A pointer to a StageToAxesMap. + * \param stages The `te::Stage`s used in TVM scheduler applying. + * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. * \return Python schedule code. */ String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; @@ -504,7 +504,7 @@ class ComputeAtStepNode : public StepNode { /*! * \brief Apply the current step to State. - * \param state A mutable pointer to State. + * \param state A mutable pointer to state, which will be updated. * \note After compute_at, we need careful dependency analysis to compute the accurate bound * information. However, it is relatively expensive and complicated, so we just fill "None" as * bound for the newly created iterators. @@ -514,15 +514,15 @@ class ComputeAtStepNode : public StepNode { /*! * \brief Apply the current step to tvm.schedule. - * \param stages A pointer to a `te::Stage` Array. - * \param stage_to_axes A pointer to a StageToAxesMap. + * \param stages The `te::Stage`s used in TVM scheduler applying. + * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. */ void ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const; /*! * \brief Print the current step as equivalent python schedule API. - * \param stages A pointer to a `te::Stage` Array. - * \param stage_to_axes A pointer to a StageToAxesMap. + * \param stages The `te::Stage`s used in TVM scheduler applying. + * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. * \return Python schedule code. */ String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; @@ -564,22 +564,22 @@ class ComputeInlineStepNode : public StepNode { /*! * \brief Apply the current step to State. - * \param state A mutable pointer to State. + * \param state A mutable pointer to state, which will be updated. */ void ApplyToState(State* state) const; /*! * \brief Apply the current step to tvm.schedule. - * \param stages A pointer to a `te::Stage` Array. - * \param stage_to_axes A pointer to a StageToAxesMap. + * \param stages The `te::Stage`s used in TVM scheduler applying. + * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. * \return The iterator result after fuse. */ void ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const; /*! * \brief Print the current step as equivalent python schedule API. - * \param stages A pointer to a `te::Stage` Array. - * \param stage_to_axes A pointer to a StageToAxesMap. + * \param stages The `te::Stage`s used in TVM scheduler applying. + * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. * \return Python schedule code. */ String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; @@ -619,7 +619,7 @@ class ComputeRootStepNode : public StepNode { /*! * \brief Apply the current step to State. - * \param state A mutable pointer to State. + * \param state A mutable pointer to state, which will be updated. * \note After compute_at, we need careful dependency analysis to compute the accurate bound * information. However, it is relatively expensive and complicated, so we just fill "None" as * bound for the newly created iterators. @@ -629,16 +629,16 @@ class ComputeRootStepNode : public StepNode { /*! * \brief Apply the current step to tvm.schedule. - * \param stages A pointer to a `te::Stage` Array. - * \param stage_to_axes A pointer to a StageToAxesMap. + * \param stages The `te::Stage`s used in TVM scheduler applying. + * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. * \return The iterator result after fuse. */ void ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const; /*! * \brief Print the current step as equivalent python schedule API. - * \param stages A pointer to a `te::Stage` Array. - * \param stage_to_axes A pointer to a StageToAxesMap. + * \param stages The `te::Stage`s used in TVM scheduler applying. + * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. * \return Python schedule code. */ String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; @@ -689,7 +689,7 @@ class CacheReadStepNode : public StepNode { /*! * \brief Apply the current step to State. - * \param state A mutable pointer to State. + * \param state A mutable pointer to state, which will be updated. * \param dag The original ComputeDAG of this state. * \return The index of the new added stage. */ @@ -697,8 +697,8 @@ class CacheReadStepNode : public StepNode { /*! * \brief Apply the current step to tvm.schedule. - * \param stages A mutable pointer to a `te::Stage` Array. - * \param stage_to_axes A mutable pointer to a StageToAxesMap. + * \param stages The `te::Stage`s used in TVM scheduler applying. + * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. * \param schedule A mutable pointer to a te::Schedule. * \return The output Tensor of the new added stage. */ @@ -707,8 +707,8 @@ class CacheReadStepNode : public StepNode { /*! * \brief Print the current step as equivalent python schedule API. - * \param stages A mutable pointer to a `te::Stage` Array. - * \param stage_to_axes A mutable pointer to a StageToAxesMap. + * \param stages The `te::Stage`s used in TVM scheduler applying. + * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. * \param schedule A mutable pointer to a te::Schedule. * \return Python schedule code. */ @@ -760,7 +760,7 @@ class CacheWriteStepNode : public StepNode { /*! * \brief Apply the current step to State. - * \param state A mutable pointer to State. + * \param state A mutable pointer to state, which will be updated. * \param dag The original ComputeDAG of this state. * \return The index of the new added stage. */ @@ -768,8 +768,8 @@ class CacheWriteStepNode : public StepNode { /*! * \brief Apply the current step to tvm.schedule. - * \param stages A mutable pointer to a `te::Stage` Array. - * \param stage_to_axes A mutable pointer to a StageToAxesMap. + * \param stages The `te::Stage`s used in TVM scheduler applying. + * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. * \param schedule A mutable pointer to a te::Schedule. * \return The output Tensors of the new added stage. */ @@ -778,8 +778,8 @@ class CacheWriteStepNode : public StepNode { /*! * \brief Print the current step as equivalent python schedule API. - * \param stages A mutable pointer to a `te::Stage` Array. - * \param stage_to_axes A mutable pointer to a StageToAxesMap. + * \param stages The `te::Stage`s used in TVM scheduler applying. + * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. * \param schedule A mutable pointer to a te::Schedule. * \return Python schedule code. */ From 1a87244b3e750414058477c8aa000ae6ac29715e Mon Sep 17 00:00:00 2001 From: "jingbang.yjb" Date: Fri, 24 Jul 2020 12:35:13 +0800 Subject: [PATCH 21/31] Fix follow_split and follow_fused_split record test. Signed-off-by: jingbang.yjb --- .../unittest/test_auto_scheduler_measure.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/tests/python/unittest/test_auto_scheduler_measure.py b/tests/python/unittest/test_auto_scheduler_measure.py index e3b31ed9d936..6d6fce79d58c 100644 --- a/tests/python/unittest/test_auto_scheduler_measure.py +++ b/tests/python/unittest/test_auto_scheduler_measure.py @@ -27,7 +27,7 @@ def test_record(): if not tvm.runtime.enabled("llvm"): return - + #pdb.set_trace() A = te.placeholder((512, 512), name='A') B = te.placeholder((512, 512), name='B') k = te.reduce_axis((0, 512), name='k') @@ -39,8 +39,9 @@ def test_record(): k = te.reduce_axis((0, 512), name='k') G = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * F[k][j], axis=[k]), name='G') H = topi.nn.relu(G) + I = topi.nn.relu(H) - dag = auto_scheduler.ComputeDAG([A, B, G]) + dag = auto_scheduler.ComputeDAG([A, B, I]) s = dag.get_init_state() # Split @@ -76,17 +77,19 @@ def test_record(): #follow_split its2 = s.split(G, s[G].iters[0], [4, 2, 8, 4], True) split_step0 = len(s.transform_steps) - 1 - s.follow_split(G, s[G].iters[0], split_step0, 1) + s.follow_split(G, s[G].iters[5], split_step0, 4) #follow_fused_split - its3 = s.split(G, s[G].iters[5], [2, 2, 4, 8]) + its2 = s.split(H, s[H].iters[0], [4, 2, 8, 4], True) split_step1 = len(s.transform_steps) - 1 + its3 = s.split(H, s[H].iters[5], [2, 4, 2, 4], True) + split_step2 = len(s.transform_steps) - 1 its = [] for i0, i1 in zip(its2, its3): its.append(i0) its.append(i1) for i in range(0, 5): - s.fuse(G, [s[G].iters[i], s[G].iters[i + 1]]) - s.follow_fused_split(G, s[G].iters[0], [split_step0, split_step1], 0, False) + s.fuse(H, [s[H].iters[i], s[H].iters[i + 1]]) + s.follow_fused_split(I, s[I].iters[0], [split_step1, split_step2], 0, False) target = tvm.target.create("llvm") task = auto_scheduler.SearchTask(dag, "test", target) From 87e703a350c30e1db5ecb1652e425e8a49b89c6b Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Sat, 25 Jul 2020 18:12:59 +0800 Subject: [PATCH 22/31] Doc update --- python/tvm/auto_scheduler/loop_state.py | 10 +++++----- src/auto_scheduler/compute_dag.h | 5 +++-- src/auto_scheduler/loop_state.h | 8 ++++---- src/auto_scheduler/transform_step.h | 8 ++++---- 4 files changed, 16 insertions(+), 15 deletions(-) diff --git a/python/tvm/auto_scheduler/loop_state.py b/python/tvm/auto_scheduler/loop_state.py index 66090b7343c9..8c3a936ccf0c 100644 --- a/python/tvm/auto_scheduler/loop_state.py +++ b/python/tvm/auto_scheduler/loop_state.py @@ -308,7 +308,7 @@ def compute_at(self, stage, target_stage, target_iter): Parameters ---------- stage : Union[int, Operation, Tensor] - The Stage to be compute at, which can be specified by the integer index, Operation, + The Stage to be computed at, which can be specified by the integer index, Operation, or output tensor of the stage. target_stage : Union[int, Operation, Tensor] The target stage of compute_at, which can be specified by the integer index, Operation, @@ -335,8 +335,8 @@ def compute_inline(self, stage): Parameters ---------- stage : Union[int, Operation, Tensor] - The Stage to be compute inlined, which can be specified by the integer index, Operation, - or output tensor of the stage. + The Stage to be marked compute inlined, which can be specified by the integer index, + Operation, or output tensor of the stage. """ self.state_object = _ffi_api.StateComputeInline(self.state_object, self._resolve_stage_id(stage)) @@ -348,8 +348,8 @@ def compute_root(self, stage): Parameters ---------- stage : Union[int, Operation, Tensor] - The Stage to be compute root, which can be specified by the integer index, Operation, - or output tensor of the stage. + The Stage to be marked compute at root, which can be specified by the integer index, + Operation, or output tensor of the stage. Notes ----- diff --git a/src/auto_scheduler/compute_dag.h b/src/auto_scheduler/compute_dag.h index 3f4ea6f269d7..0924363d71a8 100644 --- a/src/auto_scheduler/compute_dag.h +++ b/src/auto_scheduler/compute_dag.h @@ -117,9 +117,10 @@ class ComputeDAG : public ObjectRef { /*! * \brief Since some steps may change the ComputeDAG (e.g. CacheRead/CacheWrite), the initial * ComputeDAG may not be up-to-date. This function replays the given transform steps from the - * initial state and return an up-to-date ComputeDAG. + * initial state and returns an up-to-date ComputeDAG. * \param steps The steps to be replaied. Usually we'll filter out the unused steps to speed up - * the replay process, for we only need to get the new ComputeDAG structure. + * the replay process, since we only intend to get a ComputeDAG with the up-to-date op stage + * structure. * \return The up-to-date ComputeDAG. */ ComputeDAG ReplayAndGetDAG(const Array& steps) const; diff --git a/src/auto_scheduler/loop_state.h b/src/auto_scheduler/loop_state.h index 225fd91b9429..a3a0d1949647 100644 --- a/src/auto_scheduler/loop_state.h +++ b/src/auto_scheduler/loop_state.h @@ -176,7 +176,7 @@ class AttachMap : public ObjectRef { public: /*! * \brief Process the stage/iterator mapping after compute at. - * \param stage_id The index of the stage to be compute at. + * \param stage_id The index of the stage to be computed at. * \param target_stage_id The index of stage that this step will compute at to. * \param target_iter_id The index of iterator in target stage that this step will compute at to. */ @@ -184,7 +184,7 @@ class AttachMap : public ObjectRef { /*! * \brief This is a public wrapper of `DeleteStageEntry`. To delete the entry of a specific stage. - * \param stage_id The index of the stage to be compute at. + * \param stage_id The index of the stage to be computed at. */ void DeleteStage(int stage_id); @@ -199,7 +199,7 @@ class AttachMap : public ObjectRef { /*! * \brief Traverse through `stage_to_attach_iter` and `iter_to_attached_stages` map, add offset - * to stage indexes that are larger than the start_id. Used for steps that inserts new stages to + * to stage indexes that are larger than the start_id. Used for steps that insert new stages to * ComputeDAG(e.g. CacheRead/CacheWrite step). * \param start_id The index threshold, stage indexes in AttachMap which are larger than this * will be applied the extra offset. @@ -373,7 +373,7 @@ class State : public ObjectRef { void compute_inline(int stage_id); /*! * \brief Schedule primitive corresponds to `te::Stage::compute_root`. - * \param stage_id The index of the stage to be the compute root. + * \param stage_id The index of the stage to be marked compute at root. * \note After compute_root, we need careful dependency analysis to compute the accurate bound * information. However, it is relatively expensive and complicated, so we just fill "None" as * bound for the newly created iterators. diff --git a/src/auto_scheduler/transform_step.h b/src/auto_scheduler/transform_step.h index 3dc1ffb88588..cf35f4052e23 100644 --- a/src/auto_scheduler/transform_step.h +++ b/src/auto_scheduler/transform_step.h @@ -541,7 +541,7 @@ class ComputeAtStep : public Step { public: /*! * \brief The constructor. - * \param stage_id The index of the stage to be compute at. + * \param stage_id The index of the stage to be computed at. * \param target_stage_id The index of stage that this step will compute at to. * \param target_iter_id The index of iterator in target stage that this step will compute at to. */ @@ -598,7 +598,7 @@ class ComputeInlineStep : public Step { public: /*! * \brief The constructor. - * \param stage_id The index of the stage to be compute inline. + * \param stage_id The index of the stage to be marked compute inlined. */ explicit ComputeInlineStep(int stage_id); @@ -620,7 +620,7 @@ class ComputeRootStepNode : public StepNode { /*! * \brief Apply the current step to State. * \param state A mutable pointer to state, which will be updated. - * \note After compute_at, we need careful dependency analysis to compute the accurate bound + * \note After compute_root, we need careful dependency analysis to compute the accurate bound * information. However, it is relatively expensive and complicated, so we just fill "None" as * bound for the newly created iterators. * Call ComputeDAG::InferBound on the updated state to get the complete bound information. @@ -657,7 +657,7 @@ class ComputeRootStep : public Step { public: /*! * \brief The constructor. - * \param stage_id The index of the stage to be compute root + * \param stage_id The index of the stage to be marked compute at root. */ explicit ComputeRootStep(int stage_id); From 704c2cc41a05e52bf74f76f43ea047eb77fe5d52 Mon Sep 17 00:00:00 2001 From: "jingbang.yjb" Date: Mon, 27 Jul 2020 13:44:06 +0800 Subject: [PATCH 23/31] Update some doc strings Signed-off-by: jingbang.yjb --- src/auto_scheduler/loop_state.h | 4 ++-- src/auto_scheduler/transform_step.h | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/auto_scheduler/loop_state.h b/src/auto_scheduler/loop_state.h index a6e1ef2a8bcd..45462cf9dc2e 100644 --- a/src/auto_scheduler/loop_state.h +++ b/src/auto_scheduler/loop_state.h @@ -353,7 +353,7 @@ class State : public ObjectRef { Array split(int stage_id, const Iterator& it, const Array>& lengths, bool inner_to_outer = true); /*! - * \brief Schedule primitive corresponds to te.follow_split. + * \brief Schedule primitive extends to split step. * \param stage_id The index of the stage to be split. * \param it The iterator to be split. * \param src_step_id The index of the split step to follow in the history. @@ -362,7 +362,7 @@ class State : public ObjectRef { */ Array follow_split(int stage_id, const Iterator& it, int src_step_id, int n_split); /*! - * \brief Schedule primitive corresponds to te.follow_split. + * \brief Schedule primitive extends to split step. * \param stage_id The index of the stage to be split. * \param it The iterator to be split. * \param src_step_ids The indices of the split steps to follow in the history. diff --git a/src/auto_scheduler/transform_step.h b/src/auto_scheduler/transform_step.h index 0e5325314847..a3951980cede 100644 --- a/src/auto_scheduler/transform_step.h +++ b/src/auto_scheduler/transform_step.h @@ -221,6 +221,7 @@ void StepApplyToSchedule(const Step& step, Array* stages, StageToAxes String StepPrintAsPythonAPI(const Step& step, Array* stages, StageToAxesMap* stage_to_axes, te::Schedule* schedule, const Array& transform_steps); + /********** Primitives working on single stage **********/ /*! @@ -490,11 +491,11 @@ class SplitStep : public Step { TVM_DEFINE_OBJECT_REF_METHODS(SplitStep, Step, SplitStepNode); }; -/*! \brief Similar to SplitStepNode, but use split factor from another stepf +/*! \brief Similar to SplitStepNode, but uses split factors from another step * (i.e. Follow another split step) */ class FollowSplitStepNode : public StepNode { public: - /*! \brief The id of the iter to split. */ + /*! \brief The id of the iter to be split. */ int iter_id; /*! \brief The index of the split step to follow in the history. */ int src_step_id; From 40f9638ce1e0729cd493a060ab895b7b7e6b79f8 Mon Sep 17 00:00:00 2001 From: "jingbang.yjb" Date: Mon, 27 Jul 2020 15:44:11 +0800 Subject: [PATCH 24/31] Fix code style and some function definitions. Signed-off-by: jingbang.yjb --- include/tvm/auto_scheduler/loop_state.h | 14 +- include/tvm/auto_scheduler/transform_step.h | 2 +- src/auto_scheduler/loop_state.cc | 8 +- src/auto_scheduler/transform_step.cc | 2 +- .../test_auto_scheduler_loop_state.py | 301 ++---------------- .../unittest/test_auto_scheduler_measure.py | 2 +- 6 files changed, 45 insertions(+), 284 deletions(-) diff --git a/include/tvm/auto_scheduler/loop_state.h b/include/tvm/auto_scheduler/loop_state.h index 5568ecf5c17d..aade1e1520c5 100644 --- a/include/tvm/auto_scheduler/loop_state.h +++ b/include/tvm/auto_scheduler/loop_state.h @@ -356,8 +356,9 @@ class State : public ObjectRef { * \note If we do split on an iterator which has stages attached at it(by compute_at), the inner * most iterator of split results will become the new attach point. */ - Array split(int stage_id, const Iterator& it, const Array>& lengths, - bool inner_to_outer = true); + TVM_DLL Array split(int stage_id, const Iterator& it, + const Array>& lengths, + bool inner_to_outer = true); /*! * \brief Schedule primitive extends to split step. * \param stage_id The index of the stage to be split. @@ -366,7 +367,8 @@ class State : public ObjectRef { * \param n_split The number of split level. * \return The splitted new Iterators. */ - Array follow_split(int stage_id, const Iterator& it, int src_step_id, int n_split); + TVM_DLL Array follow_split(int stage_id, const Iterator& it, int src_step_id, + int n_split); /*! * \brief Schedule primitive extends to split step. * \param stage_id The index of the stage to be split. @@ -377,9 +379,9 @@ class State : public ObjectRef { False to use `nparts` for split from outer to inner. * \return The splitted new Iterators. */ - Array follow_fused_split(int stage_id, const Iterator& it, - const Array& src_step_ids, int level, - bool factor_or_nparts); + TVM_DLL Array follow_fused_split(int stage_id, const Iterator& it, + const Array& src_step_ids, int level, + bool factor_or_nparts); /********** Step APIs working on multiple stages **********/ diff --git a/include/tvm/auto_scheduler/transform_step.h b/include/tvm/auto_scheduler/transform_step.h index 4e88ced98cf3..dea389e5b4c7 100644 --- a/include/tvm/auto_scheduler/transform_step.h +++ b/include/tvm/auto_scheduler/transform_step.h @@ -648,7 +648,7 @@ class FollowFusedSplitStep : public Step { TVM_DEFINE_OBJECT_REF_METHODS(FollowFusedSplitStep, Step, FollowFusedSplitStepNode); }; -/********** Primitives working on multiple stages **********/ +/********** Steps working on multiple stages **********/ /*! \brief Compute at step that corresponds to te::Stage::compute_at */ class ComputeAtStepNode : public StepNode { diff --git a/src/auto_scheduler/loop_state.cc b/src/auto_scheduler/loop_state.cc index 9dd04955de97..92591c34fe54 100644 --- a/src/auto_scheduler/loop_state.cc +++ b/src/auto_scheduler/loop_state.cc @@ -271,7 +271,6 @@ Array State::split(int stage_id, const Iterator& it, Array State::follow_split(int stage_id, const Iterator& it, int src_step_id, int n_split) { const Stage& stage = operator->()->stages[stage_id]; - FollowSplitStep step = FollowSplitStep(stage_id, GetIndex(stage->iters, it), src_step_id, n_split); CopyOnWrite()->transform_steps.push_back(step); @@ -282,7 +281,6 @@ Array State::follow_fused_split(int stage_id, const Iterator& it, const Array& src_step_ids, int level, bool factor_or_nparts) { const Stage& stage = operator->()->stages[stage_id]; - FollowFusedSplitStep step = FollowFusedSplitStep(stage_id, GetIndex(stage->iters, it), src_step_ids, level, factor_or_nparts); CopyOnWrite()->transform_steps.push_back(step); @@ -485,12 +483,8 @@ TVM_REGISTER_GLOBAL("auto_scheduler.StateFollowSplit") TVM_REGISTER_GLOBAL("auto_scheduler.StateFollowFusedSplit") .set_body_typed([](State state, int stage_id, const Iterator& it, const Array& src_step_ids, int level, bool factor_or_nparts) { - Array array_src_step_ids; - for (const auto& i : src_step_ids) { - array_src_step_ids.push_back(i->value); - } const auto& res = - state.follow_fused_split(stage_id, it, array_src_step_ids, level, factor_or_nparts); + state.follow_fused_split(stage_id, it, src_step_ids, level, factor_or_nparts); return Array{state, Array(res)}; }); diff --git a/src/auto_scheduler/transform_step.cc b/src/auto_scheduler/transform_step.cc index 4584c076a4bb..a9575ddb0687 100644 --- a/src/auto_scheduler/transform_step.cc +++ b/src/auto_scheduler/transform_step.cc @@ -971,7 +971,7 @@ String FollowFusedSplitStepNode::PrintAsPythonAPI(Array* stages, factor_or_nparts); } -/********** Primitives working on multiple stages **********/ +/********** Steps working on multiple stages **********/ /********** Compute At **********/ ComputeAtStep::ComputeAtStep(int stage_id, int target_stage_id, int target_iter_id) { diff --git a/tests/python/unittest/test_auto_scheduler_loop_state.py b/tests/python/unittest/test_auto_scheduler_loop_state.py index 28535c206a2c..06661202403e 100644 --- a/tests/python/unittest/test_auto_scheduler_loop_state.py +++ b/tests/python/unittest/test_auto_scheduler_loop_state.py @@ -455,279 +455,44 @@ def test_cache_read_write(): for it0, it1 in zip(s0[kernel_split].iters, s0[kernel_split_global].iters): assert it0.range == it1.range -def test_cache_read_write(): - N, H, W, CO, CI, KH, KW, strides, padding = 4, 7, 7, 512, 512, 3, 3, ( - 1, 1), (1, 1) - - data = te.placeholder((N, CI, H, W), name='Data') - kernel_data = te.placeholder((CO, CI, KH, KW), name='Kernel_data') - k0, k1 = te.compute(kernel_data.shape, - lambda *i: (kernel_data(*i)+1, kernel_data(*i)/2), - name='Kernel_split') - kernel = te.compute(kernel_data.shape, - lambda *i: k0(*i) + k1(*i), - name='Kernel') - conv = topi.nn.conv2d_nchw(data, kernel, strides, padding, dilation=1) - relu = topi.nn.relu(conv) - add = topi.add(data, relu) - - dag = auto_scheduler.ComputeDAG([data, kernel_data, add]) +def test_follow_split_follow_fused_split(): + A, B, C = matmul_auto_scheduler_test(512, 512, 512) + dag = auto_scheduler.ComputeDAG([A, B, C]) s0 = dag.get_init_state() - pad_temp = s0.stage_ops[1] - kernel_split = s0.stage_ops[3] - - # 0: init state - ori_its = s0[add].iters - its = s0.split(add, s0[add].iters[0], [2]) - s0.reorder(add, [its[0], ori_its[1], its[1], ori_its[2], ori_its[3]]) - s0.compute_inline(relu) - - # 1: simple cache_write with compute_at - conv_global = s0.cache_write(conv, "global") - s0.compute_at(conv_global, conv, s0[conv].iters[3]) - - # 2: simple cache_read with compute_at - kernel_global = s0.cache_read(kernel, "global", [conv_global]) - s0.compute_at(kernel_global, conv_global, s0[conv_global].iters[4]) - """ - Placeholder: Data, Kernel_data - for i0 (0,4) - for i1 (0,512) - for i2 (0,9) - for i3 (0,9) - pad_temp = ... - for i0 (0,512) - for i1 (0,512) - for i2 (0,3) - for i3 (0,3) - Kernel_split = ... - for i0 (0,512) - for i1 (0,512) - for i2 (0,3) - for i3 (0,3) - Kernel = ... - for nn (0,4) - for ff (0,512) - for yy (0,7) - for xx (0,7) - for nn_c (None) - for ff_c (None) - for yy_c (None) - for xx_c (None) - for rc (None) - for ax0 (None) - for ax1 (None) - for ax2 (None) - for ax3 (None) - Kernel.global = ... - for ry (None) - for rx (None) - compute.global = ... - compute = ... - for ax0.0 (0,2) - for ax1 (0,512) - for ax0.1 (0,2) - for ax2 (0,7) - for ax3 (0,7) - T_add = ... - """ - s1 = dag.infer_bound_from_state(s0) - assert s1[conv].iters[0].range.extent == 4 - assert s1[conv].iters[1].range.extent == 512 - assert s1[conv].iters[2].range.extent == 7 - assert s1[conv].iters[3].range.extent == 7 - assert s1[kernel_global].iters[0].range.extent == 1 - assert s1[kernel_global].iters[1].range.extent == 1 - assert s1[kernel_global].iters[2].range.extent == 3 - assert s1[kernel_global].iters[3].range.extent == 3 - assert s1[conv_global].iters[0].range.extent == 1 - assert s1[conv_global].iters[1].range.extent == 1 - assert s1[conv_global].iters[2].range.extent == 1 - assert s1[conv_global].iters[3].range.extent == 1 - assert s1[conv_global].iters[4].range.extent == 512 - assert s1[conv_global].iters[5].range.extent == 3 - assert s1[conv_global].iters[6].range.extent == 3 + C_global = s0.cache_write(C, "global") + its0 = s0.split(C, s0[C].iters[0], [4, 2, 8, 4], True) + split_step0 = len(s0.transform_steps) - 1 + for level in range(1, 6): + tmp = s0.copy() + tmp.follow_split(C_global, tmp[C_global].iters[0], split_step0, level) + for i in range(0, level): + assert tmp[C].iters[i].range.extent == \ + tmp[C_global].iters[i].range.extent - # 3: two level cache_read with compute_at - # preparing for GPU's shared memory & local memory - pad_temp_global = s0.cache_read(pad_temp, "global", [conv_global]) - pad_temp_shared = s0.cache_read(pad_temp_global, "shared", [conv_global]) - s0.compute_at(pad_temp_global, conv_global, s0[conv_global].iters[2]) - s0.compute_at(pad_temp_shared, conv_global, s0[conv_global].iters[4]) + its1 = s0.split(C, s0[C].iters[5], [2, 2, 4, 8]) + split_step1 = len(s0.transform_steps) - 1 + its = [] + for i0, i1 in zip(its0, its1): + its.append(i0) + its.append(i1) + s0.reorder(C, its) + for i in range(0, 5): + s0.fuse(C, [s0[C].iters[i], s0[C].iters[i + 1]]) - # 4: cache_read with multi readers - # This stage cannot be compute at to its consumer - s0.cache_read(data, "global", [pad_temp, add]) - """ - Placeholder: Data, Kernel_data - for ax0 (0,4) - for ax1 (0,512) - for ax2 (0,7) - for ax3 (0,7) - Data.global = ... - for i0 (0,4) - for i1 (0,512) - for i2 (0,9) - for i3 (0,9) - pad_temp = ... - for i0 (0,512) - for i1 (0,512) - for i2 (0,3) - for i3 (0,3) - Kernel_split = ... - for i0 (0,512) - for i1 (0,512) - for i2 (0,3) - for i3 (0,3) - Kernel = ... - for nn (0,4) - for ff (0,512) - for yy (0,7) - for xx (0,7) - for nn_c (None) - for ff_c (None) - for yy_c (None) - for ax0 (None) - for ax1 (None) - for ax2 (None) - for ax3 (None) - pad_temp.global = ... - for xx_c (None) - for rc (None) - for ax0 (None) - for ax1 (None) - for ax2 (None) - for ax3 (None) - Kernel.global = ... - for ax0 (None) - for ax1 (None) - for ax2 (None) - for ax3 (None) - pad_temp.global.shared = ... - for ry (None) - for rx (None) - compute.global = ... - compute = ... - for ax0.0 (0,2) - for ax1 (0,512) - for ax0.1 (0,2) - for ax2 (0,7) - for ax3 (0,7) - T_add = ... - """ - s1 = dag.infer_bound_from_state(s0) - assert s1[conv].iters[0].range.extent == 4 - assert s1[conv].iters[1].range.extent == 512 - assert s1[conv].iters[2].range.extent == 7 - assert s1[conv].iters[3].range.extent == 7 - assert s1[kernel_global].iters[0].range.extent == 1 - assert s1[kernel_global].iters[1].range.extent == 1 - assert s1[kernel_global].iters[2].range.extent == 3 - assert s1[kernel_global].iters[3].range.extent == 3 - assert s1[conv_global].iters[0].range.extent == 1 - assert s1[conv_global].iters[1].range.extent == 1 - assert s1[conv_global].iters[2].range.extent == 1 - assert s1[conv_global].iters[3].range.extent == 1 - assert s1[conv_global].iters[4].range.extent == 512 - assert s1[conv_global].iters[5].range.extent == 3 - assert s1[conv_global].iters[6].range.extent == 3 - assert s1[pad_temp_global].iters[0].range.extent == 1 - assert s1[pad_temp_global].iters[1].range.extent == 512 - assert s1[pad_temp_global].iters[2].range.extent == 3 - assert s1[pad_temp_global].iters[3].range.extent == 3 - assert s1[pad_temp_shared].iters[0].range.extent == 1 - assert s1[pad_temp_shared].iters[1].range.extent == 1 - assert s1[pad_temp_shared].iters[2].range.extent == 3 - assert s1[pad_temp_shared].iters[3].range.extent == 3 + for level in range(0, 4): + tmp = s0.copy() + tmp.follow_fused_split(C_global, tmp[C_global].iters[0], + [split_step0, split_step1], level, False) + assert tmp[C].iters[level + 1].range.extent == \ + tmp[C_global].iters[0].range.extent - # 5: cache_write with multi outputs - # TVM's cache_write actually has a bug with this case: - # - # After schedule.cache_write, TVM generate one new stage: - # From: kernel_data -> kernel_split -> kernel - # To: kernel_data -> kernel_split_global -> kernel_split -> kernel - # - # But with topo sort analyse, we get: - # // kernel_data -> kernel_split_global -> kernel_split -> kernel - # \ / - # ----------------> kernel_split ----------------> - # - # TODO(jcf94): Seems there's bug with the input/output tensor. Such multi outputs case - # should be unusual, so we make some hack on DoCacheWrite. This should be fixed later. - kernel_split_global = s0.cache_write(kernel_split, "global") - """ - Placeholder: Data, Kernel_data - for ax0 (0,4) - for ax1 (0,512) - for ax2 (0,7) - for ax3 (0,7) - Data.global = ... - for i0 (0,4) - for i1 (0,512) - for i2 (0,9) - for i3 (0,9) - pad_temp = ... - for i0_c (0,512) - for i1_c (0,512) - for i2_c (0,3) - for i3_c (0,3) - Kernel_split.global = ... - for i0 (0,512) - for i1 (0,512) - for i2 (0,3) - for i3 (0,3) - Kernel_split = ... - (******* Bug here, there should not be two kernel_split stage *******) - for i0 (0,512) - for i1 (0,512) - for i2 (0,3) - for i3 (0,3) - Kernel_split = ... - (******* Bug here, there should not be two kernel_split stage *******) - for i0 (0,512) - for i1 (0,512) - for i2 (0,3) - for i3 (0,3) - Kernel = ... - for nn (0,4) - for ff (0,512) - for yy (0,7) - for xx (0,7) - for nn_c (None) - for ff_c (None) - for yy_c (None) - for ax0 (None) - for ax1 (None) - for ax2 (None) - for ax3 (None) - pad_temp.global = ... - for xx_c (None) - for rc (None) - for ax0 (None) - for ax1 (None) - for ax2 (None) - for ax3 (None) - Kernel.global = ... - for ax0 (None) - for ax1 (None) - for ax2 (None) - for ax3 (None) - pad_temp.global.shared = ... - for ry (None) - for rx (None) - compute.global = ... - compute = ... - for ax0.0 (0,2) - for ax1 (0,512) - for ax0.1 (0,2) - for ax2 (0,7) - for ax3 (0,7) - T_add = ... - """ - assert len(s0[kernel_split].iters) == len(s0[kernel_split_global].iters) - for it0, it1 in zip(s0[kernel_split].iters, s0[kernel_split_global].iters): - assert it0.range == it1.range + for level in range(0, 4): + tmp = s0.copy() + tmp.follow_fused_split(C_global, tmp[C_global].iters[0], + [split_step0, split_step1], level, True) + assert tmp[C].iters[level + 1].range.extent == \ + tmp[C_global].iters[1].range.extent if __name__ == "__main__": test_split_fuse_reorder_annotation() diff --git a/tests/python/unittest/test_auto_scheduler_measure.py b/tests/python/unittest/test_auto_scheduler_measure.py index 6d6fce79d58c..39d01e0c2969 100644 --- a/tests/python/unittest/test_auto_scheduler_measure.py +++ b/tests/python/unittest/test_auto_scheduler_measure.py @@ -27,7 +27,7 @@ def test_record(): if not tvm.runtime.enabled("llvm"): return - #pdb.set_trace() + A = te.placeholder((512, 512), name='A') B = te.placeholder((512, 512), name='B') k = te.reduce_axis((0, 512), name='k') From 22b1f3bbe81214b861b8371733b354aec27f1b8b Mon Sep 17 00:00:00 2001 From: "jingbang.yjb" Date: Mon, 27 Jul 2020 15:56:11 +0800 Subject: [PATCH 25/31] Update Signed-off-by: jingbang.yjb --- src/auto_scheduler/loop_state.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/auto_scheduler/loop_state.cc b/src/auto_scheduler/loop_state.cc index 92591c34fe54..636066ac957f 100644 --- a/src/auto_scheduler/loop_state.cc +++ b/src/auto_scheduler/loop_state.cc @@ -482,7 +482,7 @@ TVM_REGISTER_GLOBAL("auto_scheduler.StateFollowSplit") TVM_REGISTER_GLOBAL("auto_scheduler.StateFollowFusedSplit") .set_body_typed([](State state, int stage_id, const Iterator& it, - const Array& src_step_ids, int level, bool factor_or_nparts) { + const Array& src_step_ids, int level, bool factor_or_nparts) { const auto& res = state.follow_fused_split(stage_id, it, src_step_ids, level, factor_or_nparts); return Array{state, Array(res)}; From d2e3da6991f7e1c33c4fe1d2646e3209d8ccf90a Mon Sep 17 00:00:00 2001 From: "jingbang.yjb" Date: Mon, 27 Jul 2020 16:12:03 +0800 Subject: [PATCH 26/31] Add comments on parameters. Signed-off-by: jingbang.yjb --- include/tvm/auto_scheduler/transform_step.h | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/include/tvm/auto_scheduler/transform_step.h b/include/tvm/auto_scheduler/transform_step.h index dea389e5b4c7..5e9a06d56434 100644 --- a/include/tvm/auto_scheduler/transform_step.h +++ b/include/tvm/auto_scheduler/transform_step.h @@ -202,6 +202,7 @@ void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag); * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. * \param schedule A mutable pointer to a `te::Schedule`. This is required by some steps which need * `te::Schedule` API. (e.g. CacheRead/CacheWrite step) + * \param transform_steps An array record all transform steps. */ void StepApplyToSchedule(const Step& step, Array* stages, StageToAxesMap* stage_to_axes, te::Schedule* schedule, const Array& transform_steps); @@ -213,6 +214,7 @@ void StepApplyToSchedule(const Step& step, Array* stages, StageToAxes * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. * \param schedule A mutable pointer to a te::Schedule. This is required by some steps. (e.g. * CacheRead/CacheWrite step) + * \param transform_steps An array record all transform steps. * \return Python schedule code. */ String StepPrintAsPythonAPI(const Step& step, Array* stages, @@ -568,7 +570,7 @@ class FollowSplitStep : public Step { }; /*! \brief Similar to FollowSplitStep, but use split factors from multiple steps. - * \Note This can be used for the split in cooperative fetching + * \note This can be used for the split in cooperative fetching */ class FollowFusedSplitStepNode : public StepNode { public: From 072c868a41278a4e033a55a45b24395ee28d7bd6 Mon Sep 17 00:00:00 2001 From: "jingbang.yjb" Date: Tue, 28 Jul 2020 10:13:43 +0800 Subject: [PATCH 27/31] Add more doc strings and fix some. Signed-off-by: jingbang.yjb --- include/tvm/auto_scheduler/loop_state.h | 4 +- include/tvm/auto_scheduler/transform_step.h | 6 +-- python/tvm/auto_scheduler/loop_state.py | 32 ++++++++------- src/auto_scheduler/transform_step.cc | 13 +++++- .../test_auto_scheduler_loop_state.py | 40 ------------------- 5 files changed, 34 insertions(+), 61 deletions(-) diff --git a/include/tvm/auto_scheduler/loop_state.h b/include/tvm/auto_scheduler/loop_state.h index aade1e1520c5..9850620d484d 100644 --- a/include/tvm/auto_scheduler/loop_state.h +++ b/include/tvm/auto_scheduler/loop_state.h @@ -363,7 +363,7 @@ class State : public ObjectRef { * \brief Schedule primitive extends to split step. * \param stage_id The index of the stage to be split. * \param it The iterator to be split. - * \param src_step_id The index of the split step to follow in the history. + * \param src_step_id The index of the split step to be followed in the history. * \param n_split The number of split level. * \return The splitted new Iterators. */ @@ -373,7 +373,7 @@ class State : public ObjectRef { * \brief Schedule primitive extends to split step. * \param stage_id The index of the stage to be split. * \param it The iterator to be split. - * \param src_step_ids The indices of the split steps to follow in the history. + * \param src_step_ids The indices of the split steps to be followed in the history. * \param level Use the length in this split level. * \param factor_or_nparts True to use `factor` for split from inner to outer, False to use `nparts` for split from outer to inner. diff --git a/include/tvm/auto_scheduler/transform_step.h b/include/tvm/auto_scheduler/transform_step.h index 5e9a06d56434..bcfef9947ab2 100644 --- a/include/tvm/auto_scheduler/transform_step.h +++ b/include/tvm/auto_scheduler/transform_step.h @@ -490,8 +490,6 @@ class SplitStep : public Step { TVM_DEFINE_OBJECT_REF_METHODS(SplitStep, Step, SplitStepNode); }; -/********** Steps working on multiple stages **********/ - /*! \brief Similar to SplitStepNode, but uses split factors from another step * (i.e. Follow another split step) */ class FollowSplitStepNode : public StepNode { @@ -569,8 +567,8 @@ class FollowSplitStep : public Step { TVM_DEFINE_OBJECT_REF_METHODS(FollowSplitStep, Step, FollowSplitStepNode); }; -/*! \brief Similar to FollowSplitStep, but use split factors from multiple steps. - * \note This can be used for the split in cooperative fetching +/*! \brief Similar to FollowSplitStep, but uses split factors from multiple steps. + * \note This can be used for the split in cooperative fetching. */ class FollowFusedSplitStepNode : public StepNode { public: diff --git a/python/tvm/auto_scheduler/loop_state.py b/python/tvm/auto_scheduler/loop_state.py index f5ce6a6fdafc..9ec26c34a6e6 100644 --- a/python/tvm/auto_scheduler/loop_state.py +++ b/python/tvm/auto_scheduler/loop_state.py @@ -313,14 +313,15 @@ def split(self, stage, iterator, lengths, inner_to_outer=True): def follow_split(self, stage, iterator, src_step_id, n_split): """ Schedule primitive extends to split step. - This step is used to follow a former SplitStep, keeps their iterator structures to be same. + This step splits the iterator by the same factors as the given SplitStep. - Example cases: - With subgraph: Dense -> Relu - Some tiling structures are used in Relu stage and we intend to compute the Dense - stage at Relu. - The follow_split is used here to keep their outer most few iterators the same for - applying compute at. + Notes + ------ + This step is useful in a scenario that we have subgraph Dense -> Relu, + and we want to compute the Dense stage at ReLU. In this case, we need them to have + the same tiling structure of common outer loops. + The follow_split step could be used here to split the Dense stage and makes sure its + splitting factors are the same as the given split step for the ReLU stage. Parameters ---------- @@ -350,20 +351,23 @@ def follow_fused_split(self, stage, iterator, src_step_ids, level, factor_or_nparts): """ Schedule primitive extends to split step. - This step is used to follow several former SplitSteps and FuseSteps. + This step is used to split an iterator by the same factors + as the given list of SplitSteps and FuseSteps. - Example cases: - With subgraph in GPU schedule: Input -> Dense + Notes + ------ + This step is useful in a scenario that we have a subgraph + in GPU schedule: Input -> Dense for i.0@j.0 = ... : Bind to blockIdx.x for i.1@j.1 = ... : Bind to threadIdx.x for i.2@j.2 = ... Input_shared = Input ... for k = ... Dense = ... - We intend to apply cooperative fetching with the Input stage, while the threadIdx.x - axis is binded to a iterator generated by split & fuse step. - The follow_fused_step is used here to figure out the final extent of the threadIdx.x - binded iterator. + We intend to apply cooperative fetching with the input stage, while the threadIdx.x + axis is bound to an iterator generated by split & fuse step. + The follow_fused_step is used split the iterator to 2 parts, while the split factor + matches the final extent of the threadIdx.x bound iterator. Parameters ---------- diff --git a/src/auto_scheduler/transform_step.cc b/src/auto_scheduler/transform_step.cc index a9575ddb0687..17ae16d45759 100644 --- a/src/auto_scheduler/transform_step.cc +++ b/src/auto_scheduler/transform_step.cc @@ -816,16 +816,25 @@ void FollowSplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { void FollowSplitStepNode::ExtractSplitLengths(const Array& transform_steps, Array>* lengths) const { + // Make sure src_step_id is within the range of transform_steps. CHECK_LT(src_step_id, transform_steps.size()); auto ps = transform_steps[src_step_id].as(); CHECK(ps != nullptr); - // get lengths from src step + // Make sure the size of ps->lengths is not smaller than n_split. + // Note that the number of actual spliting factors of src_step is ps->lengths.size()+1. + CHECK_LE(n_split, ps->lengths.size() + 1); + CHECK(ps != nullptr); + lengths->reserve(n_split); int j = 0; + // Get the first (n_split-1) split factors of followed src_step. for (; j < n_split - 1; ++j) { lengths->push_back(ps->lengths[j]); } + + // Get the last split factor of src_step for spliting level if n_split is smaller than + // ps->lengths.size()+1. PrimExpr last_factor = 1; for (; j < static_cast(ps->lengths.size()); ++j) { if (ps->lengths[j]) { @@ -939,9 +948,11 @@ Optional FollowFusedSplitStepNode::ExtractSplitLength( PrimExpr ret(1); for (int src_step_id : src_step_ids) { + // Make sure the src_step_id is within the range of transform_steps. CHECK_LT(src_step_id, transform_steps.size()); auto ps = transform_steps[src_step_id].as(); CHECK(ps != nullptr); + // Multiple the spliting factor on corresponding spliting level of src_steps. if (ps->lengths[level] && ret.defined()) { ret *= ps->lengths[level].value(); } else { diff --git a/tests/python/unittest/test_auto_scheduler_loop_state.py b/tests/python/unittest/test_auto_scheduler_loop_state.py index 06661202403e..e35dfe3b0d53 100644 --- a/tests/python/unittest/test_auto_scheduler_loop_state.py +++ b/tests/python/unittest/test_auto_scheduler_loop_state.py @@ -85,7 +85,6 @@ def test_split_fuse_reorder_annotation(): assert res == s1[C].iters[5] assert res.annotation == auto_scheduler.loop_state.State.ANNOTATION_TRANS_TABLE["vectorize"] - def test_compute_at_root_inline(): dag = auto_scheduler.ComputeDAG(conv2d_nchw_bn_relu(N=1, H=224, W=224, CI=3, CO=64, kernel_size=7, strides=2, padding=3)) @@ -142,45 +141,6 @@ def test_compute_at_root_inline(): assert s0[conv].iters[5].range.extent == 7 assert s0[conv].iters[6].range.extent == 7 -def test_follow_split_follow_fused_split(): - A, B, C = matmul_auto_scheduler_test(512, 512, 512) - dag = auto_scheduler.ComputeDAG([A, B, C]) - s0 = dag.get_init_state() - - C_global = s0.cache_write(C, "global") - its0 = s0.split(C, s0[C].iters[0], [4, 2, 8, 4], True) - split_step0 = len(s0.transform_steps) - 1 - for level in range(1, 6): - tmp = s0.copy() - tmp.follow_split(C_global, tmp[C_global].iters[0], split_step0, level) - for i in range(0, level): - assert tmp[C].iters[i].range.extent == \ - tmp[C_global].iters[i].range.extent - - its1 = s0.split(C, s0[C].iters[5], [2, 2, 4, 8]) - split_step1 = len(s0.transform_steps) - 1 - its = [] - for i0, i1 in zip(its0, its1): - its.append(i0) - its.append(i1) - s0.reorder(C, its) - for i in range(0, 5): - s0.fuse(C, [s0[C].iters[i], s0[C].iters[i + 1]]) - - for level in range(0, 4): - tmp = s0.copy() - tmp.follow_fused_split(C_global, tmp[C_global].iters[0], - [split_step0, split_step1], level, False) - assert tmp[C].iters[level + 1].range.extent == \ - tmp[C_global].iters[0].range.extent - - for level in range(0, 4): - tmp = s0.copy() - tmp.follow_fused_split(C_global, tmp[C_global].iters[0], - [split_step0, split_step1], level, True) - assert tmp[C].iters[level + 1].range.extent == \ - tmp[C_global].iters[1].range.extent - def test_cache_read_write(): N, H, W, CO, CI, KH, KW, strides, padding = 4, 7, 7, 512, 512, 3, 3, ( 1, 1), (1, 1) From be751231b119db18b054bde4329d2b382e0a6723 Mon Sep 17 00:00:00 2001 From: "jingbang.yjb" Date: Tue, 28 Jul 2020 10:25:14 +0800 Subject: [PATCH 28/31] Update Signed-off-by: jingbang.yjb --- src/auto_scheduler/transform_step.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/auto_scheduler/transform_step.cc b/src/auto_scheduler/transform_step.cc index 17ae16d45759..4702770eecae 100644 --- a/src/auto_scheduler/transform_step.cc +++ b/src/auto_scheduler/transform_step.cc @@ -822,7 +822,7 @@ void FollowSplitStepNode::ExtractSplitLengths(const Array& transform_steps CHECK(ps != nullptr); // Make sure the size of ps->lengths is not smaller than n_split. - // Note that the number of actual spliting factors of src_step is ps->lengths.size()+1. + // Note that the number of actual splitting factors of src_step is ps->lengths.size()+1. CHECK_LE(n_split, ps->lengths.size() + 1); CHECK(ps != nullptr); @@ -833,7 +833,7 @@ void FollowSplitStepNode::ExtractSplitLengths(const Array& transform_steps lengths->push_back(ps->lengths[j]); } - // Get the last split factor of src_step for spliting level if n_split is smaller than + // Get the last split factor of src_step for splitting level if n_split is smaller than // ps->lengths.size()+1. PrimExpr last_factor = 1; for (; j < static_cast(ps->lengths.size()); ++j) { @@ -952,7 +952,7 @@ Optional FollowFusedSplitStepNode::ExtractSplitLength( CHECK_LT(src_step_id, transform_steps.size()); auto ps = transform_steps[src_step_id].as(); CHECK(ps != nullptr); - // Multiple the spliting factor on corresponding spliting level of src_steps. + // Multiple the splitting factor on corresponding splitting level of src_steps. if (ps->lengths[level] && ret.defined()) { ret *= ps->lengths[level].value(); } else { From f4771f91404231ea9abd595e8a07c59fe77945f0 Mon Sep 17 00:00:00 2001 From: "jingbang.yjb" Date: Tue, 28 Jul 2020 10:39:13 +0800 Subject: [PATCH 29/31] Update Signed-off-by: jingbang.yjb --- src/auto_scheduler/transform_step.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/auto_scheduler/transform_step.cc b/src/auto_scheduler/transform_step.cc index 4702770eecae..d25bace7cd2d 100644 --- a/src/auto_scheduler/transform_step.cc +++ b/src/auto_scheduler/transform_step.cc @@ -821,7 +821,7 @@ void FollowSplitStepNode::ExtractSplitLengths(const Array& transform_steps auto ps = transform_steps[src_step_id].as(); CHECK(ps != nullptr); - // Make sure the size of ps->lengths is not smaller than n_split. + // Make sure the size of ps->lengths is not smaller than n_split-1. // Note that the number of actual splitting factors of src_step is ps->lengths.size()+1. CHECK_LE(n_split, ps->lengths.size() + 1); CHECK(ps != nullptr); From 07807570ff0351fd29a1d0469c19ae76bcc6f951 Mon Sep 17 00:00:00 2001 From: "jingbang.yjb" Date: Tue, 28 Jul 2020 14:38:04 +0800 Subject: [PATCH 30/31] Update Signed-off-by: jingbang.yjb --- include/tvm/auto_scheduler/transform_step.h | 23 ++++++++++++--------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/include/tvm/auto_scheduler/transform_step.h b/include/tvm/auto_scheduler/transform_step.h index bcfef9947ab2..f91505ce05e4 100644 --- a/include/tvm/auto_scheduler/transform_step.h +++ b/include/tvm/auto_scheduler/transform_step.h @@ -513,23 +513,24 @@ class FollowSplitStepNode : public StepNode { /*! * \brief Apply the current step to State. - * \param state A mutable pointer to State. + * \param state A mutable pointer to state, which will be updated. */ Array ApplyToState(State* state) const; /*! * \brief Apply the current step to tvm.schedule. - * \param stages A pointer to a `te::Stage` Array. - * \param stage_to_axes A pointer to a StageToAxesMap. + * \param stages The `te::Stage`s used in TVM scheduler applying. + * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. * \param transform_steps An array record all transform steps. + * \return The iterator results after split. */ Array ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes, const Array& transform_steps) const; /*! * \brief Print the current step as equivalent python schedule API. - * \param stages A pointer to a `te::Stage` Array. - * \param stage_to_axes A pointer to a StageToAxesMap. + * \param stages The `te::Stage`s used in TVM scheduler applying. + * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. * \param transform_steps An array record all transform steps. * \return Python schedule code. */ @@ -592,23 +593,25 @@ class FollowFusedSplitStepNode : public StepNode { /*! * \brief Apply the current step to State. - * \param state A mutable pointer to State. + * \param state A mutable pointer to state, which will be updated. + * \return The iterator results after split. */ Array ApplyToState(State* state) const; /*! * \brief Apply the current step to tvm.schedule. - * \param stages A pointer to a `te::Stage` Array. - * \param stage_to_axes A pointer to a StageToAxesMap. + * \param stages The `te::Stage`s used in TVM scheduler applying. + * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. * \param transform_steps An array record all transform steps. + * \return The iterator results after split. */ Array ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes, const Array& transform_steps) const; /*! * \brief Print the current step as equivalent python schedule API. - * \param stages A pointer to a `te::Stage` Array. - * \param stage_to_axes A pointer to a StageToAxesMap. + * \param stages The `te::Stage`s used in TVM scheduler applying. + * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. * \param transform_steps An array record all transform steps. * \return Python schedule code. */ From 4d3a426169c24d84d0ebc6a0ddcc5e3b9646fcdb Mon Sep 17 00:00:00 2001 From: "jingbang.yjb" Date: Tue, 28 Jul 2020 15:11:12 +0800 Subject: [PATCH 31/31] Update. Signed-off-by: jingbang.yjb --- src/auto_scheduler/transform_step.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/auto_scheduler/transform_step.cc b/src/auto_scheduler/transform_step.cc index d25bace7cd2d..d43d0af14499 100644 --- a/src/auto_scheduler/transform_step.cc +++ b/src/auto_scheduler/transform_step.cc @@ -793,8 +793,6 @@ String SplitStepNode::PrintAsPythonAPI(Array* stages, return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, lengths, inner_to_outer); } -/********** Steps working on multiple stages **********/ - /********** Follow Split **********/ FollowSplitStep::FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split) { auto node = make_object();