From 7cf3fe1d8388f41811f0ad3c84fe09b1daf0abd0 Mon Sep 17 00:00:00 2001 From: Chenfan Date: Wed, 12 Aug 2020 04:53:00 +0800 Subject: [PATCH] [Ansor][AutoTVM v2.0] Phase 2: Basic CPU Sketch Search Policy (#6184) * Init commit to pass the compile * First commit to pass the test * Update * Add UTs for sketch generation * Update * Add ASF to new UT file. * Update rule for winograd * Update * File renamed * Lint fix --- include/tvm/auto_scheduler/auto_schedule.h | 16 +- include/tvm/auto_scheduler/compute_dag.h | 18 +- include/tvm/auto_scheduler/cost_model.h | 12 +- include/tvm/auto_scheduler/search_policy.h | 37 +- include/tvm/auto_scheduler/transform_step.h | 2 +- python/tvm/auto_scheduler/__init__.py | 2 +- python/tvm/auto_scheduler/auto_schedule.py | 126 +++- src/auto_scheduler/auto_schedule.cc | 24 +- src/auto_scheduler/compute_dag.cc | 94 ++- src/auto_scheduler/cost_model.cc | 14 +- src/auto_scheduler/loop_state.cc | 8 +- .../search_policy/empty_policy.cc | 35 +- .../search_policy/empty_policy.h | 7 +- .../search_policy/search_policy.cc | 14 +- .../search_policy/sketch_policy.cc | 401 ++++++++++++ .../search_policy/sketch_policy.h | 176 ++++++ .../search_policy/sketch_policy_rules.cc | 584 ++++++++++++++++++ .../search_policy/sketch_policy_rules.h | 207 +++++++ src/auto_scheduler/search_policy/utils.cc | 286 +++++++++ src/auto_scheduler/search_policy/utils.h | 484 +++++++++++++++ src/auto_scheduler/utils.h | 47 +- tests/cpp/auto_scheduler_test.cc | 4 +- .../unittest/test_auto_scheduler_common.py | 112 +++- .../test_auto_scheduler_loop_state.py | 7 +- .../test_auto_scheduler_search_policy.py | 31 +- .../test_auto_scheduler_sketch_generation.py | 102 +++ 26 files changed, 2713 insertions(+), 137 deletions(-) create mode 100644 src/auto_scheduler/search_policy/sketch_policy.cc create mode 100644 src/auto_scheduler/search_policy/sketch_policy.h create mode 100644 src/auto_scheduler/search_policy/sketch_policy_rules.cc create mode 100644 src/auto_scheduler/search_policy/sketch_policy_rules.h create mode 100644 src/auto_scheduler/search_policy/utils.cc create mode 100644 src/auto_scheduler/search_policy/utils.h create mode 100644 tests/python/unittest/test_auto_scheduler_sketch_generation.py diff --git a/include/tvm/auto_scheduler/auto_schedule.h b/include/tvm/auto_scheduler/auto_schedule.h index 8d458f1864ad1..2d7e5949aea46 100644 --- a/include/tvm/auto_scheduler/auto_schedule.h +++ b/include/tvm/auto_scheduler/auto_schedule.h @@ -42,10 +42,7 @@ class TuningOptionsNode : public Object { int early_stopping; /*! \brief The number of programs to be measured at each search round. */ int num_measures_per_round; - /*! - * \brief Verbosity level. - * 0 for silent, 1 to output information during schedule searching. - */ + /*! \brief Verbosity level. 0 for silent, 1 to output information during schedule searching. */ int verbose; /*! \brief ProgramBuilder which builds the program */ ProgramBuilder builder; @@ -53,8 +50,6 @@ class TuningOptionsNode : public Object { ProgramRunner runner; /*! \brief MeasureCallback functions to be called after each measure batch */ Optional> measure_callbacks; - /*! \brief SearchCallback functions to be called before schedule search */ - Optional> pre_search_callbacks; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("num_measure_trials", &num_measure_trials); @@ -64,7 +59,6 @@ class TuningOptionsNode : public Object { v->Visit("builder", &builder); v->Visit("runner", &runner); v->Visit("measure_callbacks", &measure_callbacks); - v->Visit("pre_search_callbacks", &pre_search_callbacks); } static constexpr const char* _type_key = "auto_scheduler.TuningOptions"; @@ -87,26 +81,22 @@ class TuningOptions : public ObjectRef { * \param builder ProgramBuilder which builds the program. * \param runner ProgramRunner which runs the program and measure time costs. * \param measure_callbacks MeasureCallback functions to be called after each measure batch. - * \param pre_search_callbacks SearchCallback functions to be called before schedule search. */ TuningOptions(int num_measure_trials, int early_stopping, int num_measures_per_round, int verbose, ProgramBuilder builder, ProgramRunner runner, - Optional> measure_callbacks, - Optional> pre_search_callbacks); + Optional> measure_callbacks); TVM_DEFINE_OBJECT_REF_METHODS(TuningOptions, ObjectRef, TuningOptionsNode); }; /*! * \brief Run schedule search for a given compute declaration. - * \param task The search task of the compute declaration. * \param search_policy The search policy. * \param tuning_options Tuning and measurement options. * \return A `te::schedule` and an Array of `te::Tensor` to be used in `tvm.lower` or * `tvm.build`. */ -TVM_DLL std::pair> AutoSchedule(SearchTask task, - SearchPolicy search_policy, +TVM_DLL std::pair> AutoSchedule(SearchPolicy search_policy, TuningOptions tuning_options); } // namespace auto_scheduler } // namespace tvm diff --git a/include/tvm/auto_scheduler/compute_dag.h b/include/tvm/auto_scheduler/compute_dag.h index 16bc7292f8899..34f1c9d8737da 100644 --- a/include/tvm/auto_scheduler/compute_dag.h +++ b/include/tvm/auto_scheduler/compute_dag.h @@ -66,10 +66,10 @@ class AccessAnalyzerNode : public Object { /*! \brief Store whether the operation is an op with only simple access. * (e.g., injective, broadcast and elementwise ops without reduction) */ OperationMap is_simple_access; - /*! \brief Store whether the operation is strictly-inlineable + /*! \brief Store whether the operation is strictly inlineable * (e.g., injective, broadcast and elementwise without reduction, branch or expenive operations) */ - OperationMap is_strict_inlineable; + OperationMap is_strictly_inlineable; /*! \brief Store whether the operation needs multi-level tiling * (e.g., computation-intensive ops with data reuse opportunity like matmul, conv2d) */ OperationMap needs_multi_level_tiling; @@ -102,7 +102,7 @@ class AccessAnalyzer : public ObjectRef { * (e.g., injective, broadcast and elementwise without reduction, branch or expenive operations) * \param op The operation */ - TVM_DLL bool IsStrictInlineable(const te::Operation& op) const; + TVM_DLL bool IsStrictlyInlineable(const te::Operation& op) const; /*! * \brief Return whether this operation needs multi-level tiling @@ -187,6 +187,7 @@ class ComputeDAGNode : public Object { v->Visit("ops", &ops); v->Visit("flop_ct", &flop_ct); v->Visit("init_state", &init_state); + v->Visit("access_analyzer", &access_analyzer); } static constexpr const char* _type_key = "auto_scheduler.ComputeDAG"; @@ -237,6 +238,17 @@ class ComputeDAG : public ObjectRef { */ State InferBound(const State& state) const; + /*! + * \brief Fill the correct bound information for the given states by calling ir_pass::InferBound. + * The states can lose complete bound information after some transform steps (e.g., compute_at). + * We can call this function to infer and fill all the bound information. + * This function calls TVM InferBound pass internally to get the bound. + * The returned state of this function is guaranteed to have complete bound information. + * \param states The input states. + * \return The States with complete bound information + */ + Array InferBound(const Array& states) const; + /*! * \brief Since some steps may change the ComputeDAG (e.g. CacheRead/CacheWrite), the initial * ComputeDAG may not be up-to-date. This function replays the given transform steps from the diff --git a/include/tvm/auto_scheduler/cost_model.h b/include/tvm/auto_scheduler/cost_model.h index a6da93fed6af2..89dcab29265d4 100644 --- a/include/tvm/auto_scheduler/cost_model.h +++ b/include/tvm/auto_scheduler/cost_model.h @@ -18,7 +18,7 @@ */ /*! - * \file tvm/auto_scheduler/cost_model.h + * \file auto_scheduler/cost_model.h * \brief Cost models that estimate the performance of programs */ @@ -54,7 +54,7 @@ class CostModelNode : public Object { * \param states The input states * \param scores The predicted scores for all states */ - virtual void Predict(const SearchTask& task, const std::vector& states, + virtual void Predict(const SearchTask& task, const Array& states, std::vector* scores) = 0; /*! @@ -64,7 +64,7 @@ class CostModelNode : public Object { * \param state_scores The predicted scores for all states * \param stage_scores The predicted scores for all stages in all stages */ - virtual void PredictStages(const SearchTask& task, const std::vector& states, + virtual void PredictStages(const SearchTask& task, const Array& states, std::vector* state_scores, std::vector>* stage_scores) { LOG(FATAL) << "Not implemented"; @@ -91,7 +91,7 @@ class RandomModelNode : public CostModelNode { void Update(const Array& inputs, const Array& results) final; - void Predict(const SearchTask& task, const std::vector& states, + void Predict(const SearchTask& task, const Array& states, std::vector* scores) final; static constexpr const char* _type_key = "auto_scheduler.RandomModel"; @@ -126,10 +126,10 @@ class PythonBasedModelNode : public CostModelNode { void Update(const Array& inputs, const Array& results) final; - void Predict(const SearchTask& task, const std::vector& states, + void Predict(const SearchTask& task, const Array& states, std::vector* scores) final; - void PredictStages(const SearchTask& task, const std::vector& states, + void PredictStages(const SearchTask& task, const Array& states, std::vector* state_scores, std::vector>* stage_scores) final; diff --git a/include/tvm/auto_scheduler/search_policy.h b/include/tvm/auto_scheduler/search_policy.h index 457aca1e8f2ec..33a58aa16e6bf 100644 --- a/include/tvm/auto_scheduler/search_policy.h +++ b/include/tvm/auto_scheduler/search_policy.h @@ -31,6 +31,17 @@ * * Candidate schedules are measured against the specific hardware target. * + * We intend to introduce different level of automation on the schedule generation process: + * - Level 0(the default level): For all kinds of ops/subgraphs, the search policy should be able + * to generate schedule automatically. + * - Level 1: For some complicated ops/subgraphs(e.g. conv2d windograd), the default search space + * of level 0 may be too large to find a high performance schedule efficiently. We provide some + * op attributes to help reduce the total search space, see `SearchPolicyKey` below for more + * information. + * - Level 2: For some further special ops/subgraphs, users may more likely to write their own + * template(just like AutoTVM). Search policy should be able to provide a flexible approach as + * well. + * * \note How to add a new search policy. * In design, there's no need for users to implement their own search policy, our formal search * policy(will be brought later) should be enough to cover most use cases. Meanwhile, a custom rule @@ -89,13 +100,24 @@ class SearchCallback : public ObjectRef { TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SearchCallback, ObjectRef, SearchCallbackNode); }; +/*! \brief Attribute keys of ops used for SearchPolicy. */ +struct SearchPolicyKey { + /*! \brief Always apply unroll to the inner most iterator of the specificed iterators. */ + static constexpr const char* always_unroll_inner = "auto_scheduler_always_unroll_inner"; + /*! \brief The specified iterators will be placed in the inner most tile without split. */ + static constexpr const char* no_split_at_inner = "auto_scheduler_no_split_at_inner"; + /*! \brief The specified iterators are indices of const tensors in "fake reduction". */ + static constexpr const char* simplify_const_tensor_indices = + "auto_scheduler_simplify_const_tensor_indices"; +}; + /*! * \brief The base class of search policies. */ class SearchPolicyNode : public Object { public: /*! \brief The current search task. */ - SearchTask cur_task; + SearchTask search_task; /*! * \brief Verbose level to control the screen output during schedule search. * 0 for silent, 1 to output state & measure information during search process. @@ -103,32 +125,27 @@ class SearchPolicyNode : public Object { int verbose; void VisitAttrs(AttrVisitor* v) { - v->Visit("cur_task", &cur_task); + v->Visit("search_task", &search_task); v->Visit("verbose", &verbose); } /*! * \brief Do schedule search for a task. Takes the SearchTask as input and returns the best state * found during the search. - * \param task The SearchTask for the computation declaration * \param num_measure_trials The number of total measurement trials. * \param early_stopping Stops the tuning early if no improvement after n measurements. * \param num_measures_per_round The number of programs to be measured at each search round. - * \param verbose Verbose level. 0 for silent, 1 to output information during schedule - * search. * \param measurer A ProgramMeasurer to build and measure programs - * \param pre_search_callbacks SearchCallback to be called before schedule search. * \return The best state found. */ - virtual State Search(SearchTask task, int num_measure_trials, int early_stopping, - int num_measures_per_round, int verbose, ProgramMeasurer measurer, - Optional> pre_search_callbacks) = 0; + virtual State Search(int num_measure_trials, int early_stopping, int num_measures_per_round, + ProgramMeasurer measurer) = 0; /*! * \brief Call SearchCallback with the current SearchPolicyNode * \param callbacks SearchCallback to be called. */ - void RunCallbacks(const Optional>& callbacks); + void RunCallbacks(const Array& callbacks); static constexpr const char* _type_key = "auto_scheduler.SearchPolicy"; TVM_DECLARE_BASE_OBJECT_INFO(SearchPolicyNode, Object); diff --git a/include/tvm/auto_scheduler/transform_step.h b/include/tvm/auto_scheduler/transform_step.h index d4ef0329d4516..5c8850486ad74 100644 --- a/include/tvm/auto_scheduler/transform_step.h +++ b/include/tvm/auto_scheduler/transform_step.h @@ -478,7 +478,7 @@ class SplitStepNode : public StepNode { /*! \brief The id of the iter to split. */ int iter_id; /*! \brief The extent length of the axis to split. */ - Optional extent; + Optional extent; /*! \brief The split factors. */ Array> lengths; /*! diff --git a/python/tvm/auto_scheduler/__init__.py b/python/tvm/auto_scheduler/__init__.py index 62ebce0299d8b..32ac4f5a3e3a3 100644 --- a/python/tvm/auto_scheduler/__init__.py +++ b/python/tvm/auto_scheduler/__init__.py @@ -26,7 +26,7 @@ # Shortcut from .auto_schedule import SearchTask, TuningOptions, HardwareParams, \ - auto_schedule, EmptyPolicy + auto_schedule, EmptyPolicy, SketchPolicy from .compute_dag import ComputeDAG from .cost_model import RandomModel from .measure import MeasureInput, LocalBuilder, LocalRunner, RPCRunner, \ diff --git a/python/tvm/auto_scheduler/auto_schedule.py b/python/tvm/auto_scheduler/auto_schedule.py index 52aa62baf56f1..d3b18fe020fb6 100644 --- a/python/tvm/auto_scheduler/auto_schedule.py +++ b/python/tvm/auto_scheduler/auto_schedule.py @@ -28,9 +28,12 @@ Candidate schedules are measured against the specific hardware target. """ +import random + import tvm._ffi from tvm.runtime import Object from .measure import LocalBuilder, LocalRunner +from .cost_model import RandomModel from . import _ffi_api @@ -88,10 +91,94 @@ class SearchPolicy(Object): class EmptyPolicy(SearchPolicy): """ This is an example empty search policy which will always generate the init state of ComputeDAG. + + Parameters + ---------- + task : SearchTask + The SearchTask for the computation declaration. + init_search_callbacks : Optional[List[SearchCallback]] + Callback functions called before the search process. + """ + def __init__(self, task, init_search_callbacks=None): + self.__init_handle_by_constructor__(_ffi_api.EmptyPolicy, task, init_search_callbacks) + + +@tvm._ffi.register_object("auto_scheduler.SketchPolicy") +class SketchPolicy(SearchPolicy): + """ The search policy that searches in a hierarchical search space defined by sketches. + The policy randomly samples programs from the space defined by sketches + and use evolutionary search to fine-tune them. + + Parameters + ---------- + task : SearchTask + The SearchTask for the computation declaration. + schedule_cost_model : CostModel = RandomModel() + The cost model to estimate the complete schedules. + params : Optional[Dict[str, Any]] + Parameters of the search policy. + See `src/auto_scheduler/search_policy/sketch_search_policy.h` for the definitions. + See `DEFAULT_PARAMS` below to find the default values. + seed : Optional[int] + Random seed. + verbose : int = 1 + Verbosity level. 0 for silent, 1 to output information during schedule search. + init_search_callbacks : Optional[List[SearchCallback]] + Callback functions called before the search process, usually used to do extra + initializations. + Possible callbacks: + - auto_scheduler.PreloadMeasuredStates + - auto_scheduler.PreloadCustomSketchRule + TODO(jcf94): Add these search callback implementations. """ - def __init__(self): - self.__init_handle_by_constructor__(_ffi_api.EmptyPolicy) + DEFAULT_PARAMS = { + "eps_greedy": 0.05, + + 'evolutionary_search_population': 2048, + "evolutionary_search_use_measured_ratio": 0.2, + + 'cpu_multi_level_tiling_structure': 'SSRSRS', + 'gpu_multi_level_tiling_structure': 'SSSRRSRS', + + 'max_innermost_split_factor': 16, + 'max_vectorize_size': 16, + + 'disable_change_compute_location': 0, + } + + def __init__(self, task, schedule_cost_model=RandomModel(), params=None, seed=None, verbose=1, + init_search_callbacks=None): + if params is None: + params = SketchPolicy.DEFAULT_PARAMS + else: + for key, value in SketchPolicy.DEFAULT_PARAMS.items(): + if key not in params: + params[key] = value + + self.__init_handle_by_constructor__( + _ffi_api.SketchPolicy, task, schedule_cost_model, params, + seed or random.randint(1, 1 << 30), verbose, init_search_callbacks) + + def generate_sketches(self, print_for_debug=False): + """ Generate the sketches, this is mainly used for debug. + + Parameters + ---------- + print_for_debug : bool = False + Whether print out the sketches for debug. + + Returns + ------- + sketches : List[State] + The generated sketches of this search task. + """ + sketches = _ffi_api.SketchPolicyGenerateSketches(self) + if print_for_debug: + for i, s in enumerate(sketches): + print("=" * 20 + " %d " % i + "=" * 20) + print(s) + return sketches @tvm._ffi.register_object("auto_scheduler.TuningOptions") class TuningOptions(Object): @@ -121,16 +208,9 @@ class TuningOptions(Object): Callback functions called after each measurement. Candidates: - auto_scheduler.RecordToFile - pre_search_callbacks: Optional[List[SearchCallback]] - Callback functions called before the search process. - Candidates: - - auto_scheduler.PreloadMeasuredStates - - auto_scheduler.PreloadCustomSketchRule - TODO(jcf94): Add these implementation in later PRs. """ def __init__(self, num_measure_trials=0, early_stopping=None, num_measures_per_round=64, - verbose=1, builder='local', runner='local', measure_callbacks=None, - pre_search_callbacks=None): + verbose=1, builder='local', runner='local', measure_callbacks=None): if isinstance(builder, str): if builder == 'local': builder = LocalBuilder() @@ -150,20 +230,20 @@ def __init__(self, num_measure_trials=0, early_stopping=None, num_measures_per_r " . TuningOptions expects a ProgramRunner or string.") self.__init_handle_by_constructor__( - _ffi_api.TuningOptions, num_measure_trials, early_stopping if early_stopping else -1, - num_measures_per_round, verbose, builder, runner, measure_callbacks, - pre_search_callbacks) + _ffi_api.TuningOptions, num_measure_trials, early_stopping or -1, + num_measures_per_round, verbose, builder, runner, measure_callbacks) -def auto_schedule(task, search_policy='default', tuning_options=None): +def auto_schedule(task, search_policy=None, tuning_options=TuningOptions()): """ Do auto scheduling for a computation declaration. Parameters ---------- task : SearchTask The SearchTask for the computation declaration. - search_policy : Union[SearchPolicy, str] = 'default' - The search policy to be used for schedule search. + search_policy : Optional[SearchPolicy] + The search policy to be used for schedule search. Use EmptyPolicy as default, which always + returns an empty schedule. tuning_options : Optional[TuningOptions] Tuning and measurement options. @@ -175,17 +255,5 @@ def auto_schedule(task, search_policy='default', tuning_options=None): raise ValueError("Invalid task: " + task + " . `auto_scheduler.auto_schedule` expects a SearchTask.") - if isinstance(search_policy, str): - if search_policy == 'default': - # TODO(jcf94): This is an example policy for minimum system, will be upgrated to - # formal search policy later. - search_policy = EmptyPolicy() - else: - raise ValueError("Invalid search policy: " + search_policy) - elif not isinstance(search_policy, SearchPolicy): - raise ValueError("Invalid search policy: " + search_policy + - " . `auto_scheduler.auto_schedule` expects a SearchPolicy or a string.") - - sch, tensors = _ffi_api.AutoSchedule(task, search_policy, - tuning_options if tuning_options else TuningOptions()) + sch, tensors = _ffi_api.AutoSchedule(search_policy or EmptyPolicy(task), tuning_options) return sch, tensors diff --git a/src/auto_scheduler/auto_schedule.cc b/src/auto_scheduler/auto_schedule.cc index c537ca702b9da..48679597c3e18 100644 --- a/src/auto_scheduler/auto_schedule.cc +++ b/src/auto_scheduler/auto_schedule.cc @@ -34,8 +34,7 @@ TVM_REGISTER_NODE_TYPE(TuningOptionsNode); TuningOptions::TuningOptions(int num_measure_trials, int early_stopping, int num_measures_per_round, int verbose, ProgramBuilder builder, ProgramRunner runner, - Optional> measure_callbacks, - Optional> pre_search_callbacks) { + Optional> measure_callbacks) { auto node = make_object(); node->num_measure_trials = num_measure_trials; node->early_stopping = early_stopping; @@ -44,38 +43,35 @@ TuningOptions::TuningOptions(int num_measure_trials, int early_stopping, int num node->builder = std::move(builder); node->runner = std::move(runner); node->measure_callbacks = std::move(measure_callbacks); - node->pre_search_callbacks = std::move(pre_search_callbacks); data_ = std::move(node); } -std::pair> AutoSchedule(SearchTask task, SearchPolicy search_policy, +std::pair> AutoSchedule(SearchPolicy search_policy, TuningOptions tuning_options) { // Create a ProgramMeasurer to handle the schedule build and performance measure ProgramMeasurer measurer = ProgramMeasurer(tuning_options->builder, tuning_options->runner, tuning_options->measure_callbacks, tuning_options->verbose); // Search for the best schedule - State state = search_policy->Search( - task, tuning_options->num_measure_trials, tuning_options->early_stopping, - tuning_options->num_measures_per_round, tuning_options->verbose, measurer, - tuning_options->pre_search_callbacks); - return task->compute_dag.ApplySteps(state->transform_steps); + State state = + search_policy->Search(tuning_options->num_measure_trials, tuning_options->early_stopping, + tuning_options->num_measures_per_round, measurer); + return search_policy->search_task->compute_dag.ApplySteps(state->transform_steps); } TVM_REGISTER_GLOBAL("auto_scheduler.TuningOptions") .set_body_typed([](int num_measure_trials, int early_stopping, int num_measures_per_round, int verbose, ProgramBuilder builder, ProgramRunner runner, - Optional> measure_callbacks, - Optional> pre_search_callbacks) { + Optional> measure_callbacks) { return TuningOptions(num_measure_trials, early_stopping, num_measures_per_round, verbose, - builder, runner, measure_callbacks, pre_search_callbacks); + builder, runner, measure_callbacks); }); TVM_REGISTER_GLOBAL("auto_scheduler.AutoSchedule") - .set_body_typed([](SearchTask task, SearchPolicy search_policy, TuningOptions tuning_options) { + .set_body_typed([](SearchPolicy search_policy, TuningOptions tuning_options) { te::Schedule sch; Array return_tensors; - std::tie(sch, return_tensors) = AutoSchedule(task, search_policy, tuning_options); + std::tie(sch, return_tensors) = AutoSchedule(search_policy, tuning_options); return Array{sch, return_tensors}; }); } // namespace auto_scheduler diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index b11dd7347504a..87b162aca6d55 100644 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -24,6 +24,7 @@ #include #include +#include #include #include #include @@ -309,12 +310,12 @@ AccessAnalyzer::AccessAnalyzer(const Array& tensors) { if (op->IsInstance()) { node->is_simple_access[op] = true; node->needs_multi_level_tiling[op] = false; - node->is_strict_inlineable[op] = false; + node->is_strictly_inlineable[op] = false; node->is_output[op] = false; } else if (auto cop = op.as()) { // check whether this op is element-wise and strict-inlineable bool is_simple_access = true; - bool is_strict_inlineable = true; + bool is_strictly_inlineable = true; bool axis_missing, axis_duplicated, same_order; for (const auto& pair : node->read_from[op]) { @@ -323,12 +324,12 @@ AccessAnalyzer::AccessAnalyzer(const Array& tensors) { if (!auto_scheduler::IsSimpleAccess(op, access, &axis_missing, &axis_duplicated, &same_order)) { is_simple_access = false; - is_strict_inlineable = false; + is_strictly_inlineable = false; break; } if (!same_order || axis_duplicated) { // do not strictly inline transpose - is_strict_inlineable = false; + is_strictly_inlineable = false; } } if (!is_simple_access) { @@ -342,11 +343,16 @@ AccessAnalyzer::AccessAnalyzer(const Array& tensors) { has_expensive_op |= HasExpensiveOp(expr); } if (has_expensive_op || has_branch[op]) { - is_strict_inlineable = false; + is_strictly_inlineable = false; + } + + // constant tensor is strict-inlineable + if (node->read_from[op].empty()) { + is_strictly_inlineable = true; } node->is_simple_access[op] = is_simple_access; - node->is_strict_inlineable[op] = is_strict_inlineable; + node->is_strictly_inlineable[op] = is_strictly_inlineable; // check whether the op needs multi-level tiling bool needs_multi_level_tiling = false; @@ -374,6 +380,11 @@ AccessAnalyzer::AccessAnalyzer(const Array& tensors) { } } + // do not perform multi-level tiling on "fake reduction" with const tensors + if (op->attrs.count(SearchPolicyKey::simplify_const_tensor_indices)) { + needs_multi_level_tiling = false; + } + node->needs_multi_level_tiling[op] = needs_multi_level_tiling; // check whether the op is output @@ -398,8 +409,8 @@ bool AccessAnalyzer::IsSimpleAccess(const te::Operation& op) const { return operator->()->is_simple_access.at(op); } -bool AccessAnalyzer::IsStrictInlineable(const te::Operation& op) const { - return operator->()->is_strict_inlineable.at(op); +bool AccessAnalyzer::IsStrictlyInlineable(const te::Operation& op) const { + return operator->()->is_strictly_inlineable.at(op); } OperationSet AccessAnalyzer::GetConsumers(const State& state, const te::Operation& op) const { @@ -791,6 +802,21 @@ State ComputeDAG::InferBound(const State& state) const { return ret_state; } +Array ComputeDAG::InferBound(const Array& states) const { + Array out_states; + // TODO(jcf94, merrymercy): Use parallel_for to run this in parallel + for (const auto& state : states) { + State out_state; + try { + out_state = this->InferBound(state); + } catch (dmlc::Error& e) { + LOG(WARNING) << "InferBound fails on the state:\n" << state << "\n" << e.what() << std::endl; + } + out_states.push_back(std::move(out_state)); + } + return out_states; +} + ComputeDAG ComputeDAG::ReplayAndGetDAG(const Array& transform_steps) const { te::Schedule sch; Array old_tensors; @@ -808,6 +834,58 @@ ComputeDAG ComputeDAG::ReplayAndGetDAG(const Array& transform_steps) const return ComputeDAG(new_tensors); } +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + for (const auto& op : node->ops_topo_order) { + p->stream << op << std::endl; + p->stream << "is_simple_access:\t" << node->is_simple_access.at(op) << "\t\t"; + p->stream << "needs_multi_level_tiling:\t" << node->needs_multi_level_tiling.at(op) + << std::endl; + p->stream << "is_strictly_inlinable:\t" << node->is_strictly_inlineable.at(op) << "\t"; + p->stream << "is_output:\t" << node->is_output.at(op) << std::endl; + p->stream << "Read from:\t"; + for (const auto& pair : node->read_from.at(op)) { + for (const auto& index : pair.second) { + p->stream << pair.first->name << Array(index) << ", "; + } + } + p->stream << std::endl; + p->stream << "Read by:\t"; + for (const auto& pair : node->read_by.at(op)) { + for (const auto& index : pair.second) { + p->stream << pair.first->name << Array(index) << ", "; + } + } + p->stream << std::endl; + p->stream << Chars('=', 50) << std::endl; + } + + AccessAnalyzer ana = GetRef(node); + p->stream << "ElementwiseMatch: \n"; + for (size_t i = 0; i < node->ops_topo_order.size(); ++i) { + for (size_t j = 0; j < node->ops_topo_order.size(); ++j) { + if (i == j) { + continue; + } + if (ana.ElementWiseMatch(node->ops_topo_order[i], node->ops_topo_order[j])) { + p->stream << node->ops_topo_order[i]->name << " -> " << node->ops_topo_order[j]->name + << std::endl; + } + } + } + p->stream << Chars('=', 50) << std::endl; + + p->stream << "NumCommonOuterIterators: \n"; + for (const auto& src_pair : node->num_common_outer_iterators) { + for (const auto& dst_pair : src_pair.second) { + p->stream << src_pair.first->name << " " << dst_pair.first->name << " " << dst_pair.second + << std::endl; + } + } + p->stream << Chars('=', 50) << std::endl; + }); + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); diff --git a/src/auto_scheduler/cost_model.cc b/src/auto_scheduler/cost_model.cc index 2f1c3091c785e..68c1d5c1f1186 100644 --- a/src/auto_scheduler/cost_model.cc +++ b/src/auto_scheduler/cost_model.cc @@ -42,7 +42,7 @@ RandomModel::RandomModel() { void RandomModelNode::Update(const Array& inputs, const Array& results) {} -void RandomModelNode::Predict(const SearchTask& task, const std::vector& states, +void RandomModelNode::Predict(const SearchTask& task, const Array& states, std::vector* scores) { scores->resize(states.size()); (*random_number_func)(states.size(), static_cast(scores->data())); @@ -62,14 +62,13 @@ void PythonBasedModelNode::Update(const Array& inputs, update_func(inputs, results); } -void PythonBasedModelNode::Predict(const SearchTask& task, const std::vector& states, +void PythonBasedModelNode::Predict(const SearchTask& task, const Array& states, std::vector* scores) { scores->resize(states.size()); - predict_func(task, Array(states.begin(), states.end()), - static_cast(scores->data())); + predict_func(task, states, static_cast(scores->data())); } -void PythonBasedModelNode::PredictStages(const SearchTask& task, const std::vector& states, +void PythonBasedModelNode::PredictStages(const SearchTask& task, const Array& states, std::vector* state_scores, std::vector>* stage_scores) { size_t n_states = states.size(); @@ -77,8 +76,7 @@ void PythonBasedModelNode::PredictStages(const SearchTask& task, const std::vect std::vector flatten_scores; // Allocate sufficient spaces. flatten_scores.resize(n_states * n_stages * 2); - predict_stage_func(task, Array(states.begin(), states.end()), - static_cast(flatten_scores.data())); + predict_stage_func(task, states, static_cast(flatten_scores.data())); // Unpack flatten scores. state_scores->clear(); @@ -144,7 +142,7 @@ TVM_REGISTER_GLOBAL("auto_scheduler.CostModelUpdate") TVM_REGISTER_GLOBAL("auto_scheduler.CostModelPredict") .set_body_typed([](CostModel model, SearchTask task, Array states) { std::vector scores; - model->Predict(task, std::vector(states.begin(), states.end()), &scores); + model->Predict(task, states, &scores); Array ret; for (auto x : scores) { ret.push_back(FloatImm(DataType::Float(32), x)); diff --git a/src/auto_scheduler/loop_state.cc b/src/auto_scheduler/loop_state.cc index 9e1a54fff15a3..649e25e913096 100644 --- a/src/auto_scheduler/loop_state.cc +++ b/src/auto_scheduler/loop_state.cc @@ -113,6 +113,7 @@ void AttachMap::UpdateIters(const std::vector& original_iters, const std::vector& new_iters) { CHECK_EQ(original_iters.size(), new_iters.size()); AttachMapNode* pnode = CopyOnWrite(); + std::unordered_map> new_iter_to_attached_stages; for (size_t i = 0; i < original_iters.size(); ++i) { auto entry = pnode->iter_to_attached_stages.find(original_iters[i]); // We get > from this map @@ -130,7 +131,12 @@ void AttachMap::UpdateIters(const std::vector& original_iters, // iterator to it std::vector attached_stages = std::move(entry->second); pnode->iter_to_attached_stages.erase(entry); - pnode->iter_to_attached_stages[new_iters[i]] = std::move(attached_stages); + new_iter_to_attached_stages[new_iters[i]] = std::move(attached_stages); + } + + // Update new entries + for (auto& it : new_iter_to_attached_stages) { + pnode->iter_to_attached_stages[it.first] = std::move(it.second); } } diff --git a/src/auto_scheduler/search_policy/empty_policy.cc b/src/auto_scheduler/search_policy/empty_policy.cc index 4c85af486a610..21a68ac21d919 100644 --- a/src/auto_scheduler/search_policy/empty_policy.cc +++ b/src/auto_scheduler/search_policy/empty_policy.cc @@ -27,20 +27,28 @@ #include #include +#include + namespace tvm { namespace auto_scheduler { TVM_REGISTER_NODE_TYPE(EmptyPolicyNode); -State EmptyPolicyNode::Search(SearchTask task, int num_measure_trials, int early_stopping, - int num_measures_per_round, int verbose, ProgramMeasurer measurer, - Optional> pre_search_callbacks) { - cur_task = task; +EmptyPolicy::EmptyPolicy(SearchTask task, Optional> init_search_callbacks) { + auto node = make_object(); + node->search_task = task; - // Run pre_search_callbacks before the search process + // Run init_search_callbacks before the search process // This Interface is usually used to set some init status - RunCallbacks(pre_search_callbacks); + if (init_search_callbacks) { + node->RunCallbacks(init_search_callbacks.value()); + } + + data_ = std::move(node); +} +State EmptyPolicyNode::Search(int num_measure_trials, int early_stopping, + int num_measures_per_round, ProgramMeasurer measurer) { // Basic design principe: `SearchOneRound()` several times to get candidate states, // measure them and return the best one // Measure is disabled if num_measure_trials <= 1 @@ -65,14 +73,14 @@ State EmptyPolicyNode::Search(SearchTask task, int num_measure_trials, int early for (const auto& state : res) { // The class members measured_states_set_ provided by SearchPolicy can be used to filter // out the already measured states - inputs.push_back(MeasureInput(cur_task, state)); + inputs.push_back(MeasureInput(search_task, state)); } // ProgramMeasurer will record the state with best performance during measure process - measurer->Measure(cur_task, GetRef(this), inputs, &results); + measurer->Measure(search_task, GetRef(this), inputs, &results); } // Return a state with best measured performance - return measurer->best_state[cur_task->workload_key]; + return measurer->best_state[search_task->workload_key]; } } @@ -81,7 +89,7 @@ Array EmptyPolicyNode::SearchOneRound() { Array res; // 1. We will process `Program sampling` first to generate several initial schedules - res.push_back(cur_task->compute_dag->init_state); + res.push_back(search_task->compute_dag->init_state); // 2. Then `Performance Tuning`: use cost model and evolutionary search to seek for the schedule // with best performance @@ -91,9 +99,10 @@ Array EmptyPolicyNode::SearchOneRound() { return res; } -TVM_REGISTER_GLOBAL("auto_scheduler.EmptyPolicy").set_body_typed([]() { - return EmptyPolicy(make_object()); -}); +TVM_REGISTER_GLOBAL("auto_scheduler.EmptyPolicy") + .set_body_typed([](SearchTask task, Optional> init_search_callbacks) { + return EmptyPolicy(task, init_search_callbacks); + }); } // namespace auto_scheduler } // namespace tvm diff --git a/src/auto_scheduler/search_policy/empty_policy.h b/src/auto_scheduler/search_policy/empty_policy.h index ef7d38ddf1166..3d138220dc0bf 100644 --- a/src/auto_scheduler/search_policy/empty_policy.h +++ b/src/auto_scheduler/search_policy/empty_policy.h @@ -40,9 +40,8 @@ namespace auto_scheduler { */ class EmptyPolicyNode : public SearchPolicyNode { public: - State Search(SearchTask task, int num_measure_trials, int early_stopping, - int num_measures_per_round, int verbose, ProgramMeasurer measurer, - Optional> pre_search_callbacks) final; + State Search(int num_measure_trials, int early_stopping, int num_measures_per_round, + ProgramMeasurer measurer) final; static constexpr const char* _type_key = "auto_scheduler.EmptyPolicy"; TVM_DECLARE_FINAL_OBJECT_INFO(EmptyPolicyNode, SearchPolicyNode); @@ -61,6 +60,8 @@ class EmptyPolicyNode : public SearchPolicyNode { */ class EmptyPolicy : public SearchPolicy { public: + explicit EmptyPolicy(SearchTask task, Optional> init_search_callbacks); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(EmptyPolicy, SearchPolicy, EmptyPolicyNode); }; diff --git a/src/auto_scheduler/search_policy/search_policy.cc b/src/auto_scheduler/search_policy/search_policy.cc index 764b0a7fb97af..f21c8aecda86d 100644 --- a/src/auto_scheduler/search_policy/search_policy.cc +++ b/src/auto_scheduler/search_policy/search_policy.cc @@ -31,21 +31,21 @@ namespace auto_scheduler { TVM_REGISTER_OBJECT_TYPE(SearchCallbackNode); TVM_REGISTER_OBJECT_TYPE(SearchPolicyNode); -void SearchPolicyNode::RunCallbacks(const Optional>& callbacks) { - if (callbacks) { - for (const auto& callback : callbacks.value()) { - callback->Callback(this); - } +void SearchPolicyNode::RunCallbacks(const Array& callbacks) { + for (const auto& callback : callbacks) { + callback->Callback(this); } } TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyRunCallbacks") .set_body_typed([](SearchPolicy policy, Optional> callbacks) { - policy->RunCallbacks(callbacks); + if (callbacks) { + policy->RunCallbacks(callbacks.value()); + } }); TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicySetTask") - .set_body_typed([](SearchPolicy policy, SearchTask task) { policy->cur_task = task; }); + .set_body_typed([](SearchPolicy policy, SearchTask task) { policy->search_task = task; }); TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicySetVerbose") .set_body_typed([](SearchPolicy policy, int verbose) { policy->verbose = verbose; }); diff --git a/src/auto_scheduler/search_policy/sketch_policy.cc b/src/auto_scheduler/search_policy/sketch_policy.cc new file mode 100644 index 0000000000000..450c429f95c8e --- /dev/null +++ b/src/auto_scheduler/search_policy/sketch_policy.cc @@ -0,0 +1,401 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file auto_scheduler/search_policy/sketch_search_policy.h + * \brief The search policy that searches in a hierarchical search space defined by sketches. + * The policy randomly samples programs from the space defined by sketches + * and use evolutionary search to fine-tune them. + */ + +#include "sketch_policy.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "sketch_policy_rules.h" + +namespace tvm { +namespace auto_scheduler { + +/********** Sketch generation rules **********/ + +static RuleSkipStage rule_skip_stage; +static RuleAlwaysInline rule_always_inline; +static RuleMultiLevelTiling rule_multi_level_tiling; +static RuleMultiLevelTilingWithFusion rule_multi_level_tiling_with_fusion; +static RuleAddCacheWrite rule_add_cache_write_stage; +static RuleAddRfactor rule_add_rfactor; +static RuleSimplifyComputeWithConstTensor rule_simplify_compute_with_const_tensor; + +/********** Init population rules **********/ + +static InitFillTileSize init_fill_tile_size; +static InitChangeComputeLocation init_change_compute_location; +static InitParallel init_parallel; +static InitUnroll init_unroll; +static InitVectorization init_vectorization; + +/********** Sketch policy **********/ + +TVM_REGISTER_NODE_TYPE(SketchPolicyNode); + +SketchPolicy::SketchPolicy(SearchTask task, CostModel schedule_cost_model, + Map params, int seed, int verbose, + Optional> init_search_callbacks) { + auto node = make_object(); + node->search_task = std::move(task); + node->schedule_cost_model = std::move(schedule_cost_model); + node->rand_gen = std::mt19937(seed); + node->params = std::move(params); + node->verbose = verbose; + + if (init_search_callbacks) { + PrintTitle("Call init-search callbacks", verbose); + // Candidates: + // - auto_scheduler.PreloadMeasuredStates: Load already measured states to + // `measured_states_set_`, `measured_states_vector_` and `measured_states_throughputs_`. + // - auto_scheduler.PreloadCustomSketchRule: Add user custom sketch rules to `sketch_rules`, + // these rules will be processed prior to the default rules. + node->RunCallbacks(init_search_callbacks.value()); + } + + // The default sketch rules for CPU policy + // Notice: Some rules require us to skip all the rest rules after they are applied. + // So the rules below should be ordered carefully. + node->sketch_rules.push_back(&rule_always_inline); + node->sketch_rules.push_back(&rule_simplify_compute_with_const_tensor); + node->sketch_rules.push_back(&rule_add_rfactor); + node->sketch_rules.push_back(&rule_add_cache_write_stage); + node->sketch_rules.push_back(&rule_multi_level_tiling_with_fusion); + node->sketch_rules.push_back(&rule_multi_level_tiling); + node->sketch_rules.push_back(&rule_skip_stage); // This should always be the last rule + + // The default init population rules for CPU policy + node->init_rules.push_back(&init_fill_tile_size); + node->init_rules.push_back(&init_change_compute_location); + node->init_rules.push_back(&init_parallel); + node->init_rules.push_back(&init_unroll); + node->init_rules.push_back(&init_vectorization); + + data_ = std::move(node); +} + +State SketchPolicyNode::Search(int n_trials, int early_stopping, int num_measure_per_iter, + ProgramMeasurer measurer) { + num_measure_per_iter_ = num_measure_per_iter; + + if (n_trials <= 1) { + // No measurement is allowed + const Array& best_states = SearchOneRound(0); + CHECK_GT(best_states.size(), 0); + return best_states[0]; + } else { + int num_random = + static_cast(GetDoubleParam(params, SketchParamKey::eps_greedy) * num_measure_per_iter); + early_stopping = early_stopping < 0 ? std::numeric_limits::max() >> 1 : early_stopping; + measurer->Reset(); + + int ct = 0; + Array inputs; + Array results; + while (ct < n_trials) { + if (!inputs.empty()) { + // Retrain cost models before the next search round + PrintTitle("Train cost model", verbose); + schedule_cost_model->Update(inputs, results); + } + + // Search one round to get promising states + PrintTitle("Search", verbose); + Array random_states; + Array best_states = SearchOneRound(num_random, &random_states); + + // Infer bound. This is necessary for computing the correct ToStr() for redundancy check + best_states = search_task->compute_dag.InferBound(best_states); + random_states = search_task->compute_dag.InferBound(random_states); + + // Pick `num_measure_per_iter` states to measure, check hash to remove already measured state + // Also pick some random states to do eps-greedy + inputs = PickStatesWithEpsGreedy(best_states, random_states, n_trials - ct); + + // Have traversed all of the search space + if (inputs.empty()) { + StdCout(verbose) << "All candidates in the search space have been measured." << std::endl; + break; + } + + // Measure candidate states + PrintTitle("Measure", verbose); + measurer->Measure(search_task, GetRef(this), inputs, &results); + ct += inputs.size(); + + // Check if reach the early stopping condition + if (ct - measurer->best_ct[search_task->workload_key] > early_stopping) { + StdCout(verbose) << "Stop early since no performance improvement in the last " + << early_stopping << " measure steps.\n"; + break; + } + + // Update measured states throughputs. These states will join the EvolutionarySearch in later + // search rounds. + for (const auto& res : results) { + measured_states_throughputs_.push_back(1.0 / FloatArrayMean(res->costs)); + } + } + PrintTitle("Done", verbose); + + return measurer->best_state[search_task->workload_key]; + } +} + +Array SketchPolicyNode::SearchOneRound(int num_random_states, Array* random_states) { + // Temporal object to be used if the input pointer is nullptr + Array temp_random_states; + if (random_states == nullptr) { + random_states = &temp_random_states; + } else { + random_states->clear(); + } + + // Get parameters + int population = GetIntParam(params, SketchParamKey::EvolutionarySearch::population); + int num_use_measured = + std::min(static_cast(measured_states_vector_.size()), + static_cast( + GetDoubleParam(params, SketchParamKey::EvolutionarySearch::use_measured_ratio) * + population)); + bool is_cost_model_reasonable = !schedule_cost_model->IsInstance(); + + // 1. Generate sketches + const Array& sketches = GenerateSketches(); + + // 2. Sample the init population + Array init_populations = SampleInitPopulation( + sketches, is_cost_model_reasonable ? population - num_use_measured : population); + + // 3. If the cost model is useless (i.e. RandomCostModel), just random pick some generated + // states, else perform evolutionary search + if (is_cost_model_reasonable) { + // Also insert already measured good states to the initial population + std::vector indices = Argsort(measured_states_throughputs_); + for (int i = 0; i < num_use_measured; i++) { + init_populations.push_back(measured_states_vector_[indices[i]]); + } + // Sample some random states for eps-greedy + *random_states = RandomSampleStates(init_populations, &rand_gen, num_random_states * 10); + return EvolutionarySearch(init_populations, num_measure_per_iter_ * 2); + } else { + return RandomSampleStates(init_populations, &rand_gen, num_measure_per_iter_ * 3); + } +} + +Array SketchPolicyNode::GenerateSketches() { + State init_state = search_task->compute_dag->init_state; + + // Two ping pong buffers to avoid copy + Array states_buf1{init_state}, states_buf2; + Array* pnow = &states_buf1; + Array* pnext = &states_buf2; + + // A map that maps state to its current working position (stage_id) + std::unordered_map cur_stage_id_map; + cur_stage_id_map[init_state] = static_cast(init_state->stages.size() - 1); + + // Derivation rule based enumeration + Array out_states; + while (!pnow->empty()) { + pnext->clear(); + for (const State& state : *pnow) { + int stage_id = cur_stage_id_map[state]; + + // Reaches to the terminal stage + if (stage_id < 0) { + out_states.push_back(state); + continue; + } + + // Try all derivation rules + for (const auto& rule : sketch_rules) { + auto cond = rule->MeetCondition(*this, state, stage_id); + if (cond != SketchGenerationRule::ConditionKind::kSkip) { + for (const auto& pair : rule->Apply(*this, state, stage_id)) { + cur_stage_id_map[pair.first] = pair.second; + pnext->push_back(pair.first); + } + // Skip the reset rules + if (cond == SketchGenerationRule::ConditionKind::kApplyAndSkipRest) { + break; + } + } + } + } + std::swap(pnow, pnext); + } + + // Hack for rfactor: Replace the split factor for rfactor to the undefined Expr(), + // so later we can sample random value for the split factor. + // Why don't we use Expr() when doing the split for rfactor at the first time? + // Because during ApplySteps, a rfactor with undefined Expr() will crash TVM. + // So rfactor with undefined Expr() will conflict with cache_write, cache_read, rfactor + // in other stages + for (size_t i = 0; i < out_states.size(); ++i) { + auto state = out_states[i]; + auto pstate = state.CopyOnWrite(); + for (size_t step_id = 0; step_id < pstate->transform_steps.size(); ++step_id) { + if (pstate->transform_steps[step_id]->IsInstance()) { + CHECK_GE(step_id, 1); + int split_step_id = static_cast(step_id - 1); + auto step = pstate->transform_steps[split_step_id].as(); + CHECK(step != nullptr); + pstate->transform_steps.Set( + split_step_id, SplitStep(step->stage_id, step->iter_id, step->extent, {NullOpt}, + step->inner_to_outer)); + } + } + out_states.Set(i, std::move(state)); + } + + StdCout(verbose) << "Generate Sketches\t\t#s: " << out_states.size() << std::endl; + return out_states; +} + +Array SketchPolicyNode::SampleInitPopulation(const Array& sketches, int out_size) { + int fail_ct = 0; + Array out_states; + auto tic_begin = std::chrono::high_resolution_clock::now(); + + // TODO(jcf94, merrymercy): Use parallel_for to run this loop in parallel + while (static_cast(out_states.size()) < out_size && fail_ct < static_cast(out_size)) { + // Random choose a starting sketch + // TODO(jcf94, merrymercy): Maybe choose sketches in different possibility for they may have + // different potential on generating state with better performance + State tmp_s = sketches[(rand_gen)() % sketches.size()]; + + // Derivation rule based enumeration + bool valid = true; + for (const auto& rule : init_rules) { + if (rule->Apply(this, &tmp_s) == InitPopulationRule::ResultKind::kInvalid) { + valid = false; + break; + } + } + + if (valid) { + out_states.push_back(std::move(tmp_s)); + } else { + fail_ct++; + } + } + + double duration = std::chrono::duration_cast>( + std::chrono::high_resolution_clock::now() - tic_begin) + .count(); + StdCout(verbose) << "Sample Initial Population\t#s: " << out_states.size() + << "\tfail_ct: " << fail_ct << "\tTime elapsed: " << std::fixed + << std::setprecision(2) << duration << std::endl; + return out_states; +} + +Array SketchPolicyNode::EvolutionarySearch(const Array& init_populations, + int out_size) { + Array best_states; + auto tic_begin = std::chrono::high_resolution_clock::now(); + + // TODO(comaniac, merrymercy, jcf94): Since we haven't finished porting the cost model part + // yet, currently delete the implementation of EvolutionarySearch. To be added later. + + double duration = std::chrono::duration_cast>( + std::chrono::high_resolution_clock::now() - tic_begin) + .count(); + StdCout(verbose) << "EvolutionarySearch\t\t#s: " << best_states.size() + << "\tTime elapsed: " << std::fixed << std::setprecision(2) << duration + << std::endl; + return best_states; +} + +Array SketchPolicyNode::PickStatesWithEpsGreedy(const Array& best_states, + const Array& random_states, + int remaining_n_trials) { + int num_random = + static_cast(GetDoubleParam(params, SketchParamKey::eps_greedy) * num_measure_per_iter_); + int num_good = num_measure_per_iter_ - num_random; + + Array inputs; + size_t offset_best = 0, offset_random = 0; + + while (static_cast(inputs.size()) < std::min(num_measure_per_iter_, remaining_n_trials)) { + State state; + + bool has_best = offset_best < best_states.size(); + bool has_random = offset_random < random_states.size(); + + if (static_cast(inputs.size()) < num_good) { + // prefer best states + if (has_best) { + state = best_states[offset_best++]; + } else if (has_random) { + state = random_states[offset_random++]; + } else { + break; + } + } else { + // prefer random states + if (has_random) { + state = random_states[offset_random++]; + } else if (has_best) { + state = best_states[offset_best++]; + } else { + break; + } + } + + // Check if it has already been measured + std::string state_str = state.ToStr(); + if (!measured_states_set_.count(state_str)) { + measured_states_set_.insert(std::move(state_str)); + measured_states_vector_.push_back(state); + inputs.push_back(MeasureInput(search_task, state)); + } + } + + return inputs; +} + +TVM_REGISTER_GLOBAL("auto_scheduler.SketchPolicy") + .set_body_typed([](SearchTask task, CostModel schedule_cost_model, + Map params, int seed, int verbose, + Optional> init_search_callbacks) { + return SketchPolicy(task, schedule_cost_model, params, seed, verbose, init_search_callbacks); + }); + +TVM_REGISTER_GLOBAL("auto_scheduler.SketchPolicyGenerateSketches") + .set_body_typed([](SketchPolicy policy) { return policy->GenerateSketches(); }); + +} // namespace auto_scheduler +} // namespace tvm diff --git a/src/auto_scheduler/search_policy/sketch_policy.h b/src/auto_scheduler/search_policy/sketch_policy.h new file mode 100644 index 0000000000000..288de839a1b79 --- /dev/null +++ b/src/auto_scheduler/search_policy/sketch_policy.h @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file auto_scheduler/search_policy/sketch_policy.h + * \brief The search policy that searches in a hierarchical search space defined by sketches. + * The policy randomly samples programs from the space defined by sketches and use evolutionary + * search to fine-tune them. + * + * Reference: + * L. Zheng, C. Jia, M. Sun, Z. Wu, C. Yu, et al. "Ansor : Generating High-Performance Tensor + * Programs for Deep Learning." arXiv preprint arXiv:2006.06762 (2020). + */ + +#ifndef TVM_AUTO_SCHEDULER_SEARCH_POLICY_SKETCH_POLICY_H_ +#define TVM_AUTO_SCHEDULER_SEARCH_POLICY_SKETCH_POLICY_H_ + +#include +#include + +#include +#include +#include +#include +#include + +#include "sketch_policy_rules.h" +#include "utils.h" + +namespace tvm { +namespace auto_scheduler { + +/*! \brief String keys used in parameter map of SketchPolicy. */ +struct SketchParamKey { + /*! \brief Always allocate this percentage of measurements to random sampled states. */ + static constexpr const char* eps_greedy = "eps_greedy"; + + struct EvolutionarySearch { + /*! \brief The population size for evolutionary search. */ + static constexpr const char* population = "evolutionary_search_population"; + /*! \brief The maximum percentage of measured states in the initial population for evolutionary + * search. */ + static constexpr const char* use_measured_ratio = "evolutionary_search_use_measured_ratio"; + }; + + struct MultiLevelTiling { + /*! \brief The structure of multi-level tiling for CPU. */ + static constexpr const char* cpu_structure = "cpu_multi_level_tiling_structure"; + /*! \brief The structure of multi-level tiling for GPU. */ + static constexpr const char* gpu_structure = "gpu_multi_level_tiling_structure"; + }; + + /*! \brief The max inner most split factor. */ + static constexpr const char* max_innermost_split_factor = "max_innermost_split_factor"; + /*! \brief The max vectorize size. */ + static constexpr const char* max_vectorize_size = "max_vectorize_size"; + /*! \brief Whether disable compute location changing. */ + static constexpr const char* disable_change_compute_location = "disable_change_compute_location"; +}; + +/*! + * \brief The search policy that searches in a hierarchical search space defined by sketches. + * The policy randomly samples programs from the space defined by sketches + * and use evolutionary search to fine-tune them. + */ +class SketchPolicyNode : public SearchPolicyNode { + public: + /*! \brief The cost model to estimate the complete schedules. */ + CostModel schedule_cost_model; + /*! \brief The parameters map for this search process. */ + Map params; + /*! \brief The rules to generate sketches. */ + std::vector sketch_rules; + /*! \brief The rules to generate initial states. */ + std::vector init_rules; + /*! \brief Random generator. */ + std::mt19937 rand_gen; + /*! \brief Memorize split space for Split. */ + SplitFactorizationMemo split_memo; + + State Search(int num_measure_trials, int early_stopping, int num_measures_per_round, + ProgramMeasurer measurer) final; + + /*! + * \brief Generate sketches. + * \return The generated sketches(states). + */ + Array GenerateSketches(); + + static constexpr const char* _type_key = "auto_scheduler.SketchPolicy"; + + TVM_DECLARE_FINAL_OBJECT_INFO(SketchPolicyNode, SearchPolicyNode); + + private: + /*! + * \brief Run one round of the search pipeline. + * \param num_random_states Number of states that are picked randomly, this is used for + * eps-greedy policy. + * \param random_states The picked random states, used as one of the output of this function. + * \return The best several states generated in this search round. + */ + Array SearchOneRound(int num_random_states, Array* random_states = nullptr); + + /*! + * \brief Sample init population. + * \param sketches The initial sketches to process population. + * \param out_size The number of expected output states. + * \return The generated states after initial population. + */ + Array SampleInitPopulation(const Array& sketches, int out_size); + + /*! + * \brief Perform evolutionary search. + * \param init_populations The states generated from init population. + * \param out_size The number of expected output states. + * \return The generated states after evolutionary search. + */ + Array EvolutionarySearch(const Array& init_populations, int out_size); + + /*! + * \brief Pick states from best states and random states with eps-greedy policy. + * \param best_states States picked by cost model. + * \param random_states States picked randomly. + * \param remaining_n_trials The remaining number of states need to be generated. + * \return The generated states to be measured, wrapped in MeasureInput. + */ + Array PickStatesWithEpsGreedy(const Array& best_states, + const Array& random_states, + int remaining_n_trials); + + /*! \brief The number of states to measure per iteration. */ + int num_measure_per_iter_; +}; + +/*! + * \brief Managed reference to SketchPolicyNode. + * \sa SketchPolicyNode + */ +class SketchPolicy : public SearchPolicy { + public: + /*! + * \brief The constructor. + * \param task The SearchTask for the computation declaration. + * \param schedule_cost_model The cost model for complete programs. + * \param params The parameters map for this search process. + * \param seed The random seed of this search process. + * \param verbose Verbose level. 0 for silent, 1 to output information during schedule + * search. + * \param init_search_callbacks SearchCallback to be called before schedule search. + */ + SketchPolicy(SearchTask task, CostModel schedule_cost_model, Map params, + int seed, int verbose, Optional> init_search_callbacks); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SketchPolicy, SearchPolicy, SketchPolicyNode); +}; + +} // namespace auto_scheduler +} // namespace tvm + +#endif // TVM_AUTO_SCHEDULER_SEARCH_POLICY_SKETCH_POLICY_H_ diff --git a/src/auto_scheduler/search_policy/sketch_policy_rules.cc b/src/auto_scheduler/search_policy/sketch_policy_rules.cc new file mode 100644 index 0000000000000..587e2c796e91f --- /dev/null +++ b/src/auto_scheduler/search_policy/sketch_policy_rules.cc @@ -0,0 +1,584 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file auto_scheduler/search_policy/sketch_policy_rules.cc + * \brief Rules defined to generate the sketches and initial sampled states in SketchPolicy. + */ + +#include "sketch_policy_rules.h" + +#include +#include +#include +#include + +#include "sketch_policy.h" + +namespace tvm { +namespace auto_scheduler { + +/********** Sketch Generation Rule **********/ +/********** RuleSkipStage **********/ + +SketchGenerationRule::ConditionKind RuleSkipStage::MeetCondition(const SketchPolicyNode& policy, + const State& state, int stage_id) { + // This rule should be the last rule, always return true to decrease the stage index count + return ConditionKind::kApply; +} + +std::vector> RuleSkipStage::Apply(const SketchPolicyNode& policy, + const State& state, int stage_id) const { + return {std::make_pair(state, stage_id - 1)}; +} + +/********** RuleAlwaysInline **********/ + +SketchGenerationRule::ConditionKind RuleAlwaysInline::MeetCondition(const SketchPolicyNode& policy, + const State& state, + int stage_id) { + const Stage& stage = state->stages[stage_id]; + // Check the inline limitation of TE first + if (stage->op_type == StageKind::kPlaceholder || + IsOutputOp(policy.search_task, state, stage_id) || HasReduceIter(stage)) { + return ConditionKind::kSkip; + } + + // TODO(jcf94): Greedily inline all inlinable ops on GPU when introducing GPU search policy. + return IsStrictlyInlineable(policy.search_task, state, stage_id) + ? ConditionKind::kApplyAndSkipRest + : ConditionKind::kSkip; +} + +std::vector> RuleAlwaysInline::Apply(const SketchPolicyNode& policy, + const State& state, int stage_id) const { + State tmp_s = state; + tmp_s.compute_inline(stage_id); + return {std::make_pair(std::move(tmp_s), stage_id - 1)}; +} + +/********** RuleMultiLevelTiling **********/ + +SketchGenerationRule::ConditionKind RuleMultiLevelTiling::MeetCondition( + const SketchPolicyNode& policy, const State& state, int stage_id) { + return NeedsMultilevelTiling(policy.search_task, state, stage_id) + ? ConditionKind::kApplyAndSkipRest + : ConditionKind::kSkip; +} + +std::vector> RuleMultiLevelTiling::Apply(const SketchPolicyNode& policy, + const State& state, + int stage_id) const { + // TODO(jcf94): Add support for GPU structure when introducing GPU search policy. + const std::string& multi_level_tiling_structure = + GetStringParam(policy.params, SketchParamKey::MultiLevelTiling::cpu_structure); + State tmp_s = DoMultiLevelTiling(state, stage_id, multi_level_tiling_structure); + return {std::make_pair(std::move(tmp_s), stage_id - 1)}; +} + +/********** RuleMultiLevelTilingWithFusion **********/ + +SketchGenerationRule::ConditionKind RuleMultiLevelTilingWithFusion::MeetCondition( + const SketchPolicyNode& policy, const State& state, int stage_id) { + if (NeedsMultilevelTiling(policy.search_task, state, stage_id) && + HasSingleElementwiseMatchedConsumer(policy.search_task, state, stage_id, &target_stage_id)) { + // Always do fusion for stage with cache_write + // TODO(jcf94): Always do fusion on GPU when introducing GPU search policy. + return HasCacheWriteStage(state, stage_id) ? ConditionKind::kApplyAndSkipRest + : ConditionKind::kApply; + } + return ConditionKind::kSkip; +} + +std::vector> RuleMultiLevelTilingWithFusion::Apply( + const SketchPolicyNode& policy, const State& state, int stage_id) const { + // TODO(jcf94): Add support for GPU structure when introducing GPU search policy. + const std::string& multi_level_tiling_structure = + GetStringParam(policy.params, SketchParamKey::MultiLevelTiling::cpu_structure); + std::vector spatial_split_step_ids; + State base_state = + DoMultiLevelTiling(state, stage_id, multi_level_tiling_structure, &spatial_split_step_ids); + + std::vector> ret; + // TODO(jcf94): Add follow_tiling_levels for GPU when introducing GPU search policy. + std::vector follow_tiling_levels{1, 2}; + for (int level : follow_tiling_levels) { + if (tolower(multi_level_tiling_structure[level - 1]) != 's') { + continue; + } + State tmp_s = base_state; + tmp_s = FollowTiling(tmp_s, target_stage_id, spatial_split_step_ids, level); + const Iterator& target_iter = + tmp_s->stages[target_stage_id]->iters[level * spatial_split_step_ids.size() - 1]; + tmp_s.compute_at(stage_id, target_stage_id, target_iter); + ret.emplace_back(std::move(tmp_s), stage_id - 1); + } + + return ret; +} + +/********** RuleAddCacheWrite **********/ + +SketchGenerationRule::ConditionKind RuleAddCacheWrite::MeetCondition(const SketchPolicyNode& policy, + const State& state, + int stage_id) { + // Add cache write if a stage needs multi-level tiling, but does not have a element-wise + // matched consumer + if (NeedsMultilevelTiling(policy.search_task, state, stage_id) && + !HasSingleElementwiseMatchedConsumer(policy.search_task, state, stage_id)) { + // An apply and skip rule will be handled in RuleMultiLevelTilingWithFusion + // TODO(jcf94): Always do cache_write on GPU when introducing GPU search policy. + return ConditionKind::kApply; + } + + return ConditionKind::kSkip; +} + +std::vector> RuleAddCacheWrite::Apply(const SketchPolicyNode& policy, + const State& state, + int stage_id) const { + State tmp_s = state; + tmp_s.cache_write(stage_id, "local", policy.search_task->compute_dag); + return {std::make_pair(std::move(tmp_s), stage_id)}; +} + +/********** RuleAddRfactor **********/ + +SketchGenerationRule::ConditionKind RuleAddRfactor::MeetCondition(const SketchPolicyNode& policy, + const State& state, + int stage_id) { + return (NeedsRfactor(policy.search_task, state, stage_id) && !HasCacheWriteStage(state, stage_id)) + ? ConditionKind::kApply + : ConditionKind::kSkip; +} + +std::vector> RuleAddRfactor::Apply(const SketchPolicyNode& policy, + const State& state, int stage_id) const { + // Fuse all reduction iters + Array space_iters, reduce_iters; + Iterator fused_reduce_iter; + State base_state = + FuseAllReductionIterators(state, stage_id, &fused_reduce_iter, &space_iters, &reduce_iters); + + // TODO(merrymercy): We can do more analysis here to generate less and more efficient sketches. + // In some cases, we only need rfactor for more parallel + // In some cases, we only need rfactor for vectorization. + // Now we will generate two versions and let the search figure out the bette one. + + // Split reduction iters + const auto& split_res = base_state.split(stage_id, fused_reduce_iter, {Integer(1)}); + int factor_axis_id = static_cast(space_iters.size()); + std::vector> ret; + for (const auto& split_iter : split_res) { + State tmp_s = base_state; + int rstage_id = + tmp_s.rfactor(stage_id, split_iter, factor_axis_id, policy.search_task->compute_dag); + + // reorder the space iterator to innermost for vectorization + if (split_iter == split_res[1]) { + Array new_order; + for (size_t i = 0; i < tmp_s->stages[rstage_id]->iters.size(); ++i) { + if (i != space_iters.size()) { + new_order.push_back(tmp_s->stages[rstage_id]->iters[i]); + } + } + new_order.push_back(tmp_s->stages[rstage_id]->iters[space_iters.size()]); + tmp_s.reorder(rstage_id, new_order); + } + + ret.emplace_back(std::move(tmp_s), rstage_id - 1); + } + + return ret; +} + +/********** RuleSimplifyComputeWithConstTensor **********/ + +SketchGenerationRule::ConditionKind RuleSimplifyComputeWithConstTensor::MeetCondition( + const SketchPolicyNode& policy, const State& state, int stage_id) { + return state->stages[stage_id]->op->attrs.count(SearchPolicyKey::simplify_const_tensor_indices) + ? ConditionKind::kApplyAndSkipRest + : ConditionKind::kSkip; +} + +std::vector> RuleSimplifyComputeWithConstTensor::Apply( + const SketchPolicyNode& policy, const State& state, int stage_id) const { + std::set const_tensor_indices = GetIterNameSetParam( + state->stages[stage_id]->op->attrs, SearchPolicyKey::simplify_const_tensor_indices); + + State tmp_s = state; + Array> tiled_outer_iters; + Array unrolled_inner_iters; + + // Currently set to 2 + size_t tile_level = 2; + + for (const auto& iter : state->stages[stage_id]->iters) { + if (const_tensor_indices.count(iter->name)) { + // unroll indices of const tensors + unrolled_inner_iters.push_back(tmp_s.unroll(stage_id, iter)); + } else { + // tile other space indices + CHECK(iter->iter_kind == IteratorKind::kSpatial); + tiled_outer_iters.push_back( + tmp_s.split(stage_id, iter, Array>(tile_level - 1, NullOpt))); + } + } + + // reorder them + Array new_order; + for (size_t i = 0; i < tile_level; ++i) { + for (size_t j = 0; j < tiled_outer_iters.size(); ++j) { + new_order.push_back(tiled_outer_iters[j][i]); + } + } + new_order.insert(new_order.end(), unrolled_inner_iters.begin(), unrolled_inner_iters.end()); + tmp_s.reorder(stage_id, new_order); + + return {std::make_pair(tmp_s, stage_id - 1)}; +} + +/********** Init Population **********/ + +InitPopulationRule::ResultKind InitFillTileSize::Apply(SketchPolicyNode* policy, + State* state) const { + StateNode* pstate = state->CopyOnWrite(); + // Scan the transformation history and randomly fill tiles size for all SplitStep + for (size_t step_id = 0; step_id < (*state)->transform_steps.size(); ++step_id) { + if (auto ps = (*state)->transform_steps[step_id].as()) { + bool all_defined = true; + for (const auto& len : ps->lengths) { + if (!len) { + all_defined = false; + break; + } + } + if (all_defined) { + continue; + } + + CHECK(ps->extent); + int extent = GetIntImm(ps->extent.value()); + const auto& candidate_lens = policy->split_memo.GetFactorizationSchemes( + extent, ps->lengths.size(), + GetIntParam(policy->params, SketchParamKey::max_innermost_split_factor)); + const auto& candidate_lengths = candidate_lens[(policy->rand_gen)() % candidate_lens.size()]; + + pstate->transform_steps.Set( + step_id, + SplitStep(ps->stage_id, ps->iter_id, ps->extent, + Array>(candidate_lengths.begin(), candidate_lengths.end()), + ps->inner_to_outer)); + } + } + pstate->concrete = true; + + return ResultKind::kValid; +} + +InitPopulationRule::ResultKind InitChangeComputeLocation::Apply(SketchPolicyNode* policy, + State* state) const { + if (GetIntParam(policy->params, SketchParamKey::disable_change_compute_location)) { + return ResultKind::kValid; + } + + for (int stage_id = static_cast((*state)->stages.size()) - 1; stage_id >= 0; stage_id--) { + const Stage& stage = (*state)->stages[stage_id]; + // Skip the inlined stages and placeholders + if (stage->op_type == StageKind::kPlaceholder || stage->compute_at == ComputeAtKind::kInlined) { + continue; + } + // Skip the tiled stages + if (IsTiled(stage) || NeedsMultilevelTiling(policy->search_task, *state, stage_id)) { + continue; + } + + int target_stage_id = GetSingleConsumerId(policy->search_task, *state, stage_id); + if (target_stage_id < 0) { + continue; + } + const Stage& target_stage = (*state)->stages[target_stage_id]; + + std::vector> candidates; + bool target_compute_at_other = target_stage->compute_at == ComputeAtKind::kIter; + bool target_is_tiled = IsTiled(target_stage); + + bool visited_reduce = false; + // enumerate compute_at location at target_stage + // TODO(merrymercy): More analysis here to make smarter choices + for (size_t i = 0; i < target_stage->iters.size(); ++i) { + const Iterator& target_iter = target_stage->iters[i]; + if (target_iter->iter_kind == IteratorKind::kReduction) { + visited_reduce = true; + if (!target_is_tiled) { // Do not go into reduce iter + break; + } + } else if (target_iter->iter_kind == IteratorKind::kSpatial) { + if (visited_reduce) { // Do not go into inner tile + break; + } + } + + if (target_iter->annotation == IteratorAnnotation::kUnroll) { + // Do not go into the unroll region of const tensor indices + break; + } + + if (GetExtent(target_iter) == 1) { + // Skip iterators with length of 1 + continue; + } + if (target_compute_at_other && target_iter->iter_kind == IteratorKind::kSpatial && + StrEndsWith(target_iter->name, ".0")) { + // Skip the first level iterators if target stage compute_at another stage + // In this case, the lengths of first level iterators are always one + continue; + } + candidates.emplace_back(target_stage_id, i); + + if ((*state)->attach_map->iter_to_attached_stages.count(std::make_pair(target_stage_id, i))) { + break; + } + } + + // if the target_stage is already compute_at another stage X, try also compute_at X + // We call stage X as `target_target_stage` + if (target_compute_at_other) { + int target_target_stage_id; + target_target_stage_id = (*state)->attach_map->stage_to_attach_iter.at(target_stage_id).first; + const Stage& target_target_stage = (*state)->stages[target_target_stage_id]; + + for (size_t i = 0; i < target_target_stage->iters.size(); ++i) { + const Iterator& target_target_iter = target_target_stage->iters[i]; + if (target_target_iter->iter_kind == IteratorKind::kReduction || + (*state)->attach_map->iter_to_attached_stages.count( + std::make_pair(target_target_stage_id, i))) { + break; + } + + if (target_target_iter->annotation == IteratorAnnotation::kUnroll) { + // Do not go into the unroll region of const tensor indices + break; + } + + if (GetExtent(target_target_iter) == 1) { // skip iterators with length of 1 + continue; + } + + candidates.emplace_back(target_target_stage_id, i); + } + } + + int choice = (policy->rand_gen)() % (candidates.size() + 2); + + if (choice == 0) { + if (!HasReduceIter(stage)) { + const auto& stage_to_attach_iter = (*state)->attach_map->stage_to_attach_iter; + if (stage_to_attach_iter.find(stage_id) != stage_to_attach_iter.end()) { + state->compute_inline(stage_id); + } + } + } else if (choice == 1) { + state->compute_root(stage_id); + } else { + choice = choice - 2; + const Stage& stage = (*state)->stages[candidates[choice].first]; + state->compute_at(stage_id, candidates[choice].first, + stage->iters[candidates[choice].second]); + } + } + + *state = policy->search_task->compute_dag.InferBound(*state); + return ResultKind::kValid; +} + +InitPopulationRule::ResultKind InitParallel::Apply(SketchPolicyNode* policy, State* state) const { + std::function + annotate_parallel; + annotate_parallel = [&annotate_parallel](const SketchPolicyNode& policy, State* state, + int stage_id, int iter_offset) { + const Stage& stage = (*state)->stages[stage_id]; + + Array to_fuse; + int64_t parallel_degree = 1; + + // Try to fuse and parallel the outermost n iterators + // Stop if we meet reduce iterator or we have enough parallel degree + size_t iter_id = iter_offset; + for (; iter_id < stage->iters.size(); ++iter_id) { + const Iterator& it = stage->iters[iter_id]; + if (it->iter_kind == IteratorKind::kReduction || + it->annotation != IteratorAnnotation::kNone) { + break; + } + to_fuse.push_back(it); + parallel_degree *= GetExtent(it); + + if (parallel_degree > policy.search_task->hardware_params->num_cores * 16) { + break; + } + + if ((*state)->attach_map->iter_to_attached_stages.count(std::make_pair(stage_id, iter_id))) { + break; + } + } + + if (parallel_degree == 1) { + auto res = + (*state)->attach_map->iter_to_attached_stages.find(std::make_pair(stage_id, iter_id)); + if (res != (*state)->attach_map->iter_to_attached_stages.end()) { + for (int attached_stage_id : res->second) { + annotate_parallel(policy, state, attached_stage_id, 0); + } + annotate_parallel(policy, state, stage_id, iter_id + 1); + } + } + + if (!to_fuse.empty()) { + if (to_fuse.size() == 1) { + state->parallel(stage_id, to_fuse[0]); + } else { + Iterator fused_iter = state->fuse(stage_id, to_fuse); + state->parallel(stage_id, fused_iter); + } + } + }; + + for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) { + const Stage& stage = (*state)->stages[stage_id]; + if (stage->compute_at != ComputeAtKind::kRoot || stage->op_type == StageKind::kPlaceholder) { + continue; + } + + annotate_parallel(*policy, state, stage_id, 0); + } + + return ResultKind::kValid; +} + +InitPopulationRule::ResultKind InitUnroll::Apply(SketchPolicyNode* policy, State* state) const { + std::vector auto_unroll_configs = {0, 16, 64, 512}; + for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) { + const Stage& stage = (*state)->stages[stage_id]; + // Skip the inlined stage and placeholder stage + if (stage->compute_at == ComputeAtKind::kInlined || stage->op_type == StageKind::kPlaceholder) { + continue; + } + + // Handle always_unroll_inner attr + if (stage->op->attrs.count(SearchPolicyKey::always_unroll_inner)) { + const auto& to_unroll_name_set = + GetIterNameSetParam(stage->op->attrs, SearchPolicyKey::always_unroll_inner); + + // Unroll the space iterators and reduce iterators listed in the attrs in the innermost + // tile + std::set visited_names; + for (int n = static_cast(stage->iters.size()) - 1; n >= 0; n--) { + const Iterator& it = stage->iters[n]; + + // If we meet two iterators that come from a same original iterator, + // then we are out of the innermost tile + size_t size_before = visited_names.size(); + ExtractOriginalIterators(it->name, &visited_names); + if (size_before == visited_names.size()) { + break; + } + + std::set name; + ExtractOriginalIterators(it->name, &name); + if (name.size() == 1 && to_unroll_name_set.count(*name.begin())) { + if (it->annotation == IteratorAnnotation::kNone) { + state->unroll(stage_id, it); + } + } + } + } + + if (HasReduceIter(stage)) { + // Use auto unroll for multi level tiled stage + int value = auto_unroll_configs[(policy->rand_gen)() % auto_unroll_configs.size()]; + state->pragma(stage_id, (*state)->stages[stage_id]->iters[0], + std::string("auto_unroll_max_step") + "$" + std::to_string(value)); + } + } + + return ResultKind::kValid; +} + +InitPopulationRule::ResultKind InitVectorization::Apply(SketchPolicyNode* policy, + State* state) const { + for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) { + const Stage& stage = (*state)->stages[stage_id]; + // Skip the inlined stage and placeholder stage + if (stage->compute_at == ComputeAtKind::kInlined || stage->op_type == StageKind::kPlaceholder) { + continue; + } + + // Try to fuse and vectorize the space iterators in the inner most tile + int64_t cum_length_prod = 1; + + int num_fusible = 0; + while (num_fusible < static_cast(stage->iters.size())) { + int iter_id = static_cast(stage->iters.size()) - 1 - num_fusible; + // Stop if this iterator has been a compute at attatch point + if ((*state)->attach_map->iter_to_attached_stages.count(std::make_pair(stage_id, iter_id))) { + break; + } + + const Iterator& it = stage->iters[iter_id]; + // Stop if we meet a reduce iterator or annotated iterator + if (it->iter_kind == IteratorKind::kReduction || + it->annotation != IteratorAnnotation::kNone) { + break; + } + + // Stop if the memory access is not continuous (vectorizable) + // Note: The check is too hard, so we use heuristic here + if (IsTiled(stage) && num_fusible != 0) { + // If the stage is tiled, then the memory access must not be continuous + // for the innermost two iterators + break; + } + + cum_length_prod *= GetExtent(it); + if (cum_length_prod > GetIntParam(policy->params, SketchParamKey::max_vectorize_size)) { + break; + } + + num_fusible++; + } + + if (num_fusible > 1) { + // Select a random range to fuse + num_fusible = 1 + (policy->rand_gen)() % (num_fusible - 1); + } + + if (num_fusible == 1) { + state->vectorize(stage_id, stage->iters.back()); + } else if (num_fusible > 1) { + Array to_fuse(stage->iters.end() + (-num_fusible), stage->iters.end()); + state->vectorize(stage_id, state->fuse(stage_id, to_fuse)); + } + } + + return ResultKind::kValid; +} + +} // namespace auto_scheduler +} // namespace tvm diff --git a/src/auto_scheduler/search_policy/sketch_policy_rules.h b/src/auto_scheduler/search_policy/sketch_policy_rules.h new file mode 100644 index 0000000000000..dac186ddf81fe --- /dev/null +++ b/src/auto_scheduler/search_policy/sketch_policy_rules.h @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file auto_scheduler/search_policy/sketch_policy_rules.h + * \brief Rules defined to generate the sketches and initial sampled states in SketchPolicy. + */ + +#ifndef TVM_AUTO_SCHEDULER_SEARCH_POLICY_SKETCH_POLICY_RULES_H_ +#define TVM_AUTO_SCHEDULER_SEARCH_POLICY_SKETCH_POLICY_RULES_H_ + +#include + +#include +#include + +namespace tvm { +namespace auto_scheduler { + +class SketchPolicyNode; + +/********** Sketch Generation Rule **********/ + +/*! \brief The base class for derivation rules used in the sketch generation. */ +class SketchGenerationRule { + public: + /*! \brief Result enumeration of the condition function. */ + enum class ConditionKind : int { + /*! \brief Skip this rule and continue to try the next rules. */ + kSkip = 0, + /*! \brief Apply this rule and continue to try the next rules. */ + kApply = 1, + /*! \brief Apply this rule and skip the rest rules. */ + kApplyAndSkipRest = 2 + }; + + /*! + * \brief Condition check function of this rule. + * \param policy The SketchPolicyNode of this rule, some information may be used during + * the condition checking. + * \param state The original state to be checked. + * \param stage_id The index of the stage to process this condition check. + * \return The condition check result of this rule. + */ + virtual ConditionKind MeetCondition(const SketchPolicyNode& policy, const State& state, + int stage_id) = 0; + + /*! + * \brief Apply function of this rule. + * \param policy The SketchPolicyNode of this rule, some information may be used during + * the rule applying. + * \param state The original state to apply this rule. + * \param stage_id The index of the next stage to apply this rule. + * \return The state after applying this rule, and index of the next stage. + */ + virtual std::vector> Apply(const SketchPolicyNode& policy, + const State& state, int stage_id) const = 0; +}; + +/*! \brief The rule that simply skips the current stage. It returns an unchanged state and move to + * the next stage. */ +class RuleSkipStage : public SketchGenerationRule { + public: + ConditionKind MeetCondition(const SketchPolicyNode& policy, const State& state, + int stage_id) final; + + std::vector> Apply(const SketchPolicyNode& policy, const State& state, + int stage_id) const final; +}; + +/*! \brief The rule that inlines simple elementwise ops. + * \note This rule only inlines the strictly inlineable stages. Stages marked as not strictly + * inlineable will have a chance to try different compute at location in InitPopulation later. + */ +class RuleAlwaysInline : public SketchGenerationRule { + public: + ConditionKind MeetCondition(const SketchPolicyNode& policy, const State& state, + int stage_id) final; + + std::vector> Apply(const SketchPolicyNode& policy, const State& state, + int stage_id) const final; +}; + +/*! \brief The rule that performs multi-level tiling. */ +class RuleMultiLevelTiling : public SketchGenerationRule { + public: + ConditionKind MeetCondition(const SketchPolicyNode& policy, const State& state, + int stage_id) final; + + std::vector> Apply(const SketchPolicyNode& policy, const State& state, + int stage_id) const final; +}; + +/*! The rule that performs multi-level tiling and fuses later consumers. */ +class RuleMultiLevelTilingWithFusion : public SketchGenerationRule { + public: + ConditionKind MeetCondition(const SketchPolicyNode& policy, const State& state, + int stage_id) final; + + std::vector> Apply(const SketchPolicyNode& policy, const State& state, + int stage_id) const final; + + private: + int target_stage_id; +}; + +/*! \brief The rule that adds a cache write stage. */ +class RuleAddCacheWrite : public SketchGenerationRule { + public: + ConditionKind MeetCondition(const SketchPolicyNode& policy, const State& state, + int stage_id) final; + + std::vector> Apply(const SketchPolicyNode& policy, const State& state, + int stage_id) const final; +}; + +/*! \brief The rule that adds rfactor stage. */ +class RuleAddRfactor : public SketchGenerationRule { + public: + ConditionKind MeetCondition(const SketchPolicyNode& policy, const State& state, + int stage_id) final; + + std::vector> Apply(const SketchPolicyNode& policy, const State& state, + int stage_id) const final; +}; + +/*! \brief The rule that deals with compute ops that perform "fake reduction" with const tensors. + * This kind of op comes from winograd transformation. + */ +class RuleSimplifyComputeWithConstTensor : public SketchGenerationRule { + public: + ConditionKind MeetCondition(const SketchPolicyNode& policy, const State& state, + int stage_id) final; + + std::vector> Apply(const SketchPolicyNode& policy, const State& state, + int stage_id) const final; +}; + +/********** Init Population **********/ + +/*! \brief The base class for derivation rules used in the initial population. */ +class InitPopulationRule { + public: + /*! \brief Result enumeration of the apply function. */ + enum class ResultKind : int { kValid = 0, kInvalid = 1 }; + + /*! + * \brief Apply function of this rule. + * \param policy The SketchPolicyNode of this rule, some member may get changed during the + * rule applying. (e.g. random number generator) + * \param state The state to apply this rule, update inplace. + * \return The result of this rule, indicate if there's any valid state generated. + */ + virtual ResultKind Apply(SketchPolicyNode* policy, State* state) const = 0; +}; + +/*! \brief The rule that fills the incomplete SplitSteps. */ +class InitFillTileSize : public InitPopulationRule { + public: + ResultKind Apply(SketchPolicyNode* policy, State* state) const final; +}; + +/*! \brief The rule that randomly changes the computation location for some stages, which do not + * need tiling and are not strictly inlineable(e.g. data padding). */ +class InitChangeComputeLocation : public InitPopulationRule { + public: + ResultKind Apply(SketchPolicyNode* policy, State* state) const final; +}; + +/*! \brief The rule that annotates parallel for CPU. */ +class InitParallel : public InitPopulationRule { + public: + ResultKind Apply(SketchPolicyNode* policy, State* state) const final; +}; + +/*! \brief The rule that annotates unroll. */ +class InitUnroll : public InitPopulationRule { + public: + ResultKind Apply(SketchPolicyNode* policy, State* state) const final; +}; + +/*! \brief The rule that annotates vectorization. */ +class InitVectorization : public InitPopulationRule { + public: + ResultKind Apply(SketchPolicyNode* policy, State* state) const final; +}; + +} // namespace auto_scheduler +} // namespace tvm + +#endif // TVM_AUTO_SCHEDULER_SEARCH_POLICY_SKETCH_POLICY_RULES_H_ diff --git a/src/auto_scheduler/search_policy/utils.cc b/src/auto_scheduler/search_policy/utils.cc new file mode 100644 index 0000000000000..6c2e68d09f2d2 --- /dev/null +++ b/src/auto_scheduler/search_policy/utils.cc @@ -0,0 +1,286 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file auto_scheduler/utils.cc + * \brief Common utilities + */ + +#include "utils.h" + +#include + +namespace tvm { +namespace auto_scheduler { + +State DoMultiLevelTiling(const State& state, int stage_id, const std::string& format, + std::vector* spatial_split_step_ids) { + // Temporal object to be used if the input pointer is nullptr + std::vector temp_split_step_ids; + if (spatial_split_step_ids == nullptr) { + spatial_split_step_ids = &temp_split_step_ids; + } + std::vector> space_levels; + std::vector> reduce_levels; + std::vector space_outer, space_inner, reduce_outer, reduce_inner; + Array split_res; + + for (const auto c : format) { + if (tolower(c) == 's') { + space_levels.emplace_back(); + } else if (tolower(c) == 'r') { + reduce_levels.emplace_back(); + } else { + LOG(FATAL) << "Invalid multi-level tiling format: " << format; + } + } + size_t n_space = space_levels.size(); + size_t n_reduce = reduce_levels.size(); + + spatial_split_step_ids->clear(); + + State tmp_s = state; + const Stage& stage = state->stages[stage_id]; + const std::set& no_split_at_inner_name_set = + stage->op->attrs.count(SearchPolicyKey::no_split_at_inner) + ? GetIterNameSetParam(stage->op->attrs, SearchPolicyKey::no_split_at_inner) + : std::set(); + + for (const auto& iter : state->stages[stage_id]->iters) { + if (!no_split_at_inner_name_set.count(iter->name)) { + if (iter->iter_kind == IteratorKind::kSpatial) { + CHECK_GE(n_space, 1); + + if (n_space == 1) { + space_levels[0].push_back(iter); + } else { + split_res = tmp_s.split(stage_id, iter, Array>(n_space - 1, NullOpt)); + for (size_t i = 0; i < n_space; i++) { + space_levels[i].push_back(split_res[i]); + } + spatial_split_step_ids->push_back(tmp_s->transform_steps.size() - 1); + } + } else if (iter->iter_kind == IteratorKind::kReduction) { + CHECK_GE(n_reduce, 1); + + if (n_reduce == 1) { + reduce_levels[0].push_back(iter); + } else { + split_res = tmp_s.split(stage_id, iter, Array>(n_reduce - 1, NullOpt)); + for (size_t i = 0; i < n_reduce; i++) { + reduce_levels[i].push_back(split_res[i]); + } + } + } else { + LOG(FATAL) << "Invalid iter type: " << int(iter->iter_kind); + } + } else { + if (iter->iter_kind == IteratorKind::kSpatial) { + space_inner.push_back(iter); + } else if (iter->iter_kind == IteratorKind::kReduction) { + reduce_inner.push_back(iter); + } else { + LOG(FATAL) << "Invalid iter type: " << int(iter->iter_kind); + } + } + } + + if (!space_outer.empty()) { + CHECK(!space_levels.empty()); + space_levels.front().insert(space_levels.front().begin(), + std::make_move_iterator(space_outer.begin()), + std::make_move_iterator(space_outer.end())); + } + if (!space_inner.empty()) { + CHECK(!space_levels.empty()); + space_levels.back().insert(space_levels.back().begin(), + std::make_move_iterator(space_inner.begin()), + std::make_move_iterator(space_inner.end())); + } + + if (!reduce_outer.empty()) { + CHECK(!reduce_levels.empty()); + reduce_levels.front().insert(reduce_levels.front().begin(), + std::make_move_iterator(reduce_outer.begin()), + std::make_move_iterator(reduce_outer.end())); + } + if (!reduce_inner.empty()) { + CHECK(!reduce_levels.empty()); + reduce_levels.back().insert(reduce_levels.back().begin(), + std::make_move_iterator(reduce_inner.begin()), + std::make_move_iterator(reduce_inner.end())); + } + + Array order; + int space_ct = 0, reduce_ct = 0; + for (const auto c : format) { + if (tolower(c) == 's') { + order.insert(order.end(), std::make_move_iterator(space_levels[space_ct].begin()), + std::make_move_iterator(space_levels[space_ct].end())); + space_ct++; + } else if (tolower(c) == 'r') { + order.insert(order.end(), std::make_move_iterator(reduce_levels[reduce_ct].begin()), + std::make_move_iterator(reduce_levels[reduce_ct].end())); + reduce_ct++; + } else { + LOG(FATAL) << "Invalid multi level tiling format: " << format; + } + } + + tmp_s.reorder(stage_id, order); + return tmp_s; +} + +State FollowTiling(const State& state, int stage_id, const std::vector& split_step_ids, + int n_split) { + if (n_split < 1 || n_split > 3) { + LOG(FATAL) << "Invalid split parts, currently only support 1, 2 and 3"; + } + // Apply up to three-level tiling structure: space_L0, space_L1, space_L2 + std::vector space_0, space_1, space_2, space_3, tmp_order; + Array split_res; + + auto pop = state->stages[stage_id]->op.as(); + CHECK(pop != nullptr); + const Stage& stage = state->stages[stage_id]; + const std::set& no_split_at_inner_name_set = + stage->op->attrs.count(SearchPolicyKey::no_split_at_inner) + ? GetIterNameSetParam(stage->op->attrs, SearchPolicyKey::no_split_at_inner) + : std::set(); + int no_split_at_inner_name_in_stage_cnt = 0; + for (const auto& iter : state->stages[stage_id]->iters) { + no_split_at_inner_name_in_stage_cnt += no_split_at_inner_name_set.count(iter->name); + } + + CHECK_EQ(state->stages[stage_id]->iters.size() - no_split_at_inner_name_in_stage_cnt, + split_step_ids.size()); + + State tmp_s = state; + int ct = 0; + for (const auto& iter : state->stages[stage_id]->iters) { + if (iter->iter_kind == IteratorKind::kSpatial) { + // For spatial iterator, split it into multi iterators + if (!no_split_at_inner_name_set.count(iter->name)) { + IteratorAnnotation ann_type = iter->annotation; + split_res = tmp_s.follow_split(stage_id, iter, split_step_ids[ct], n_split); + // Restore annotation. Move unroll and vectorize to inner, move parallel + // to outer + switch (ann_type) { + case IteratorAnnotation::kUnroll: + split_res.Set(n_split, tmp_s.unroll(stage_id, split_res[n_split])); + break; + case IteratorAnnotation::kVectorize: + split_res.Set(n_split, tmp_s.vectorize(stage_id, split_res[n_split])); + break; + case IteratorAnnotation::kParallel: + split_res.Set(0, tmp_s.parallel(stage_id, split_res[0])); + break; + default: + break; + } + + space_0.push_back(split_res[0]); + space_1.push_back(split_res[1]); + if (n_split >= 2) { + space_2.push_back(split_res[2]); + if (n_split == 3) { + space_3.push_back(split_res[3]); + } + } + ct++; + } else { + if (no_split_at_inner_name_set.count(iter->name)) { + if (n_split == 1) { + space_1.push_back(iter); + } else if (n_split == 2) { + space_2.push_back(iter); + } else { + CHECK_EQ(n_split, 3); + space_3.push_back(iter); + } + } + } + } else { + LOG(FATAL) << "Invalid iter type: " << int(iter->iter_kind); + } + } + + if (n_split == 3) { + ConcatenateMove(&tmp_order, &space_0, &space_1, &space_2, &space_3); + } else if (n_split == 2) { + ConcatenateMove(&tmp_order, &space_0, &space_1, &space_2); + } else { + ConcatenateMove(&tmp_order, &space_0, &space_1); + } + tmp_s.reorder(stage_id, tmp_order); + return tmp_s; +} + +const Array>& SplitFactorizationMemo::GetFactorizationSchemes( + int extent, int n_lengths, int max_innermost_factor) { + QueryKey key = std::make_tuple(extent, n_lengths, max_innermost_factor); + auto it = memory_.find(key); + if (it != memory_.end()) { + return it->second; + } + + tmp_stack_ = Array(n_lengths, Integer()); + results_ = &memory_[key]; + n_lengths_ = n_lengths; + + DfsEnumerate(0, extent, max_innermost_factor); + + return *results_; +} + +void SplitFactorizationMemo::DfsEnumerate(int now, int remaining_lenght, int max_innermost_factor) { + if (now == n_lengths_) { + if (tmp_stack_.back().as()->value <= max_innermost_factor) { + results_->push_back(tmp_stack_); + } + } else { + for (const auto& f : GetFactors(remaining_lenght)) { + tmp_stack_.Set(now, Integer(f)); + DfsEnumerate(now + 1, remaining_lenght / f, max_innermost_factor); + } + } +} + +const std::vector& SplitFactorizationMemo::GetFactors(int n) { + auto it = factor_memory_.find(n); + if (it != factor_memory_.end()) { + return it->second; + } + + std::vector& res = factor_memory_[n]; + int step = n % 2 == 0 ? 1 : 2; + for (size_t i = 1; i < static_cast(std::sqrt(n)) + 1; i += step) { + if (n % i == 0) { + res.push_back(i); + if (n / i != i) { + res.push_back(n / i); + } + } + } + std::sort(res.begin(), res.end()); + return res; +} + +} // namespace auto_scheduler +} // namespace tvm diff --git a/src/auto_scheduler/search_policy/utils.h b/src/auto_scheduler/search_policy/utils.h new file mode 100644 index 0000000000000..6814d258c77f4 --- /dev/null +++ b/src/auto_scheduler/search_policy/utils.h @@ -0,0 +1,484 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file auto_scheduler/search_policy/utils.cc + * \brief Common utilities for search policies. + */ + +#ifndef TVM_AUTO_SCHEDULER_SEARCH_POLICY_UTILS_H_ +#define TVM_AUTO_SCHEDULER_SEARCH_POLICY_UTILS_H_ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../utils.h" + +namespace tvm { +namespace auto_scheduler { + +/*! \brief Argsort. Order: largest to smallest */ +template +inline std::vector Argsort(const std::vector& scores) { + std::vector index; + index.reserve(scores.size()); + for (size_t i = 0; i < scores.size(); ++i) { + index.push_back(i); + } + auto cmp = [&scores](int l, int r) { return scores[l] > scores[r]; }; + std::sort(index.begin(), index.end(), cmp); + return index; +} + +/*! \brief Convert operation to stage id. */ +inline int OperationToStage(const te::Operation& op, const State& state) { + for (size_t i = 0; i < state->stages.size(); ++i) { + if (op == state->stages[i]->op) { + return i; + } + } + LOG(FATAL) << "Cannot find op: " << op; + return -1; +} + +/********** Get Parameters **********/ + +/*! \brief Get an integer from a tvm str Map. */ +inline int GetIntParam(const Map& attr_dict, const std::string& key) { + CHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict; + auto pint = attr_dict[key].as(); + CHECK(pint != nullptr); + return pint->value; +} + +/*! \brief Get a double from a tvm str Map. */ +inline double GetDoubleParam(const Map& attr_dict, const std::string& key) { + CHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict; + auto pdouble = attr_dict[key].as(); + CHECK(pdouble != nullptr); + return pdouble->value; +} + +/*! \brief Get a string from a tvm str Map. */ +inline std::string GetStringParam(const Map& attr_dict, const std::string& key) { + CHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict; + const auto& target = attr_dict[key]; + if (auto pstr = target.as()) { + return pstr->value; + } + auto pstr = target.as(); + CHECK(pstr != nullptr); + return pstr->data; +} + +/*! \brief Get a iterator name set from a tvm str Map. */ +inline std::set GetIterNameSetParam(const Map& attr_dict, + const std::string& key) { + std::set ret; + CHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict; + auto names = attr_dict[key].as(); + CHECK(names != nullptr); + for (const auto& name : *names) { + ret.insert(name.as()->data); + } + return ret; +} + +/********** Checks with ComputeDAG **********/ + +/*! \brief Return whether an op is strictly-inlineable. */ +inline bool IsStrictlyInlineable(const SearchTask& task, const State& state, int stage_id) { + if (state->current_compute_dag) { + return state->current_compute_dag.as()->access_analyzer.IsStrictlyInlineable( + state->stages[stage_id]->op); + } else { + return task->compute_dag->access_analyzer.IsStrictlyInlineable(state->stages[stage_id]->op); + } +} + +/*! \brief Return whether an op is an output op. */ +inline bool IsOutputOp(const SearchTask& task, const State& state, int stage_id) { + if (state->current_compute_dag) { + return state->current_compute_dag.as()->access_analyzer.IsOutput( + state->stages[stage_id]->op); + } else { + return task->compute_dag->access_analyzer.IsOutput(state->stages[stage_id]->op); + } +} + +/*! \brief Return whether an op needs multi level tiling. */ +inline bool NeedsMultilevelTiling(const SearchTask& task, const State& state, int stage_id) { + if (state->current_compute_dag) { + return state->current_compute_dag.as()->access_analyzer.NeedsMultiLevelTiling( + state->stages[stage_id]->op); + } else { + return task->compute_dag->access_analyzer.NeedsMultiLevelTiling(state->stages[stage_id]->op); + } +} + +/*! \brief Get all consumers for a stage. This function propagates the relation for inlined ops. */ +inline std::set GetConsumers(const SearchTask& task, const State& state, int stage_id) { + std::unordered_set consumers; + std::set ret; + + if (state->current_compute_dag) { + consumers = state->current_compute_dag.as()->access_analyzer.GetConsumers( + state, state->stages[stage_id]->op); + } else { + consumers = task->compute_dag->access_analyzer.GetConsumers(state, state->stages[stage_id]->op); + } + + for (const auto& op : consumers) { + ret.insert(OperationToStage(op, state)); + } + return ret; +} + +/*! \brief Check if a stage has single consumer or all of its consumers share a common root, return + * the target consumer root or -1. */ +inline int GetSingleConsumerId(const SearchTask& task, const State& state, int stage_id) { + const std::set& consumers = GetConsumers(task, state, stage_id); + if (consumers.empty()) { + return -1; + } + + if (consumers.size() == 1) { + return *consumers.begin(); + } else { + // Check all consumers share a common root + int common_root_id = -1; + bool mismatch = false; + for (const auto& consumer_stage_id : consumers) { + int root_id = -1; + if (state->stages[consumer_stage_id]->compute_at == ComputeAtKind::kRoot) { + root_id = consumer_stage_id; + } else if (state->stages[consumer_stage_id]->compute_at == ComputeAtKind::kIter) { + root_id = state->attach_map->stage_to_attach_iter.at(consumer_stage_id).first; + } else { + LOG(FATAL) << "Invalid case"; + } + + if (common_root_id == -1) { + common_root_id = root_id; + } else { + if (common_root_id != root_id) { + mismatch = true; + break; + } + } + } + + return mismatch ? -1 : common_root_id; + } +} + +/*! \brief Get all producers for a stage. This function propagates the relation for inlined ops. */ +inline std::set GetProducers(const SearchTask& task, const State& state, int stage_id) { + std::unordered_set producers; + std::set ret; + + if (state->current_compute_dag) { + producers = state->current_compute_dag.as()->access_analyzer.GetProducers( + state, state->stages[stage_id]->op); + } else { + producers = task->compute_dag->access_analyzer.GetProducers(state, state->stages[stage_id]->op); + } + + for (const auto& op : producers) { + ret.insert(OperationToStage(op, state)); + } + return ret; +} + +/*! \brief Get all producers for a stage. This function DOES NOT propagates the relation for + * inlined ops. */ +inline std::set GetDirectProducers(const SearchTask& task, const State& state, int stage_id) { + std::unordered_set producers; + std::set ret; + + if (state->current_compute_dag) { + producers = state->current_compute_dag.as()->access_analyzer.GetDirectProducers( + state->stages[stage_id]->op); + } else { + producers = task->compute_dag->access_analyzer.GetDirectProducers(state->stages[stage_id]->op); + } + + for (const auto& op : producers) { + ret.insert(OperationToStage(op, state)); + } + return ret; +} + +/*! \brief Get the number of common outer iterators. This function propagates the relation for + * chains with multiple ops. */ +inline int GetNumCommonOuterIterator(const SearchTask& task, const State& state, int stage_id, + int target_stage_id) { + if (state->current_compute_dag) { + return state->current_compute_dag.as() + ->access_analyzer.GetNumCommonOuterIterator(state->stages[stage_id]->op, + state->stages[target_stage_id]->op); + } else { + return task->compute_dag->access_analyzer.GetNumCommonOuterIterator( + state->stages[stage_id]->op, state->stages[target_stage_id]->op); + } +} + +/*! \brief Return whether two ops are elementwise-matched. */ +inline bool ElementwiseMatch(const SearchTask& task, const State& state, int stage_id, + int target_stage_id) { + const auto& op = state->stages[stage_id]->op; + const auto& target_op = state->stages[target_stage_id]->op; + if (state->current_compute_dag) { + return state->current_compute_dag.as()->access_analyzer.ElementWiseMatch( + op, target_op); + } else { + return task->compute_dag->access_analyzer.ElementWiseMatch(op, target_op); + } +} + +/********** Get informations from Stage/Iterator **********/ + +/*! \brief Return the extent of an iterator. */ +inline int64_t GetExtent(const Iterator& it) { + if (it->range.defined()) { + if (auto pint = it->range->extent.as()) { + return pint->value; + } + } + return -1; +} + +/*! \brief Compute the product of lengths of all space iters and all reduce iters, respectively. */ +inline std::pair GetCumulativeSpaceAndReductionLengh(const Stage& stage) { + int64_t cum_space_len = 1, cum_reduce_len = 1; + for (const auto& iter : stage->iters) { + if (iter->iter_kind == IteratorKind::kSpatial) { + cum_space_len *= GetExtent(iter); + } else if (iter->iter_kind == IteratorKind::kReduction) { + cum_reduce_len *= GetExtent(iter); + } + } + return std::make_pair(cum_space_len, cum_reduce_len); +} + +/*! \brief Return whether this stage needs rfactor. */ +inline bool NeedsRfactor(const SearchTask& task, const State& state, int stage_id) { + const auto& op = state->stages[stage_id]->op; + if (op->IsInstance()) { + // Compute the product of lengths of all space iters and all reduce iters + int cum_space_len, cum_reduce_len; + std::tie(cum_space_len, cum_reduce_len) = + GetCumulativeSpaceAndReductionLengh(state->stages[stage_id]); + + if (NeedsMultilevelTiling(task, state, stage_id)) { + // Do not use rfactor if we have enough parallelism on space iters + if (cum_space_len > cum_reduce_len || cum_space_len > task->hardware_params->num_cores * 16) { + return false; + } else { + return true; + } + } else if (cum_reduce_len > 1) { + // Always try rfactor for reduction ops + return cum_reduce_len > task->hardware_params->num_cores; + } + } + + return false; +} + +/*! \brief Return whether the stage has reduce iterators. */ +inline bool HasReduceIter(const Stage& stage) { + for (const auto& iter : stage->iters) { + if (iter->iter_kind != IteratorKind::kSpatial) { + return true; + } + } + return false; +} + +/*! \brief Return whether the stage has specific annotated iterators. */ +inline bool HasAnnotatedIter(const Stage& stage, IteratorAnnotation type) { + for (const auto& iter : stage->iters) { + if (iter->annotation == type) { + return true; + } + } + return false; +} + +/*! \brief Return whether the stage has only one consumer and they are elementwise-matched. */ +inline bool HasSingleElementwiseMatchedConsumer(const SearchTask& task, const State& state, + int stage_id, int* target_stage_id = nullptr) { + // Temporal object to be used if the input pointer is nullptr + int temp_target_stage_id; + if (target_stage_id == nullptr) { + target_stage_id = &temp_target_stage_id; + } + const std::set& consumers = GetConsumers(task, state, stage_id); + if (consumers.size() == 1) { + *target_stage_id = *consumers.begin(); + if (ElementwiseMatch(task, state, stage_id, *target_stage_id) && + (!(HasReduceIter(state->stages[stage_id]) && + HasReduceIter(state->stages[*target_stage_id])))) { + return true; + } + } + return false; +} + +/*! \brief Return whether the state does cache_write for stage_id. */ +inline bool HasCacheWriteStage(const State& s, int stage_id) { + for (int i = static_cast(s->transform_steps.size()) - 1; i >= 0; --i) { + if (auto ps = s->transform_steps[i].as()) { + if (stage_id == ps->stage_id) { + return true; + } + } + + if (s->transform_steps[i]->IsInstance() || + s->transform_steps[i]->IsInstance() || + s->transform_steps[i]->IsInstance()) { + if (stage_id > s->transform_steps[i]->stage_id) { + stage_id--; + } + } + } + return false; +} + +/*! \brief Return whether the stage has been tiled already. */ +inline bool IsTiled(const Stage& stage) { + auto op = stage->op.as(); + CHECK(op != nullptr); + return stage->iters.size() != op->axis.size() + op->reduce_axis.size(); +} + +/*! \brief Extract primitive iterators from a nested fused or splitted iterator's name. */ +inline void ExtractOriginalIterators(const std::string& name, std::set* rets) { + size_t last_pos = 0; + for (size_t i = 0; i < name.size(); ++i) { + if (name[i] == '@' || name[i] == '.') { // '@' for fuse and '.' for split + if (!isdigit(name[last_pos]) && name[last_pos] != '@' && name[last_pos] != '.') { + rets->insert(name.substr(last_pos, i - last_pos)); + } + last_pos = i + 1; + } + } + + if (last_pos < name.size() && !isdigit(name[last_pos]) && name[last_pos] != '@' && + name[last_pos] != '.') { + rets->insert(name.substr(last_pos, name.size() - last_pos)); + } +} + +/*! \brief Fuse all reduction iterators. */ +inline State FuseAllReductionIterators(const State& state, int stage_id, Iterator* fused_iter, + Array* space_iters, + Array* reduce_iters) { + space_iters->clear(); + reduce_iters->clear(); + + for (const auto& iter : state->stages[stage_id]->iters) { + if (iter->iter_kind == IteratorKind::kSpatial) { + space_iters->push_back(iter); + } else if (iter->iter_kind == IteratorKind::kReduction) { + reduce_iters->push_back(iter); + } + } + + CHECK(!reduce_iters->empty()); + State tmp_s = state; + if (reduce_iters->size() > 1) { + *fused_iter = tmp_s.fuse(stage_id, *reduce_iters); + } else { + *fused_iter = (*reduce_iters)[0]; + } + return tmp_s; +} + +/*! \brief Random sample states. */ +inline Array RandomSampleStates(const Array& in_states, std::mt19937* random_gen, + size_t out_size) { + Array out_states; + for (size_t i = 0; i < out_size; i++) { + out_states.push_back(in_states[(*random_gen)() % in_states.size()]); + } + return out_states; +} + +/*! \brief Print a title */ +inline void PrintTitle(const std::string& title, int verbose) { + StdCout(verbose) << Chars('-', 60) << "\n" + << Chars('-', 25) << " [ " << title << " ]\n" + << Chars('-', 60) << std::endl; +} + +/*! + * \brief Enumerate all possible factorization schemes for splitting an axes. + * \note This class will memorize the results for reuse. + */ +class SplitFactorizationMemo { + public: + using QueryKey = std::tuple; + + const Array>& GetFactorizationSchemes(int extent, int n_lengths, + int max_innermost_factor); + const std::vector& GetFactors(int n); + + private: + void DfsEnumerate(int now, int remaining_lenght, int max_innermost_factor); + + std::unordered_map>> memory_; + + int n_lengths_; + Array tmp_stack_; + Array>* results_; + std::unordered_map> factor_memory_; +}; + +// Apply multi-level tiling structure according to a string format, +// where "S" stands a space level, "R" stands for a reudciton level. +// For example, if the format is "SSRSRS", the we will +// use tiling structure: space_L0, space_L1, reduce_L0, space_L2, reduce_L1, space_L3 +// For example, if apply "SSRSRS" to matrix multiplication, +// we have space iterators i and j, reduce iterator k. +// Then the tiling structure is : i0, j0, i1, j1, k0, i2, j2, k1, i3, j3 +State DoMultiLevelTiling(const State& state, int stage_id, const std::string& format, + std::vector* spatial_split_step_ids = nullptr); + +// Apply tiling structure: space, space, space, ..., with tile sizes from other SplitStep +State FollowTiling(const State& state, int stage_id, const std::vector& split_step_ids, + int n_split); + +} // namespace auto_scheduler +} // namespace tvm + +#endif // TVM_AUTO_SCHEDULER_SEARCH_POLICY_UTILS_H_ diff --git a/src/auto_scheduler/utils.h b/src/auto_scheduler/utils.h index aacdcf4265f9e..b3fa2dc6d9dca 100644 --- a/src/auto_scheduler/utils.h +++ b/src/auto_scheduler/utils.h @@ -32,6 +32,8 @@ #include #include #include +#include +#include #include #include #include @@ -98,6 +100,38 @@ inline void FindAndDeleteItem(std::vector* array, const T& to_delete) { } } +/*! \brief Compute the product of all elements in a vector */ +inline int64_t ElementProduct(const std::vector& array) { + int64_t ret = 1; + for (auto x : array) { + ret *= x; + } + return ret; +} + +/*! \brief Move elements from multiple vectors to one vector */ +template +std::vector& ConcatenateMove(std::vector* out, std::vector* in) { + out->insert(out->end(), std::make_move_iterator(in->begin()), std::make_move_iterator(in->end())); + return *out; +} + +/*! \brief Move elements from multiple vectors to one vector */ +template +std::vector& ConcatenateMove(std::vector* out, std::vector* first, Args... args) { + ConcatenateMove(out, first); + ConcatenateMove(out, args...); + return *out; +} + +/*! \brief Get a random permutation of integers [0, n-1] */ +template +void RandomPermutation(int n, std::vector* out, G* gen) { + out->assign(n, 0); + std::iota(out->begin(), out->end(), 0); + std::shuffle(out->begin(), out->end(), *gen); +} + /*! \brief Replace a sub-string to another sub-string in a string */ inline void StrReplace(std::string* base, const std::string& from, const std::string& to) { auto pos = base->find(from); @@ -168,6 +202,12 @@ inline bool StrStartsWith(const String& a, const String& b) { return std::equal(a.c_str(), a.c_str() + b.size(), b.c_str()); } +/*! \brief Return whether a string ends with another substring */ +inline bool StrEndsWith(const String& a, const String& b) { + if (b.size() > a.size()) return false; + return std::equal(a.c_str() + a.size() - b.size(), a.c_str() + a.size(), b.c_str()); +} + /********** Other Utilities **********/ /*! \brief Get an int value from an Expr */ inline int64_t GetIntImm(const PrimExpr& expr) { @@ -230,13 +270,6 @@ inline std::string Chars(const char& str, int times) { return ret.str(); } -/*! \brief Print a title */ -inline void PrintTitle(const std::string& title, int verbose) { - StdCout(verbose) << Chars('-', 60) << "\n" - << Chars('-', 25) << " [ " << title << " ]\n" - << Chars('-', 60) << std::endl; -} - } // namespace auto_scheduler } // namespace tvm diff --git a/tests/cpp/auto_scheduler_test.cc b/tests/cpp/auto_scheduler_test.cc index b2de0ef6bf4de..aacc3b154463e 100644 --- a/tests/cpp/auto_scheduler_test.cc +++ b/tests/cpp/auto_scheduler_test.cc @@ -95,9 +95,9 @@ TEST(ComputeDAG, AccessAnalyzer) { std::set is_strictly_inlinable = {bias_add, bn_mul, bn_add, relu}; for (size_t stage_id = 0; stage_id < dag->ops.size(); stage_id++) { if (is_strictly_inlinable.count(stage_id)) { - CHECK(dag->access_analyzer.IsStrictInlineable(dag->ops[stage_id])); + CHECK(dag->access_analyzer.IsStrictlyInlineable(dag->ops[stage_id])); } else { - CHECK(!dag->access_analyzer.IsStrictInlineable(dag->ops[stage_id])); + CHECK(!dag->access_analyzer.IsStrictlyInlineable(dag->ops[stage_id])); } } diff --git a/tests/python/unittest/test_auto_scheduler_common.py b/tests/python/unittest/test_auto_scheduler_common.py index b67178ec43707..f2db8d1f391d7 100644 --- a/tests/python/unittest/test_auto_scheduler_common.py +++ b/tests/python/unittest/test_auto_scheduler_common.py @@ -19,8 +19,11 @@ import threading +import tvm from tvm import te, auto_scheduler from tvm import topi +from tvm.topi.nn.winograd_util import winograd_transform_matrices +from tvm.topi.util import get_const_tuple @auto_scheduler.register_workload @@ -32,6 +35,7 @@ def matmul_auto_scheduler_test(N, M, K): return [A, B, C] +# Test for register_workload with different name @auto_scheduler.register_workload("matmul_auto_scheduler_test_rename_1") def matmul_auto_scheduler_test_rename_0(N, M, K): A = te.placeholder((N, K), name='A') @@ -40,8 +44,9 @@ def matmul_auto_scheduler_test_rename_0(N, M, K): C = te.compute((N, M), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name='C') return [A, B, C] + @auto_scheduler.register_workload -def conv2d_nchw_bn_relu(N, H, W, CI, CO, kernel_size, strides, padding, dilation=1): +def conv2d_nchw_bn_relu_auto_scheduler_test(N, H, W, CI, CO, kernel_size, strides, padding, dilation=1): data = te.placeholder((N, CI, H, W), name='Data') kernel = te.placeholder((CO, CI, kernel_size, kernel_size), name='Kernel') bias = te.placeholder((CO, 1, 1), name='Bias') @@ -66,6 +71,111 @@ def conv2d_nchw_bn_relu(N, H, W, CI, CO, kernel_size, strides, padding, dilation return [data, kernel, bias, bn_offset, bn_scale, out] +@auto_scheduler.register_workload +def max_pool2d_auto_scheduler_test(N, H, W, CI, padding): + data = te.placeholder((N, CI, H, W), name='Data') + out = topi.nn.pool(data, [2, 2], [1, 1], [padding, padding, padding, padding], 'max') + + return [data, out] + + +@auto_scheduler.register_workload +def min_nm_auto_scheduler_test(N, M): + A = te.placeholder((N, M), name='A') + B = topi.min(A, axis=-1) + + return [A, B] + + +@auto_scheduler.register_workload +def softmax_nm_auto_scheduler_test(N, M): + A = te.placeholder((N, M), name='A') + B = topi.nn.softmax(A, axis=1) + + return [A, B] + + +@auto_scheduler.register_workload +def softmax_abcd_auto_scheduler_test(a, b, c, d): + A = te.placeholder((a, b, c, d), name='A') + B = topi.nn.softmax(A, axis=-1) + + return [A, B] + + +@auto_scheduler.register_workload +def conv2d_winograd_nhwc_auto_scheduler_test(N, H, W, CI, CO, kernel_size=3, stride=1, padding=0, dilation=1): + tile_size = 4 + inputs = te.placeholder((N, H, W, CI), name='inputs') + N, H, W, CI = get_const_tuple(inputs.shape) + if isinstance(dilation, int): + dilation_h = dilation_w = dilation + else: + dilation_h, dilation_w = dilation + + assert (dilation_h, dilation_w) == (1, 1), "Does not support dilation" + + KH = KW = kernel_size + HPAD, WPAD, _, _ = topi.nn.get_pad_tuple(padding, (KH, KW)) + HSTR, WSTR = (stride, stride) if isinstance(stride, int) else stride + assert HSTR == 1 and WSTR == 1 and KH == KW + + data_pad = topi.nn.pad(inputs, (0, HPAD, WPAD, 0), (0, HPAD, WPAD, 0), name="data_pad") + + r = KW + m = tile_size + alpha = m + r - 1 + A, B, G = winograd_transform_matrices(m, r, 'float32') + + H = (H + 2 * HPAD - KH) // HSTR + 1 + W = (W + 2 * WPAD - KW) // WSTR + 1 + nH, nW = (H + m - 1) // m, (W + m - 1) // m + P = N * nH * nW + r_kh = te.reduce_axis((0, KH), name='r_kh') + r_kw = te.reduce_axis((0, KW), name='r_kw') + kshape = (alpha, alpha, CI, CO) + kernel_pack = te.placeholder(kshape, inputs.dtype, name="weight") + + idxdiv = te.indexdiv + idxmod = te.indexmod + # pack input tile + input_tile = te.compute((alpha, alpha, P, CI), lambda eps, nu, p, ci: + data_pad[idxdiv(p, (nH * nW))][idxmod(idxdiv(p, nW), nH) * m + eps] + [idxmod(p, nW) * m + nu][ci], name='input_tile') + + # transform data + r_a = te.reduce_axis((0, alpha), 'r_a') + r_b = te.reduce_axis((0, alpha), 'r_b') + data_pack = te.compute((alpha, alpha, P, CI), lambda eps, nu, p, ci: + te.sum(input_tile[r_a][r_b][p][ci] * B[r_a][eps] * B[r_b][nu], + axis=[r_a, r_b]), name='data_pack', + attrs={"auto_scheduler_simplify_const_tensor_indices": ["eps", "nu", "r_a", "r_b"]}) + + # do batch gemm + ci = te.reduce_axis((0, CI), name='ci') + bgemm = te.compute((alpha, alpha, P, CO), lambda eps, nu, p, co: + te.sum(data_pack[eps][nu][p][ci] * + kernel_pack[eps][nu][ci][co], + axis=[ci]), name='bgemm') + + # inverse transform + r_a = te.reduce_axis((0, alpha), 'r_a') + r_b = te.reduce_axis((0, alpha), 'r_b') + inverse = te.compute((m, m, P, CO), lambda vh, vw, p, co: + te.sum(bgemm[r_a][r_b][p][co] * A[r_a][vh] * A[r_b][vw], + axis=[r_a, r_b]), name='inverse', + attrs={"auto_scheduler_simplify_const_tensor_indices": ["vh", "vw", "r_a", "r_b"]}) + + # output + output = te.compute((N, H, W, CO), lambda n, h, w, co: + inverse[idxmod(h, m), + idxmod(w, m), + n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m), + co], + name='conv2d_winograd') + + return [inputs, kernel_pack, output] + def get_tiled_matmul(): A, B, C = matmul_auto_scheduler_test(512, 512, 512) dag = auto_scheduler.ComputeDAG([A, B, C]) diff --git a/tests/python/unittest/test_auto_scheduler_loop_state.py b/tests/python/unittest/test_auto_scheduler_loop_state.py index a051e81894239..aeed42089de07 100644 --- a/tests/python/unittest/test_auto_scheduler_loop_state.py +++ b/tests/python/unittest/test_auto_scheduler_loop_state.py @@ -23,7 +23,7 @@ from tvm import auto_scheduler, te from tvm import topi -from test_auto_scheduler_common import matmul_auto_scheduler_test, conv2d_nchw_bn_relu +from test_auto_scheduler_common import matmul_auto_scheduler_test, conv2d_nchw_bn_relu_auto_scheduler_test def test_split_fuse_reorder_annotation(): @@ -86,8 +86,9 @@ def test_split_fuse_reorder_annotation(): assert res.annotation == auto_scheduler.loop_state.State.ANNOTATION_TRANS_TABLE["vectorize"] def test_compute_at_root_inline(): - dag = auto_scheduler.ComputeDAG(conv2d_nchw_bn_relu(N=1, H=224, W=224, CI=3, CO=64, - kernel_size=7, strides=2, padding=3)) + dag = auto_scheduler.ComputeDAG(conv2d_nchw_bn_relu_auto_scheduler_test(N=1, H=224, W=224, CI=3, + CO=64, kernel_size=7, strides=2, + padding=3)) s0 = dag.get_init_state() # data, padding, kernel = 0, 1, 2 diff --git a/tests/python/unittest/test_auto_scheduler_search_policy.py b/tests/python/unittest/test_auto_scheduler_search_policy.py index 8dfc07865ea8a..5dfc649449797 100644 --- a/tests/python/unittest/test_auto_scheduler_search_policy.py +++ b/tests/python/unittest/test_auto_scheduler_search_policy.py @@ -27,9 +27,9 @@ from test_auto_scheduler_common import matmul_auto_scheduler_test, PropagatingThread def search_common(workload=matmul_auto_scheduler_test, target="llvm", - search_policy=auto_scheduler.EmptyPolicy(), - seed=random.randint(1, 1 << 30), runner='local', cost_model=None, - num_measure_trials=2, params=None, pre_search_callbacks=None): + search_policy='empty', seed=random.randint(1, 1 << 30), runner='local', + cost_model=auto_scheduler.RandomModel(), num_measure_trials=2, + init_search_callbacks=None): print("Test %s schedule search with the default search policy" % (target)) random.seed(seed) @@ -39,13 +39,17 @@ def search_common(workload=matmul_auto_scheduler_test, target="llvm", target = tvm.target.create(target) task = auto_scheduler.SearchTask(dag, workload_key, target) + if search_policy == 'empty': + search_policy = auto_scheduler.EmptyPolicy(task) + elif search_policy == 'sketch': + search_policy = auto_scheduler.SketchPolicy(task, + init_search_callbacks=init_search_callbacks) + with tempfile.NamedTemporaryFile() as fp: log_file = fp.name - tuning_options = auto_scheduler.TuningOptions(num_measure_trials=num_measure_trials, runner=runner, - verbose=0, - measure_callbacks=[auto_scheduler.RecordToFile(log_file)], - pre_search_callbacks=pre_search_callbacks) + tuning_options = auto_scheduler.TuningOptions(num_measure_trials=num_measure_trials, + runner=runner, verbose=1, measure_callbacks=[auto_scheduler.RecordToFile(log_file)]) sch, args = auto_scheduler.auto_schedule(task, search_policy, tuning_options) inp, res = auto_scheduler.load_best(log_file, workload_key, target) @@ -88,5 +92,18 @@ def test_workload_registry_search_basic(): t.start() t.join() + +def test_sketch_search_policy_basic(): + if not tvm.runtime.enabled("llvm"): + return + # wrap the search in a new thread to avoid the conflict + # between python's multiprocessing and tvm's thread pool + t = PropagatingThread(target=search_common, + kwargs={'seed': 944563397, 'search_policy': 'sketch'}) + t.start() + t.join() + + if __name__ == "__main__": test_workload_registry_search_basic() + test_sketch_search_policy_basic() diff --git a/tests/python/unittest/test_auto_scheduler_sketch_generation.py b/tests/python/unittest/test_auto_scheduler_sketch_generation.py new file mode 100644 index 0000000000000..4ef0cbc7d957e --- /dev/null +++ b/tests/python/unittest/test_auto_scheduler_sketch_generation.py @@ -0,0 +1,102 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" Test sketch generation. """ + +import tvm +from tvm import te, auto_scheduler + +from test_auto_scheduler_common import (matmul_auto_scheduler_test, conv2d_nchw_bn_relu_auto_scheduler_test, + max_pool2d_auto_scheduler_test, min_nm_auto_scheduler_test, + softmax_nm_auto_scheduler_test, softmax_abcd_auto_scheduler_test, + conv2d_winograd_nhwc_auto_scheduler_test) + +def generate_sketches(workload_func, args, target, print_for_debug=False): + workload_key = auto_scheduler.make_workload_key(workload_func, args) + dag = auto_scheduler.ComputeDAG(workload_key) + task = auto_scheduler.SearchTask(dag, workload_key, tvm.target.create(target)) + policy = auto_scheduler.SketchPolicy(task, verbose=0) + return policy.generate_sketches(print_for_debug) + +def test_cpu_matmul_sketch(): + sketches = generate_sketches(matmul_auto_scheduler_test, (512, 512, 512), 'llvm') + ''' 3 multi-level tiling sketches + 0 - Multi-level tiling + 1 - Multi-level tiling with cache write on position 0 + 2 - Multi-level tiling with cache write on position 1 + ''' + assert len(sketches) == 3 + + sketches = generate_sketches(matmul_auto_scheduler_test, (8, 8, 512), 'llvm') + ''' 2 rfactor sketches + 3 multi-level tiling sketches + 0 - Rfactor with factor position 0 + 1 - Rfactor with factor position 1 + 2 - Multi-level tiling + 3 - Multi-level tiling with cache write on position 0 + 4 - Multi-level tiling with cache write on position 1 + ''' + assert len(sketches) == 5 + +def test_cpu_conv2d_bn_relu_sketch(): + sketches = generate_sketches(conv2d_nchw_bn_relu_auto_scheduler_test, + (1, 56, 56, 512, 512, 3, 1, 1), 'llvm') + ''' 3 multi-level tiling sketches + 0 - Conv2d multi-level tiling with fusion on position 0 + 1 - Conv2d multi-level tiling with fusion on position 1 + 2 - Conv2d multi-level tiling without fusion + ''' + assert len(sketches) == 3 + +def test_cpu_max_pool2d_sketch(): + sketches = generate_sketches(max_pool2d_auto_scheduler_test, (1, 56, 56, 512, 1), 'llvm') + assert len(sketches) == 1 # 1 default sketch + +def test_cpu_min_sketch(): + sketches = generate_sketches(min_nm_auto_scheduler_test, (10, 1024), 'llvm') + assert len(sketches) == 3 + ''' 2 rfactor sketches + 1 default sketch + 0 - Rfactor with factor position 0 + 1 - Rfactor with factor position 1 + 2 - Default sketch + ''' + +def test_cpu_softmax_sketch(): + sketches = generate_sketches(softmax_nm_auto_scheduler_test, (1, 1024), 'llvm') + ''' (2 rfactor sketches + 1 default sketch) * (2 rfactor sketches + 1 default sketch) ''' + assert len(sketches) == (3 * 3) + + sketches = generate_sketches(softmax_abcd_auto_scheduler_test, (1, 12, 128, 128), 'llvm') + ''' (2 rfactor sketches + 1 default sketch) * (2 rfactor sketches + 1 default sketch) ''' + assert len(sketches) == (3 * 3) + +def test_cpu_conv2d_winograd_sketch(): + sketches = generate_sketches(conv2d_winograd_nhwc_auto_scheduler_test, + (1, 28, 28, 128, 128, 3, 1, 1), 'llvm') + ''' 3 multi-level tiling sketches + 0 - Bgemm multi-level tiling + 1 - Bgemm multi-level tiling with cache write on position 0 + 2 - Bgemm multi-level tiling with cache write on position 1 + ''' + assert len(sketches) == 3 + +if __name__ == "__main__": + test_cpu_matmul_sketch() + test_cpu_conv2d_bn_relu_sketch() + test_cpu_max_pool2d_sketch() + test_cpu_min_sketch() + test_cpu_softmax_sketch() + test_cpu_conv2d_winograd_sketch()