From b8afb5a125452cb97eae800912d0de205fae822c Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Thu, 30 Jul 2020 09:35:29 -0700 Subject: [PATCH] [AutoScheduler] Improve doc string (#6176) --- include/tvm/auto_scheduler/auto_schedule.h | 4 +- include/tvm/auto_scheduler/compute_dag.h | 37 ++-- include/tvm/auto_scheduler/loop_state.h | 134 ++++++-------- include/tvm/auto_scheduler/transform_step.h | 195 ++++++++++---------- python/tvm/auto_scheduler/compute_dag.py | 20 +- python/tvm/auto_scheduler/loop_state.py | 83 +++++---- src/auto_scheduler/compute_dag.cc | 5 +- src/auto_scheduler/loop_state.cc | 9 - 8 files changed, 232 insertions(+), 255 deletions(-) diff --git a/include/tvm/auto_scheduler/auto_schedule.h b/include/tvm/auto_scheduler/auto_schedule.h index 8477966c0247..8d458f1864ad 100644 --- a/include/tvm/auto_scheduler/auto_schedule.h +++ b/include/tvm/auto_scheduler/auto_schedule.h @@ -100,9 +100,9 @@ class TuningOptions : public ObjectRef { /*! * \brief Run schedule search for a given compute declaration. * \param task The search task of the compute declaration. - * \param search_policy The search policy to be used. + * \param search_policy The search policy. * \param tuning_options Tuning and measurement options. - * \return A `te::schedule` and the an Array of `te::Tensor` to be used in `tvm.lower` or + * \return A `te::schedule` and an Array of `te::Tensor` to be used in `tvm.lower` or * `tvm.build`. */ TVM_DLL std::pair> AutoSchedule(SearchTask task, diff --git a/include/tvm/auto_scheduler/compute_dag.h b/include/tvm/auto_scheduler/compute_dag.h index 69b74bfa35de..16bc7292f889 100644 --- a/include/tvm/auto_scheduler/compute_dag.h +++ b/include/tvm/auto_scheduler/compute_dag.h @@ -22,13 +22,12 @@ * \brief The auto-scheduler's computational graph and related program analyses. * * We convert a compute declaration described by `tvm.compute` (could be a single operator or a - * subgraph) to a ComputeDAG. It keeps the input/output tensors of the compute declaration, - * a list of all operations in the DAG as well as static analysis results for the DAG (e.g. the - * total float operation count, consumer/producer relations of each operation stage, whether an - * operation stage should be tiled/compute inlined ...). These analyses can help the search policy - * to make decisions during search process. - * ComputeDAG is also responsible for the interaction between TVM Auto-scheduler `LoopState` and - * TVM schedule (e.g. applying the `LoopState` transform steps to TVM schedule, providing + * subgraph) to a ComputeDAG. It keeps the input/output tensors, all operations in the DAG, and + * some static analysis results for the DAG (e.g. the total float operation count, consumer/producer + * relations of operations, whether an operation stage should be tiled/compute inlined ...). + * These analyses can help the search policy to make decisions during the search. + * ComputeDAG is also responsible for the interaction between auto-scheduler's `LoopState` and + * TVM schedule (e.g. applying the `LoopState` transform steps to a TVM schedule, providing * `LoopState` with extra information got from TVM schedule ...). */ @@ -47,18 +46,18 @@ namespace tvm { namespace auto_scheduler { -/*! \brief Static analysis result for a ComputeDAG */ +/*! \brief Static analyzer for a ComputeDAG */ class AccessAnalyzerNode : public Object { public: template using OperationMap = std::unordered_map; /*! \brief Map an operation to all operations it reads from. - * For each operation pair, use a two-dimentional array to multiple multi-dimentional accesses + * For each operation pair, use a two-dimentional array for multiple multi-dimentional accesses * The inner vector represents the indices of multi-dimensional access.*/ OperationMap>>> read_from; /*! \brief Map an operation to all operations it is read by. - * For each operation pair, use a two-dimentional array to multiple multi-dimentional accesses + * For each operation pair, use a two-dimentional array for multiple multi-dimentional accesses * The inner vector represents the indices of multi-dimensional access.*/ OperationMap>>> read_by; /*! \brief Store the number of common outer iterators for operation pairs that have @@ -92,7 +91,7 @@ class AccessAnalyzer : public ObjectRef { explicit AccessAnalyzer(const Array& tensors); /*! - * \brief Return whether this operation is an injective operation + * \brief Return whether this operation is an op with simple access * (e.g., injective, broadcast and elementwise ops without reduction) * \param op The operation */ @@ -113,13 +112,13 @@ class AccessAnalyzer : public ObjectRef { TVM_DLL bool NeedsMultiLevelTiling(const te::Operation& op) const; /*! - * \brief Return whether this operation is an output op + * \brief Return whether this operation is an output operation * \param op The operation */ TVM_DLL bool IsOutput(const te::Operation& op) const; /*! - * \brief Get all consumers of on operation + * \brief Get all consumers of an operation * \param state The current loop state * \param op The operation * \return The set of consumers @@ -129,7 +128,7 @@ class AccessAnalyzer : public ObjectRef { const State& state, const te::Operation& op) const; /*! - * \brief Get all producers of on operation + * \brief Get all producers of an operation * \param state The current loop state * \param op The operation * \return The set of producers @@ -139,7 +138,7 @@ class AccessAnalyzer : public ObjectRef { const State& state, const te::Operation& op) const; /*! - * \brief Get all direct producers of on operation + * \brief Get all direct producers of an operation * \param op The operation * \return The set of direct producers * \note This function DOES NOT propagate the relation for inlined ops @@ -158,7 +157,7 @@ class AccessAnalyzer : public ObjectRef { /*! * \brief Return whether two operations are elementwise-matched - * (e.g. conv2d and relu are elementwise matched) + * (e.g. conv2d and relu are elementwise-matched) * \note This function propagates the relation for chains with multiple ops. */ TVM_DLL bool ElementWiseMatch(const te::Operation& op, const te::Operation& target_op) const; @@ -166,7 +165,7 @@ class AccessAnalyzer : public ObjectRef { TVM_DEFINE_OBJECT_REF_METHODS(AccessAnalyzer, ObjectRef, AccessAnalyzerNode); }; -/*! \brief The TVM Auto-scheduler computational graph and related program analyses. */ +/*! \brief The auto-scheduler's computational graph and related program analyses. */ class ComputeDAGNode : public Object { public: /*! @@ -174,9 +173,9 @@ class ComputeDAGNode : public Object { * This is used as the input of `tvm.lower` or `tvm.build`. */ Array tensors; - /*! \brief All related operations in topo order. */ + /*! \brief All used operations in topo order. */ Array ops; - /*! \brief The number of total float operations for this ComputeDAG. */ + /*! \brief The number of float operations in this ComputeDAG. */ double flop_ct; /*! \brief The initial state without any transform steps. */ State init_state; diff --git a/include/tvm/auto_scheduler/loop_state.h b/include/tvm/auto_scheduler/loop_state.h index 34e7e56c405c..ba58f373555b 100644 --- a/include/tvm/auto_scheduler/loop_state.h +++ b/include/tvm/auto_scheduler/loop_state.h @@ -19,7 +19,7 @@ /*! * \file auto_scheduler/loop_state.h - * \brief The definition of the "state" in search. + * \brief The definition of the "state" in the search. * * Each LoopState corresponds to a schedule for its ComputeDAG. * A LoopState consists of: 1. a current loop structure; 2. a list of transformation steps used to @@ -30,7 +30,7 @@ * During the schedule search process, the loop structure can provide search policy with necessary * information on how to manipulate the current state. * The transform history is a sequence of `TransformStep` which will finally be mapped to TVM - * schedule primitives. The steps can also be used for the serialization of a state. + * schedule primitives. The steps are also used for the serialization of a state. * * The LoopState can be seen as a lightweight loop structure IR specifically for schedule search. * We don't use the existing TVM IR but to extend a new structure on it is because: @@ -40,7 +40,7 @@ * 3. We may create some macro schedule primitives that represent the combination of several * TVM schedule primitives. * - * When the search is complete, we will lower the state to TVM IR with TVM's schedule primitives. + * When the search is finished, we will lower the state to TVM IR with TVM's schedule primitives. * Since we share a lot of common objects during search, the transformation is implemented in * copy on write style. All objects are immutable, which is similar to TVM IR. */ @@ -131,7 +131,7 @@ class Stage : public ObjectRef { explicit Stage(te::Operation op); /*! * \brief The constructor. - * \param op A `te::Operation`. + * \param op The source operation * \param op_type The stage type of this op. * \param iters The iterators of this op. * \param compute_at The compute at type of this op. @@ -167,7 +167,7 @@ class AttachMapNode : public Object { /*! \brief A Map to store the mapping of stage to its attached iterator. */ std::unordered_map stage_to_attach_iter; - /*! \brief A Map to store the mapping of iterator to the stage attached to it. */ + /*! \brief A Map to store the mapping of iterator to the stages attached to it. */ std::unordered_map, IterKeyHash> iter_to_attached_stages; static constexpr const char* _type_key = "auto_scheduler.AttachMap"; @@ -182,15 +182,15 @@ class AttachMap : public ObjectRef { public: /*! * \brief Process the stage/iterator mapping after compute at. - * \param stage_id The index of the stage to be computed at. + * \param stage_id The index of the source stage of computed at. * \param target_stage_id The index of stage that this step will compute at to. - * \param target_iter_id The index of iterator in target stage that this step will compute at to. + * \param target_iter_id The index of target iterator in the target stage. */ void SetComputeAtIter(int stage_id, int target_stage_id, int target_iter_id); /*! - * \brief This is a public wrapper of `DeleteStageEntry`. To delete the entry of a specific stage. - * \param stage_id The index of the stage to be computed at. + * \brief Delete the entry of a specific stage. This is a public wrapper of `DeleteStageEntry`. + * \param stage_id The index of the stage to be deleted. */ void DeleteStage(int stage_id); @@ -198,7 +198,7 @@ class AttachMap : public ObjectRef { * \brief Find the relations of original iterators in AttachMap, and update them with the new * iterators. Both `stage_to_attach_iter` and `iter_to_attached_stages` will be updated. * \param original_iters The original IterKey. - * \param new_iters The new IterKey to update. + * \param new_iters The new IterKey for replacing the old ones. */ void UpdateIters(const std::vector& original_iters, const std::vector& new_iters); @@ -206,9 +206,9 @@ class AttachMap : public ObjectRef { /*! * \brief Traverse through `stage_to_attach_iter` and `iter_to_attached_stages` map, add offset * to stage indexes that are larger than the start_id. Used for steps that insert new stages to - * ComputeDAG(e.g. CacheRead/CacheWrite step). - * \param start_id The index threshold, stage indexes in AttachMap which are larger than this - * will be applied the extra offset. + * ComputeDAG (e.g., CacheRead/CacheWrite step). + * \param start_id The index threshold. This function only adds offset for stages + * with indices larger then this threshold. * \param offset The index offset to be added to the stage index. * \return The updated AttachMap after applying stage index offset. */ @@ -219,7 +219,7 @@ class AttachMap : public ObjectRef { private: /*! - * \brief To delete the entry of a specific stage. This will remove the items related to this + * \brief Delete the entry of a specific stage. This will remove the items related to this * stage in both `stage_to_attach_iter` and `iter_to_attached_stages` map. * \param pnode A mutable pointer to AttachMapNode. * \param stage_id The index of stage that will be removed from the map. @@ -244,10 +244,10 @@ class StateNode : public Object { * operation. */ AttachMap attach_map; - /*! \brief The up-to-date ComputeDAG of this state. The default value is an empty NullOpt, means - * no modification to the original ComputeDAG. - * Otherwise, it means some steps (e.g., CacheReadStep/CacheWriteStep) have modified the - * ComputeDAG, the stored value is the up-to-date ComputeDAG for this state. + /*! \brief The up-to-date ComputeDAG of this state. The default value is an empty NullOpt, + * meaning the dag of this state is the same as the original ComputeDAG in the SearchTask. + * Otherwise, the stored value is the up-to-date ComputeDAG for this state, meaning some steps + * (e.g., CacheReadStep/CacheWriteStep) have modified the ComputeDAG. */ Optional current_compute_dag; /*! @@ -279,60 +279,47 @@ class State : public ObjectRef { explicit State(const Array& ops); /*! - * \brief Print the state to a human readable string. + * \brief Pretty-print the state to a human readable string. * \param delete_trivial_loop True for skipping the trivial loops. * (undefined or extent == 1, default set to True) - * \return The human readable state structure. + * \return The human readable string. */ String ToStr(bool delete_trivial_loop = true) const; + /********** Step APIs working on a single stage **********/ /*! - * \brief General call step functions with a runtime dynamic dispatcher. This will re-apply all - * the transform steps from the initial state. - * \param dag The original ComputeDAG of this state. - * \note The input `dag` is different from the class member `current_compute_dag`. - * This function takes the initial ComputeDAG as input to replay all the history. While the - * `current_compute_dag` is used to track the current stage status, for some transform step may - * change the op stage structure. - */ - void ApplySteps(const ComputeDAG& dag); - - /********** Step APIs working on single stage **********/ - - /*! - * \brief Schedule primitive corresponds to `te::Stage::bind`. + * \brief The schedule primitive corresponding to `te::Stage::bind`. * \param stage_id The index of the stage to be binded. * \param it The iterator to be binded. - * \param thread_type The thread type to be binded. We dirctly use the IteratorAnnotation as - * this input. - * \return The iterator result after binded. + * \param thread_type The thread type. + * \return The new iterator after binding. */ TVM_DLL Iterator bind(int stage_id, const Iterator& it, IteratorAnnotation thread_type); /*! - * \brief Schedule primitive corresponds to `te::Stage::parallel`. + * \brief The schedule primitive corresponding to `te::Stage::parallel`. * \param stage_id The index of the stage to be paralleled. * \param it The iterator to be paralleled. - * \return The iterator result after parallel. + * \return The new iterator after parallel. */ TVM_DLL Iterator parallel(int stage_id, const Iterator& it); /*! - * \brief Schedule primitive corresponds to `te::Stage::unroll`. + * \brief The schedule primitive corresponding to `te::Stage::unroll`. * \param stage_id The index of the stage to be unrolled. * \param it The iterator to be unrolled. * \param max_unroll The max unroll limit. Iterator with extent larger than this limit will be * skipped. - * \return The iterator result after unrolled. + * \return The new iterator after unroll. */ TVM_DLL Iterator unroll(int stage_id, const Iterator& it, int max_unroll = -1); /*! - * \brief Schedule primitive corresponds to `te::Stage::vectorize`. + * \brief The schedule primitive corresponding to `te::Stage::vectorize`. * \param stage_id The index of the stage to be vectorized. * \param it The iterator to be vectorized. - * \return The iterator result after vectorize. + * \return The new iterator after vectorization. */ TVM_DLL Iterator vectorize(int stage_id, const Iterator& it); /*! - * \brief Schedule primitive corresponds to `te::Stage::fuse`. + * \brief The schedule primitive corresponding to `te::Stage::fuse`. * \param stage_id The index of the stage to be fused. * \param iters The iterators to be fused. * \return The iterator result after fuse. @@ -341,25 +328,25 @@ class State : public ObjectRef { */ TVM_DLL Iterator fuse(int stage_id, const Array& iters); /*! - * \brief Schedule primitive corresponds to `te.Stage.pragma`. + * \brief The schedule primitive corresponding to `te.Stage.pragma`. * \param stage_id The index of the stage to add pragma. * \param it The iterator to add pragma. * \param pragma_type The pragma string. */ TVM_DLL void pragma(int stage_id, const Iterator& it, const String& pragma_type); /*! - * \brief Schedule primitive corresponds to `te::Stage::reorder`. + * \brief The schedule primitive corresponding to `te::Stage::reorder`. * \param stage_id The index of the stage to be reordered. * \param order The expected iterator order. */ TVM_DLL void reorder(int stage_id, const Array& order); /*! - * \brief Schedule primitive corresponds to `te::Stage::split`. + * \brief The schedule primitive corresponding to `te::Stage::split`. * \param stage_id The index of the stage to be split. * \param it The iterator to be split. * \param lengths The multiple split factors. Can be None to be filled by search policy. - * \param inner_to_outer Whether the factor go from inner to outer, or from outer to inner. - * \return The iterator results after split. + * \param inner_to_outer Whether the factors go from inner to outer, or from outer to inner. + * \return The new iterator after splitting. * \note If we do split on an iterator which has stages attached at it(by compute_at), the inner * most iterator of split results will become the new attach point. */ @@ -367,30 +354,31 @@ class State : public ObjectRef { const Array>& lengths, bool inner_to_outer = true); /*! - * \brief Schedule primitive extends to split step. + * \brief The schedule primitive similar to split, but uses split factors from previous steps. * \param stage_id The index of the stage to be split. * \param it The iterator to be split. * \param src_step_id The index of the split step to be followed in the history. * \param n_split The number of split level. - * \return The splitted new Iterators. + * \return The split new Iterators. */ TVM_DLL Array follow_split(int stage_id, const Iterator& it, int src_step_id, int n_split); /*! - * \brief Schedule primitive extends to split step. + * \brief The schedule primitive similar to split, but uses split factors from + * fused previous steps. * \param stage_id The index of the stage to be split. * \param it The iterator to be split. * \param src_step_ids The indices of the split steps to be followed in the history. * \param level Use the length in this split level. * \param factor_or_nparts True to use `factor` for split from inner to outer, False to use `nparts` for split from outer to inner. - * \return The splitted new Iterators. + * \return The split new Iterators. */ TVM_DLL Array follow_fused_split(int stage_id, const Iterator& it, const Array& src_step_ids, int level, bool factor_or_nparts); /*! - * \brief Schedule primitive corresponds to `te.Stage.storage_align`. + * \brief The schedule primitive corresponding to `te.Stage.storage_align`. * \param stage_id The index of the stage to be aligned. * \param it The iterator to be aligned. * \param factor The factor in alignment specification. @@ -399,64 +387,62 @@ class State : public ObjectRef { TVM_DLL void storage_align(int stage_id, const Iterator& it, int factor, int offset); /********** Step APIs working on multiple stages **********/ - /*! - * \brief Schedule primitive corresponds to `te::Stage::compute_at`. - * \param stage_id The index of the stage to be computed at. + * \brief The schedule primitive corresponding to `te::Stage::compute_at`. + * \param stage_id The index of the source stage of computed at. * \param target_stage_id The index of stage that this step will compute at to. - * \param target_iter The iterator in target stage that this step will compute at to. + * \param target_iter The indiex of the target iterator in the target stage. * \note After compute_at, we need careful dependency analysis to compute the accurate bound * information. However, it is relatively expensive and complicated, so we just fill "None" as * bound for the newly created iterators. - * Call ComputeDAG::InferBound on the updated state to get the complete bound information. + * Call ComputeDAG::InferBound on the updated state if you need the complete bound information. */ TVM_DLL void compute_at(int stage_id, int target_stage_id, const Iterator& target_iter); /*! - * \brief Schedule primitive corresponds to `te::Stage::compute_inline`. + * \brief The schedule primitive corresponding to `te::Stage::compute_inline`. * \param stage_id The index of the stage to be marked compute inlined. */ TVM_DLL void compute_inline(int stage_id); /*! - * \brief Schedule primitive corresponds to `te::Stage::compute_root`. + * \brief The schedule primitive corresponding to `te::Stage::compute_root`. * \param stage_id The index of the stage to be marked compute at root. * \note After compute_root, we need careful dependency analysis to compute the accurate bound * information. However, it is relatively expensive and complicated, so we just fill "None" as * bound for the newly created iterators. - * Call ComputeDAG::InferBound on the updated state to get the complete bound information. + * Call ComputeDAG::InferBound on the updated state if you need the complete bound information. */ TVM_DLL void compute_root(int stage_id); /********** Step APIs adding new stages **********/ - /*! - * \brief Schedule primitive corresponds to `te::Schedule::cache_read`. - * \param stage_id The index of the stage to be cache read. - * \param scope_name The scope name of the newly added read stage. - * \param reader_stage_ids The indices of read stages. + * \brief The schedule primitive corresponding to `te::Schedule::cache_read`. + * \param stage_id The index of the stage to be cache_read. + * \param scope_name The scope name of the newly added stage. + * \param reader_stage_ids The indices of reader stages. * \param dag The original ComputeDAG of this state. * \note Cache read step will add an extra stage to the original ComputeDAG (at the back of the - * target stage), a up-to-date ComputeDAG is stored in State's `current_compute_dag`. + * target stage), an up-to-date ComputeDAG is stored in State's `current_compute_dag`. */ TVM_DLL int cache_read(int stage_id, const String& scope_name, const Array& reader_stage_ids, const ComputeDAG& dag); /*! - * \brief Schedule primitive corresponds to `te::Schedule::cache_write`. - * \param stage_id The index of the stage to be cache write. - * \param scope_name The scope name of the newly added compute stage. + * \brief The schedule primitive corresponding to `te::Schedule::cache_write`. + * \param stage_id The index of the stage to be cache_write. + * \param scope_name The scope name of the newly added stage. * \param dag The original ComputeDAG of this state. * \note Cache write step will add an extra stage to the original ComputeDAG (in the front of the - * target stage), a up-to-date ComputeDAG is stored in State's `current_compute_dag`. + * target stage), an up-to-date ComputeDAG is stored in State's `current_compute_dag`. * This step will cache write all output tensors of the target stage. */ TVM_DLL int cache_write(int stage_id, const String& scope_name, const ComputeDAG& dag); /*! - * \brief Schedule primitive corresponds to `te::Schedule::rfactor`. + * \brief The schedule primitive corresponding to `te::Schedule::rfactor`. * \param stage_id The index of the iterator to be factored. * \param it The iterator to be factored. * \param factor_iter_id The position where the new iterator is placed. * \param dag The original ComputeDAG of this state. * \note Rfactor step will add an extra stage to the original ComputeDAG (in the front of the - * target stage), a up-to-date ComputeDAG is stored in State's `current_compute_dag`. + * target stage), an up-to-date ComputeDAG is stored in State's `current_compute_dag`. */ TVM_DLL int rfactor(int stage_id, const Iterator& it, int factor_iter_id, const ComputeDAG& dag); diff --git a/include/tvm/auto_scheduler/transform_step.h b/include/tvm/auto_scheduler/transform_step.h index a31765a4d44f..d4ef0329d451 100644 --- a/include/tvm/auto_scheduler/transform_step.h +++ b/include/tvm/auto_scheduler/transform_step.h @@ -19,7 +19,7 @@ /*! * \file auto_scheduler/transform_step.h - * \brief Transformation steps. These steps are used to manipulate the LoopState. + * \brief Transformation steps. These steps are used to manipulate `LoopState`. * They are similar to the schedule primitives in te::Stage. * * \note How to add a new transform step: @@ -31,16 +31,15 @@ * 3. Implement `FuseStepNode::ApplyToState` and the state API `State::fuse`. * - In these two functions you need to incrementally update all data structures in State with * CopyOnWrite style. - * 4. Add your step implementation to `StepApplyToState`, `StepApplyToSchedule` and - * `StepPrintAsPythonAPI`, make sure it works. + * 4. Add your step to `StepApplyToState`, `StepApplyToSchedule`, and `StepPrintAsPythonAPI`. * 5. Log record serialization support: * - Add `FuseStepNode::WriteToRecord` which takes a mutable JSONWriter pointer as input and * output the record to it. * - Add another construction function that takes a mutable JSONReader as input, this will get a * step record from the reader and create the step. * - Add the step implementation to `StepReadFromRecord`. - * 6. Add its corresponding Python API to `loop_state.py` and necessary unit test, the test should - * at lease consists of two parts: the functional test and the record serialization test. + * 6. Add its corresponding Python API to `loop_state.py` with necessary unit tests. The test should + * at lease cover two parts: the functional test and the record serialization test. */ #ifndef TVM_AUTO_SCHEDULER_TRANSFORM_STEP_H_ @@ -58,8 +57,8 @@ typedef Map, ObjectHash, ObjectEqual> StageT /*! * \brief Update the current stage IterVar information to StageToAxesMap. - * \param stage A te::Stage Object. - * \param stage_to_axes A mutable pointer to StageToAxesMap, this map will be updated. + * \param stage The stage to be updated. + * \param stage_to_axes The map to be updated. */ void UpdateStageToAxesMap(const te::Stage& stage, StageToAxesMap* stage_to_axes); @@ -106,7 +105,7 @@ enum class IteratorAnnotation : int { extern const char* IteratorAnnotationString[]; /*! - * \brief A for loop iterator + * \brief An iterator of a for-loop * Similar to tvm::IterVar in `include/tvm/tir/expr.h` */ class IteratorNode : public Object { @@ -188,7 +187,7 @@ class ComputeDAG; Step StepReadFromRecord(dmlc::JSONReader* reader); /*! - * \brief Apply the step to State. + * \brief Apply a general step to a State with runtime dynamic dispatching. * \param step The step to be applied to State. * \param state A mutable pointer to state, which will be updated. * \param dag The original ComputeDAG of this state. @@ -196,25 +195,23 @@ Step StepReadFromRecord(dmlc::JSONReader* reader); void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag); /*! - * \brief Apply the step to tvm.schedule. + * \brief Apply a general step to tvm.schedule with runtime dynamic dispatching. * \param step The step to be applied to tvm.schedule. - * \param stages The `te::Stage`s used in TVM scheduler applying. - * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. - * \param schedule A mutable pointer to a `te::Schedule`. This is required by some steps which need - * `te::Schedule` API. (e.g. CacheRead/CacheWrite step) - * \param transform_steps An array record all transform steps. + * \param stages The list of current stages + * \param stage_to_axes A map that maps stage ot all its iterators. + * \param schedule A mutable point to the current schedule + * \param transform_steps An array of all history transform steps. */ void StepApplyToSchedule(const Step& step, Array* stages, StageToAxesMap* stage_to_axes, te::Schedule* schedule, const Array& transform_steps); /*! - * \brief Print the step as equivalent python schedule API. - * \param step The step to be applied to python API. - * \param stages The `te::Stage`s used in TVM scheduler applying. - * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. - * \param schedule A mutable pointer to a te::Schedule. This is required by some steps. (e.g. - * CacheRead/CacheWrite step) - * \param transform_steps An array record all transform steps. + * \brief Print a general step as equivalent python schedule API with runtime dynamic dispatching. + * \param step The step to be printed as python API. + * \param stages The list of current stages + * \param stage_to_axes A map that maps stage ot all its iterators. + * \param schedule A mutable point to the current schedule + * \param transform_steps An array of all history transform steps. * \return Python schedule code. */ String StepPrintAsPythonAPI(const Step& step, Array* stages, @@ -245,15 +242,15 @@ class AnnotationStepNode : public StepNode { /*! * \brief Apply the current step to tvm.schedule. - * \param stages The `te::Stage`s used in TVM scheduler applying. - * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. + * \param stages The list of current stages + * \param stage_to_axes A map that maps stage ot all its iterators. */ void ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const; /*! * \brief Print the current step as equivalent python schedule API. - * \param stages The `te::Stage`s used in TVM scheduler applying. - * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. + * \param stages The list of current stages + * \param stage_to_axes A map that maps stage ot all its iterators. * \return Python schedule code. */ String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; @@ -307,16 +304,16 @@ class FuseStepNode : public StepNode { /*! * \brief Apply the current step to tvm.schedule. - * \param stages The `te::Stage`s used in TVM scheduler applying. - * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. + * \param stages The list of current stages + * \param stage_to_axes A map that maps stage ot all its iterators. * \return The iterator result after fuse. */ tir::IterVar ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const; /*! * \brief Print the current step as equivalent python schedule API. - * \param stages The `te::Stage`s used in TVM scheduler applying. - * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. + * \param stages The list of current stages + * \param stage_to_axes A map that maps stage ot all its iterators. * \return Python schedule code. */ String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; @@ -368,15 +365,15 @@ class PragmaStepNode : public StepNode { /*! * \brief Apply the current step to tvm.schedule. - * \param stages The `te::Stage`s used in TVM scheduler applying. - * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. + * \param stages The list of current stages + * \param stage_to_axes A map that maps stage ot all its iterators. */ void ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const; /*! * \brief Print the current step as equivalent python schedule API. - * \param stages The `te::Stage`s used in TVM scheduler applying. - * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. + * \param stages The list of current stages + * \param stage_to_axes A map that maps stage ot all its iterators. * \return Python schedule code. */ String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; @@ -430,15 +427,15 @@ class ReorderStepNode : public StepNode { /*! * \brief Apply the current step to tvm.schedule. - * \param stages The `te::Stage`s used in TVM scheduler applying. - * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. + * \param stages The list of current stages + * \param stage_to_axes A map that maps stage ot all its iterators. */ void ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const; /*! * \brief Print the current step as equivalent python schedule API. - * \param stages The `te::Stage`s used in TVM scheduler applying. - * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. + * \param stages The list of current stages + * \param stage_to_axes A map that maps stage ot all its iterators. * \return Python schedule code. */ String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; @@ -503,8 +500,8 @@ class SplitStepNode : public StepNode { /*! * \brief Apply the current step to tvm.schedule. - * \param stages The `te::Stage`s used in TVM scheduler applying. - * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. + * \param stages The list of current stages + * \param stage_to_axes A map that maps stage ot all its iterators. * \return The iterator results after split. */ Array ApplyToSchedule(Array* stages, @@ -512,8 +509,8 @@ class SplitStepNode : public StepNode { /*! * \brief Print the current step as equivalent python schedule API. - * \param stages The `te::Stage`s used in TVM scheduler applying. - * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. + * \param stages The list of current stages + * \param stage_to_axes A map that maps stage ot all its iterators. * \return Python schedule code. */ String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; @@ -557,7 +554,7 @@ class FollowSplitStepNode : public StepNode { public: /*! \brief The id of the iter to be split. */ int iter_id; - /*! \brief The index of the split step to follow in the history. */ + /*! \brief The index of the split step to be followed in the history. */ int src_step_id; /*! \brief The number of split level. */ int n_split; @@ -566,7 +563,7 @@ class FollowSplitStepNode : public StepNode { /*! * \brief Extract split lengths. - * \param transform_steps An array record all transform steps. + * \param transform_steps An array of history transform steps. * \return The multiple split factors. */ Array> ExtractSplitLengths(const Array& transform_steps) const; @@ -580,9 +577,9 @@ class FollowSplitStepNode : public StepNode { /*! * \brief Apply the current step to tvm.schedule. - * \param stages The `te::Stage`s used in TVM scheduler applying. - * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. - * \param transform_steps An array record all transform steps. + * \param stages The list of current stages + * \param stage_to_axes A map that maps stage ot all its iterators. + * \param transform_steps An array of history transform steps. * \return The iterator results after split. */ Array ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes, @@ -590,9 +587,9 @@ class FollowSplitStepNode : public StepNode { /*! * \brief Print the current step as equivalent python schedule API. - * \param stages The `te::Stage`s used in TVM scheduler applying. - * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. - * \param transform_steps An array record all transform steps. + * \param stages The list of current stages + * \param stage_to_axes A map that maps stage ot all its iterators. + * \param transform_steps An array of history transform steps. * \return Python schedule code. */ String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes, @@ -614,7 +611,7 @@ class FollowSplitStep : public Step { * \brief The constructor. * \param stage_id The index of the stage to be split. * \param iter_id The index of the iterator to be split. - * \param src_step_id The index of the split step to follow in the history. + * \param src_step_id The index of the split step to be followed in the history. * \param n_split The number of split level. */ FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split); @@ -636,7 +633,7 @@ class FollowFusedSplitStepNode : public StepNode { public: /*! \brief The id of the iter to split. */ int iter_id; - /*! \brief The indices of the split steps to follow in the history. */ + /*! \brief The indices of the split steps to be followed in the history. */ Array src_step_ids; /*! \brief Use the length in this split level. */ int level; @@ -647,7 +644,7 @@ class FollowFusedSplitStepNode : public StepNode { /*! * \brief Extract split length. - * \param transform_steps An array record all transform steps. + * \param transform_steps An array of history transform steps. * \return Split factor. */ Optional ExtractSplitLength(const Array& transform_steps) const; @@ -661,9 +658,9 @@ class FollowFusedSplitStepNode : public StepNode { /*! * \brief Apply the current step to tvm.schedule. - * \param stages The `te::Stage`s used in TVM scheduler applying. - * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. - * \param transform_steps An array record all transform steps. + * \param stages The list of current stages + * \param stage_to_axes A map that maps stage ot all its iterators. + * \param transform_steps An array of history transform steps. * \return The iterator results after split. */ Array ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes, @@ -671,9 +668,9 @@ class FollowFusedSplitStepNode : public StepNode { /*! * \brief Print the current step as equivalent python schedule API. - * \param stages The `te::Stage`s used in TVM scheduler applying. - * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. - * \param transform_steps An array record all transform steps. + * \param stages The list of current stages + * \param stage_to_axes A map that maps stage ot all its iterators. + * \param transform_steps An array of history transform steps. * \return Python schedule code. */ String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes, @@ -695,7 +692,7 @@ class FollowFusedSplitStep : public Step { * \brief The constructor. * \param stage_id The index of the stage to be split. * \param iter_id The index of the iterator to be split. - * \param src_step_ids An array of index for split step to follow in the history. + * \param src_step_ids An array of index for split step to be followed in the history. * \param level Use the length in this split level. * \param factor_or_nparts If this is true, use factor. Otherwise, use nparts. */ @@ -732,15 +729,15 @@ class StorageAlignStepNode : public StepNode { /*! * \brief Apply the current step to tvm.schedule. - * \param stages The `te::Stage`s used in TVM scheduler applying. - * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. + * \param stages The list of current stages + * \param stage_to_axes A map that maps stage ot all its iterators. */ void ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const; /*! * \brief Print the current step as equivalent python schedule API. - * \param stages The `te::Stage`s used in TVM scheduler applying. - * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. + * \param stages The list of current stages + * \param stage_to_axes A map that maps stage ot all its iterators. * \return Python schedule code. */ String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; @@ -794,21 +791,21 @@ class ComputeAtStepNode : public StepNode { * \note After compute_at, we need careful dependency analysis to compute the accurate bound * information. However, it is relatively expensive and complicated, so we just fill "None" as * bound for the newly created iterators. - * Call ComputeDAG::InferBound on the updated state to get the complete bound information. + * Call ComputeDAG::InferBound on the updated state if you need the complete bound information. */ void ApplyToState(State* state) const; /*! * \brief Apply the current step to tvm.schedule. - * \param stages The `te::Stage`s used in TVM scheduler applying. - * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. + * \param stages The list of current stages + * \param stage_to_axes A map that maps stage ot all its iterators. */ void ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const; /*! * \brief Print the current step as equivalent python schedule API. - * \param stages The `te::Stage`s used in TVM scheduler applying. - * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. + * \param stages The list of current stages + * \param stage_to_axes A map that maps stage ot all its iterators. * \return Python schedule code. */ String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; @@ -827,7 +824,7 @@ class ComputeAtStep : public Step { public: /*! * \brief The constructor. - * \param stage_id The index of the stage to be computed at. + * \param stage_id The index of the source stage. * \param target_stage_id The index of stage that this step will compute at to. * \param target_iter_id The index of iterator in target stage that this step will compute at to. */ @@ -856,16 +853,16 @@ class ComputeInlineStepNode : public StepNode { /*! * \brief Apply the current step to tvm.schedule. - * \param stages The `te::Stage`s used in TVM scheduler applying. - * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. + * \param stages The list of current stages + * \param stage_to_axes A map that maps stage ot all its iterators. * \return The iterator result after fuse. */ void ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const; /*! * \brief Print the current step as equivalent python schedule API. - * \param stages The `te::Stage`s used in TVM scheduler applying. - * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. + * \param stages The list of current stages + * \param stage_to_axes A map that maps stage ot all its iterators. * \return Python schedule code. */ String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; @@ -909,22 +906,22 @@ class ComputeRootStepNode : public StepNode { * \note After compute_root, we need careful dependency analysis to compute the accurate bound * information. However, it is relatively expensive and complicated, so we just fill "None" as * bound for the newly created iterators. - * Call ComputeDAG::InferBound on the updated state to get the complete bound information. + * Call ComputeDAG::InferBound on the updated state if you need the complete bound information. */ void ApplyToState(State* state) const; /*! * \brief Apply the current step to tvm.schedule. - * \param stages The `te::Stage`s used in TVM scheduler applying. - * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. + * \param stages The list of current stages + * \param stage_to_axes A map that maps stage ot all its iterators. * \return The iterator result after fuse. */ void ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const; /*! * \brief Print the current step as equivalent python schedule API. - * \param stages The `te::Stage`s used in TVM scheduler applying. - * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. + * \param stages The list of current stages + * \param stage_to_axes A map that maps stage ot all its iterators. * \return Python schedule code. */ String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; @@ -961,12 +958,12 @@ class ComputeRootStep : public Step { /*! * \brief Cache read step that corresponds to te::Schedule::cache_read. - * \note Cache read step will add an extra stage to the original ComputeDAG, a up-to-date ComputeDAG - * is stored in State's `current_compute_dag`. + * \note Cache read step adds an extra stage to the original ComputeDAG, + * an up-to-date ComputeDAG will be stored in State's `current_compute_dag`. */ class CacheReadStepNode : public StepNode { public: - /*! \brief The scope name of the newly added read stage. (e.g. local, shared, global) */ + /*! \brief The scope name of the newly added read stage. (e.g., local, shared, global) */ String scope_name; /*! \brief The indices of read stages. */ Array reader_stage_ids; @@ -983,8 +980,8 @@ class CacheReadStepNode : public StepNode { /*! * \brief Apply the current step to tvm.schedule. - * \param stages The `te::Stage`s used in TVM scheduler applying. - * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. + * \param stages The list of current stages + * \param stage_to_axes A map that maps stage ot all its iterators. * \param schedule A mutable pointer to a te::Schedule. * \return The output Tensor of the new added stage. */ @@ -993,8 +990,8 @@ class CacheReadStepNode : public StepNode { /*! * \brief Print the current step as equivalent python schedule API. - * \param stages The `te::Stage`s used in TVM scheduler applying. - * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. + * \param stages The list of current stages + * \param stage_to_axes A map that maps stage ot all its iterators. * \param schedule A mutable pointer to a te::Schedule. * \return Python schedule code. */ @@ -1015,9 +1012,9 @@ class CacheReadStep : public Step { public: /*! * \brief The constructor. - * \param stage_id The index of the stage to be cache read. - * \param scope_name The scope name of the newly added read stage. - * \param reader_stage_ids The indices of read stages. + * \param stage_id The index of the stage to be cache_read. + * \param scope_name The scope name of the newly added stage. + * \param reader_stage_ids The indices of reader stages. */ CacheReadStep(int stage_id, String scope_name, const Array& reader_stage_ids); @@ -1054,8 +1051,8 @@ class CacheWriteStepNode : public StepNode { /*! * \brief Apply the current step to tvm.schedule. - * \param stages The `te::Stage`s used in TVM scheduler applying. - * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. + * \param stages The list of current stages + * \param stage_to_axes A map that maps stage ot all its iterators. * \param schedule A mutable pointer to a te::Schedule. * \return The output Tensors of the new added stage. */ @@ -1064,8 +1061,8 @@ class CacheWriteStepNode : public StepNode { /*! * \brief Print the current step as equivalent python schedule API. - * \param stages The `te::Stage`s used in TVM scheduler applying. - * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. + * \param stages The list of current stages + * \param stage_to_axes A map that maps stage ot all its iterators. * \param schedule A mutable pointer to a te::Schedule. * \return Python schedule code. */ @@ -1086,8 +1083,8 @@ class CacheWriteStep : public Step { public: /*! * \brief The constructor. - * \param stage_id The index of the stage to be cache write. - * \param scope_name The scope name of the newly added compute stage. + * \param stage_id The index of the stage to be cache_write. + * \param scope_name The scope name of the newly added stage. */ CacheWriteStep(int stage_id, String scope_name); @@ -1121,8 +1118,8 @@ class RfactorStepNode : public StepNode { /*! * \brief Apply the current step to tvm.schedule. - * \param stages The `te::Stage`s used in TVM scheduler applying. - * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. + * \param stages The list of current stages + * \param stage_to_axes A map that maps stage ot all its iterators. * \param schedule A mutable pointer to a te::Schedule. * \return The output Tensors of the new added stage. */ @@ -1131,8 +1128,8 @@ class RfactorStepNode : public StepNode { /*! * \brief Print the current step as equivalent python schedule API. - * \param stages The `te::Stage`s used in TVM scheduler applying. - * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. + * \param stages The list of current stages + * \param stage_to_axes A map that maps stage ot all its iterators. * \param schedule A mutable pointer to a te::Schedule. * \return Python schedule code. */ diff --git a/python/tvm/auto_scheduler/compute_dag.py b/python/tvm/auto_scheduler/compute_dag.py index e08454fb1d09..f56f430bebde 100644 --- a/python/tvm/auto_scheduler/compute_dag.py +++ b/python/tvm/auto_scheduler/compute_dag.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -""" The TVM Auto-scheduler computational graph and related program analyses. """ +""" The auto-scheduler's computational graph and related program analyses. """ import hashlib @@ -33,16 +33,16 @@ @tvm._ffi.register_object("auto_scheduler.ComputeDAG") class ComputeDAG(Object): """ - The TVM Auto-scheduler computational graph and related program analyses. + The auto-scheduler's computational graph and related program analyses. We convert a compute declaration described by `tvm.compute` (could be a single operator or a - subgraph) to a ComputeDAG. It keeps the input/output tensors of the compute declaration, - a list of all operations in the DAG as well as static analysis results for the DAG (e.g. the - total float operation count, consumer/producer relations of each operation stage, whether an - operation stage should be tiled/compute inlined ...). These analyses can help the search policy - to make decisions during search process. - ComputeDAG is also responsible for the interaction between TVM Auto-scheduler `LoopState` and - TVM schedule (e.g. applying the `LoopState` transform steps to TVM schedule, providing + subgraph) to a ComputeDAG. It keeps the input/output tensors, all operations in the DAG, and + some static analysis results for the DAG (e.g. the total float operation count, + consumer/producer relations of operations, whether an operation stage should + be tiled/compute inlined ...). + These analyses can help the search policy to make decisions during the search. + ComputeDAG is also responsible for the interaction between auto-scheduler's `LoopState` and + TVM schedule (e.g. applying the `LoopState` transform steps to a TVM schedule, providing `LoopState` with extra information got from TVM schedule ...). Parameters @@ -90,7 +90,7 @@ def apply_steps_from_state(self, state): def print_python_code_from_state(self, state): """ - Print transform steps in the history of a State as TVM's python schedule primitive. + Print transform steps in the history of a State as TVM's python schedule code. This is used to print transformation steps for debugging. Use `apply_steps_from_state` if you want to get a schedule for code generation. diff --git a/python/tvm/auto_scheduler/loop_state.py b/python/tvm/auto_scheduler/loop_state.py index 35ecacc9249e..da3a4bf5ff34 100644 --- a/python/tvm/auto_scheduler/loop_state.py +++ b/python/tvm/auto_scheduler/loop_state.py @@ -17,17 +17,18 @@ # pylint: disable=unused-import """ -The definition of the "state" in search. +The definition of the "state" in the search. Each LoopState corresponds to a schedule for its ComputeDAG. A LoopState consists of: 1. a current loop structure; 2. a list of transformation steps used to construct the loop structure. The loop structure keeps a preview of how the schedule will finally look like after lowering the -current state (e.g. number of iterators, the extent of each iterator, the compute_at locations ...). +current state (e.g. number of iterators, the extent of each iterator, the compute_at locations +...). During the schedule search process, the loop structure can provide search policy with necessary information on how to manipulate the current state. -The transform history is a sequence of `TransformStep` which will finally be mapped to TVM schedule -primitives. The steps can also be used for the serialization of a state. +The transform history is a sequence of `TransformStep` which will finally be mapped to TVM +schedule primitives. The steps are also used for the serialization of a state. The LoopState can be seen as a lightweight loop structure IR specifically for schedule search. We don't use the existing TVM IR but to extend a new structure on it is because: @@ -37,7 +38,7 @@ 3. We may create some macro schedule primitives that represent the combination of several TVM schedule primitives. -When the search is complete, we will lower the state to TVM IR with TVM's schedule primitives. +When the search is finished, we will lower the state to TVM IR with TVM's schedule primitives. Since we share a lot of common objects during search, the transformation is implemented in copy on write style. All objects are immutable, which is similar to TVM IR. """ @@ -136,8 +137,8 @@ def stage_ops(self): return [stage.op for stage in self.stages] def bind(self, stage, iterator, thread_name): - """ Schedule primitive corresponds to `te.Stage.bind`, see also the `te.Stage` for more - details. + """Schedule primitive corresponding to `te.Stage.bind`. + See also the `te.Stage` for more details. Parameters ---------- @@ -170,8 +171,8 @@ def bind(self, stage, iterator, thread_name): return res def parallel(self, stage, iterator): - """ Schedule primitive corresponds to `te.Stage.parallel`, see also the `te.Stage` for more - details. + """Schedule primitive corresponding to `te.Stage.parallel`. + See also the `te.Stage` for more details. Parameters ---------- @@ -191,8 +192,8 @@ def parallel(self, stage, iterator): return res def unroll(self, stage, iterator, max_unroll=None): - """ Schedule primitive corresponds to `te.Stage.unroll`, see also the `te.Stage` for more - details. + """Schedule primitive corresponding to `te.Stage.unroll`. + See also the `te.Stage` for more details. Parameters ---------- @@ -215,8 +216,8 @@ def unroll(self, stage, iterator, max_unroll=None): return res def vectorize(self, stage, iterator): - """ Schedule primitive corresponds to `te.Stage.vectorize`, see also the `te.Stage` for - more details. + """Schedule primitive corresponding to `te.Stage.vectorize`. + See also the `te.Stage` for more details. Parameters ---------- @@ -236,8 +237,8 @@ def vectorize(self, stage, iterator): return res def fuse(self, stage, iters): - """ Schedule primitive corresponds to `te.Stage.fuse`, see also the `te.Stage` for more - details. + """Schedule primitive corresponding to `te.Stage.fuse`. + See also the `te.Stage` for more details. Parameters ---------- @@ -262,8 +263,8 @@ def fuse(self, stage, iters): return res def pragma(self, stage, iterator, pragma_type): - """ Schedule primitive corresponds to `te.Stage.pragma`, see also the `te.Stage` for more - details. + """Schedule primitive corresponding to `te.Stage.pragma`. + See also the `te.Stage` for more details. Parameters ---------- @@ -279,8 +280,8 @@ def pragma(self, stage, iterator, pragma_type): iterator, pragma_type) def reorder(self, stage, order): - """ Schedule primitive corresponds to `te.Stage.reorder`, see also the `te.Stage` for more - details. + """Schedule primitive corresponding to `te.Stage.reorder`. + See also the `te.Stage` for more details. Parameters ---------- @@ -294,8 +295,8 @@ def reorder(self, stage, order): order) def split(self, stage, iterator, lengths, inner_to_outer=True): - """ Schedule primitive corresponds to `te.Stage.split`, see also the `te.Stage` for more - details. + """Schedule primitive corresponding to `te.Stage.split`. + See also the `te.Stage` for more details. This API supports multiple split factors. (e.g. with 2 split factors, the original iterator will be split to 3 parts, use `inner_to_outer` to control the split order) @@ -328,7 +329,7 @@ def split(self, stage, iterator, lengths, inner_to_outer=True): return res def follow_split(self, stage, iterator, src_step_id, n_split): - """ Schedule primitive extends to split step. + """The schedule primitive similar to split, but uses split factors from previous steps. This step splits the iterator by the same factors as the given SplitStep. @@ -348,7 +349,7 @@ def follow_split(self, stage, iterator, src_step_id, n_split): iterator : Iterator The iterator to split. src_step_id : int - The index of the split step to follow in the history. + The index of the split step to be followed in the history. n_split : int The number of split level. @@ -394,7 +395,7 @@ def follow_fused_split(self, stage, iterator, src_step_ids, level, iterator : Iterator The iterator to split. src_step_ids : List[int] - The indices of the split steps to follow in the history. + The indices of the split steps to be followed in the history. level : int Use the length in this split level. factor_or_nparts : bool @@ -415,8 +416,8 @@ def follow_fused_split(self, stage, iterator, src_step_ids, level, return res def storage_align(self, stage, iterator, factor, offset): - """ Schedule primitive corresponds to `te.Stage.storage_align`, see also the `te.Stage` for - more details. + """Schedule primitive corresponding to `te.Stage.storage_align`. + See also the `te.Stage` for more details. Parameters ---------- @@ -435,14 +436,14 @@ def storage_align(self, stage, iterator, factor, offset): factor, offset) def compute_at(self, stage, target_stage, target_iter): - """ Schedule primitive corresponds to `te.Stage.compute_at`, see also the `te.Stage` for - more details. + """Schedule primitive corresponding to `te.Stage.compute_at`. + See also the `te.Stage` for more details. Parameters ---------- stage : Union[int, Operation, Tensor] - The Stage to be computed at, which can be specified by the integer index, Operation, - or output tensor of the stage. + The source Stage of computed at, which can be specified by the integer index, + Operation, or output tensor of the stage. target_stage : Union[int, Operation, Tensor] The target stage of compute_at, which can be specified by the integer index, Operation, or output tensor of the stage. @@ -462,7 +463,7 @@ def compute_at(self, stage, target_stage, target_iter): target_iter) def compute_inline(self, stage): - """ Schedule primitive corresponds to `te.Stage.compute_inline`, see also the `te.Stage` + """Schedule primitive corresponding to `te.Stage.compute_inline`, see also the `te.Stage` for more details. Parameters @@ -475,8 +476,8 @@ def compute_inline(self, stage): self._resolve_stage_id(stage)) def compute_root(self, stage): - """ Schedule primitive corresponds to `te.Stage.compute_root`, see also the `te.Stage` for - more details. + """Schedule primitive corresponding to `te.Stage.compute_root`. + Ssee also the `te.Stage` for more details. Parameters ---------- @@ -495,13 +496,13 @@ def compute_root(self, stage): self._resolve_stage_id(stage)) def cache_read(self, stage, scope_name, reader_stages): - """ Schedule primitive corresponds to `te.Schedule.cache_read`, see also the `te.Schedule` - for more details. + """Schedule primitive corresponding to `te.Schedule.cache_read`. + See also the `te.Schedule` for more details. Parameters ---------- stage : Union[int, Operation, Tensor] - The Stage to be cache read, which can be specified by the integer index, Operation, + The Stage to be cache_read, which can be specified by the integer index, Operation, or output tensor of the stage. scope_name : str The scope name of the newly added read stage. @@ -531,13 +532,13 @@ def cache_read(self, stage, scope_name, reader_stages): return self.stages[int(new_stage_id)].op def cache_write(self, stage, scope_name): - """ Schedule primitive corresponds to `te.Schedule.cache_write`, see also the `te.Schedule` - for more details. + """Schedule primitive corresponding to `te.Schedule.cache_write`. + See also the `te.Schedule` for more details. Parameters ---------- stage : Union[int, Operation, Tensor] - The Stage to be cache write, which can be specified by the integer index, Operation, + The Stage to be cache_write, which can be specified by the integer index, Operation, or output tensor of the stage. scope_name : str The scope name of the newly added compute stage. @@ -563,8 +564,8 @@ def cache_write(self, stage, scope_name): return self.stages[int(new_stage_id)].op def rfactor(self, stage, iterator, factor_iter_id): - """ Schedule primitive corresponds to `te.Schedule.rfactor`, see also the `te.Schedule` for - more details. + """Schedule primitive corresponding to `te.Schedule.rfactor`. + See also the `te.Schedule` for more details. Parameters ---------- diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index f2815fbdec6b..b11dd7347504 100644 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -24,6 +24,7 @@ #include #include +#include #include #include #include @@ -740,7 +741,9 @@ State ComputeDAG::InferBound(const State& state) const { ret_state = operator->()->init_state; pstate = ret_state.CopyOnWrite(); pstate->transform_steps = state->transform_steps; - ret_state.ApplySteps(*this); + for (const auto& step : pstate->transform_steps) { + StepApplyToState(step, &ret_state, *this); + } } else { ret_state = state; pstate = ret_state.CopyOnWrite(); diff --git a/src/auto_scheduler/loop_state.cc b/src/auto_scheduler/loop_state.cc index f9d1f823fc3e..9e1a54fff15a 100644 --- a/src/auto_scheduler/loop_state.cc +++ b/src/auto_scheduler/loop_state.cc @@ -341,15 +341,6 @@ int State::rfactor(int stage_id, const Iterator& it, int factor_iter_id, const C return step->ApplyToState(this, dag); } -void State::ApplySteps(const ComputeDAG& dag) { - CHECK(operator->()->stages.size()) << "Invalid State with empty operation stages."; - - // Call each step's ApplyToState method - for (const auto& step : operator->()->transform_steps) { - StepApplyToState(step, this, dag); - } -} - // Print stage to ostream void PrintStage(std::ostream* os, int stage_id, const State& state, size_t base_indent, bool delete_trivial_loop) {