Skip to content

Commit

Permalink
[Ansor][AutoTVM v2.0] Phase 2: Basic CPU Sketch Search Policy (apache…
Browse files Browse the repository at this point in the history
…#6184)

* Init commit to pass the compile

* First commit to pass the test

* Update

* Add UTs for sketch generation

* Update

* Add ASF to new UT file.

* Update rule for winograd

* Update

* File renamed

* Lint fix
  • Loading branch information
jcf94 authored and Trevor Morris committed Sep 2, 2020
1 parent 34aad1a commit 6051861
Show file tree
Hide file tree
Showing 26 changed files with 2,713 additions and 137 deletions.
16 changes: 3 additions & 13 deletions include/tvm/auto_scheduler/auto_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,14 @@ class TuningOptionsNode : public Object {
int early_stopping;
/*! \brief The number of programs to be measured at each search round. */
int num_measures_per_round;
/*!
* \brief Verbosity level.
* 0 for silent, 1 to output information during schedule searching.
*/
/*! \brief Verbosity level. 0 for silent, 1 to output information during schedule searching. */
int verbose;
/*! \brief ProgramBuilder which builds the program */
ProgramBuilder builder;
/*! \brief ProgramRunner which runs the program and measures time costs */
ProgramRunner runner;
/*! \brief MeasureCallback functions to be called after each measure batch */
Optional<Array<MeasureCallback>> measure_callbacks;
/*! \brief SearchCallback functions to be called before schedule search */
Optional<Array<SearchCallback>> pre_search_callbacks;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("num_measure_trials", &num_measure_trials);
Expand All @@ -64,7 +59,6 @@ class TuningOptionsNode : public Object {
v->Visit("builder", &builder);
v->Visit("runner", &runner);
v->Visit("measure_callbacks", &measure_callbacks);
v->Visit("pre_search_callbacks", &pre_search_callbacks);
}

static constexpr const char* _type_key = "auto_scheduler.TuningOptions";
Expand All @@ -87,26 +81,22 @@ class TuningOptions : public ObjectRef {
* \param builder ProgramBuilder which builds the program.
* \param runner ProgramRunner which runs the program and measure time costs.
* \param measure_callbacks MeasureCallback functions to be called after each measure batch.
* \param pre_search_callbacks SearchCallback functions to be called before schedule search.
*/
TuningOptions(int num_measure_trials, int early_stopping, int num_measures_per_round, int verbose,
ProgramBuilder builder, ProgramRunner runner,
Optional<Array<MeasureCallback>> measure_callbacks,
Optional<Array<SearchCallback>> pre_search_callbacks);
Optional<Array<MeasureCallback>> measure_callbacks);

TVM_DEFINE_OBJECT_REF_METHODS(TuningOptions, ObjectRef, TuningOptionsNode);
};

/*!
* \brief Run schedule search for a given compute declaration.
* \param task The search task of the compute declaration.
* \param search_policy The search policy.
* \param tuning_options Tuning and measurement options.
* \return A `te::schedule` and an Array of `te::Tensor` to be used in `tvm.lower` or
* `tvm.build`.
*/
TVM_DLL std::pair<te::Schedule, Array<te::Tensor>> AutoSchedule(SearchTask task,
SearchPolicy search_policy,
TVM_DLL std::pair<te::Schedule, Array<te::Tensor>> AutoSchedule(SearchPolicy search_policy,
TuningOptions tuning_options);
} // namespace auto_scheduler
} // namespace tvm
Expand Down
18 changes: 15 additions & 3 deletions include/tvm/auto_scheduler/compute_dag.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ class AccessAnalyzerNode : public Object {
/*! \brief Store whether the operation is an op with only simple access.
* (e.g., injective, broadcast and elementwise ops without reduction) */
OperationMap<bool> is_simple_access;
/*! \brief Store whether the operation is strictly-inlineable
/*! \brief Store whether the operation is strictly inlineable
* (e.g., injective, broadcast and elementwise without reduction, branch or expenive operations)
*/
OperationMap<bool> is_strict_inlineable;
OperationMap<bool> is_strictly_inlineable;
/*! \brief Store whether the operation needs multi-level tiling
* (e.g., computation-intensive ops with data reuse opportunity like matmul, conv2d) */
OperationMap<bool> needs_multi_level_tiling;
Expand Down Expand Up @@ -102,7 +102,7 @@ class AccessAnalyzer : public ObjectRef {
* (e.g., injective, broadcast and elementwise without reduction, branch or expenive operations)
* \param op The operation
*/
TVM_DLL bool IsStrictInlineable(const te::Operation& op) const;
TVM_DLL bool IsStrictlyInlineable(const te::Operation& op) const;

/*!
* \brief Return whether this operation needs multi-level tiling
Expand Down Expand Up @@ -187,6 +187,7 @@ class ComputeDAGNode : public Object {
v->Visit("ops", &ops);
v->Visit("flop_ct", &flop_ct);
v->Visit("init_state", &init_state);
v->Visit("access_analyzer", &access_analyzer);
}

static constexpr const char* _type_key = "auto_scheduler.ComputeDAG";
Expand Down Expand Up @@ -237,6 +238,17 @@ class ComputeDAG : public ObjectRef {
*/
State InferBound(const State& state) const;

/*!
* \brief Fill the correct bound information for the given states by calling ir_pass::InferBound.
* The states can lose complete bound information after some transform steps (e.g., compute_at).
* We can call this function to infer and fill all the bound information.
* This function calls TVM InferBound pass internally to get the bound.
* The returned state of this function is guaranteed to have complete bound information.
* \param states The input states.
* \return The States with complete bound information
*/
Array<State> InferBound(const Array<State>& states) const;

/*!
* \brief Since some steps may change the ComputeDAG (e.g. CacheRead/CacheWrite), the initial
* ComputeDAG may not be up-to-date. This function replays the given transform steps from the
Expand Down
12 changes: 6 additions & 6 deletions include/tvm/auto_scheduler/cost_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
*/

/*!
* \file tvm/auto_scheduler/cost_model.h
* \file auto_scheduler/cost_model.h
* \brief Cost models that estimate the performance of programs
*/

Expand Down Expand Up @@ -54,7 +54,7 @@ class CostModelNode : public Object {
* \param states The input states
* \param scores The predicted scores for all states
*/
virtual void Predict(const SearchTask& task, const std::vector<State>& states,
virtual void Predict(const SearchTask& task, const Array<State>& states,
std::vector<float>* scores) = 0;

/*!
Expand All @@ -64,7 +64,7 @@ class CostModelNode : public Object {
* \param state_scores The predicted scores for all states
* \param stage_scores The predicted scores for all stages in all stages
*/
virtual void PredictStages(const SearchTask& task, const std::vector<State>& states,
virtual void PredictStages(const SearchTask& task, const Array<State>& states,
std::vector<float>* state_scores,
std::vector<std::vector<float>>* stage_scores) {
LOG(FATAL) << "Not implemented";
Expand All @@ -91,7 +91,7 @@ class RandomModelNode : public CostModelNode {

void Update(const Array<MeasureInput>& inputs, const Array<MeasureResult>& results) final;

void Predict(const SearchTask& task, const std::vector<State>& states,
void Predict(const SearchTask& task, const Array<State>& states,
std::vector<float>* scores) final;

static constexpr const char* _type_key = "auto_scheduler.RandomModel";
Expand Down Expand Up @@ -126,10 +126,10 @@ class PythonBasedModelNode : public CostModelNode {

void Update(const Array<MeasureInput>& inputs, const Array<MeasureResult>& results) final;

void Predict(const SearchTask& task, const std::vector<State>& states,
void Predict(const SearchTask& task, const Array<State>& states,
std::vector<float>* scores) final;

void PredictStages(const SearchTask& task, const std::vector<State>& states,
void PredictStages(const SearchTask& task, const Array<State>& states,
std::vector<float>* state_scores,
std::vector<std::vector<float>>* stage_scores) final;

Expand Down
37 changes: 27 additions & 10 deletions include/tvm/auto_scheduler/search_policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,17 @@
*
* Candidate schedules are measured against the specific hardware target.
*
* We intend to introduce different level of automation on the schedule generation process:
* - Level 0(the default level): For all kinds of ops/subgraphs, the search policy should be able
* to generate schedule automatically.
* - Level 1: For some complicated ops/subgraphs(e.g. conv2d windograd), the default search space
* of level 0 may be too large to find a high performance schedule efficiently. We provide some
* op attributes to help reduce the total search space, see `SearchPolicyKey` below for more
* information.
* - Level 2: For some further special ops/subgraphs, users may more likely to write their own
* template(just like AutoTVM). Search policy should be able to provide a flexible approach as
* well.
*
* \note How to add a new search policy.
* In design, there's no need for users to implement their own search policy, our formal search
* policy(will be brought later) should be enough to cover most use cases. Meanwhile, a custom rule
Expand Down Expand Up @@ -89,46 +100,52 @@ class SearchCallback : public ObjectRef {
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SearchCallback, ObjectRef, SearchCallbackNode);
};

/*! \brief Attribute keys of ops used for SearchPolicy. */
struct SearchPolicyKey {
/*! \brief Always apply unroll to the inner most iterator of the specificed iterators. */
static constexpr const char* always_unroll_inner = "auto_scheduler_always_unroll_inner";
/*! \brief The specified iterators will be placed in the inner most tile without split. */
static constexpr const char* no_split_at_inner = "auto_scheduler_no_split_at_inner";
/*! \brief The specified iterators are indices of const tensors in "fake reduction". */
static constexpr const char* simplify_const_tensor_indices =
"auto_scheduler_simplify_const_tensor_indices";
};

/*!
* \brief The base class of search policies.
*/
class SearchPolicyNode : public Object {
public:
/*! \brief The current search task. */
SearchTask cur_task;
SearchTask search_task;
/*!
* \brief Verbose level to control the screen output during schedule search.
* 0 for silent, 1 to output state & measure information during search process.
*/
int verbose;

void VisitAttrs(AttrVisitor* v) {
v->Visit("cur_task", &cur_task);
v->Visit("search_task", &search_task);
v->Visit("verbose", &verbose);
}

/*!
* \brief Do schedule search for a task. Takes the SearchTask as input and returns the best state
* found during the search.
* \param task The SearchTask for the computation declaration
* \param num_measure_trials The number of total measurement trials.
* \param early_stopping Stops the tuning early if no improvement after n measurements.
* \param num_measures_per_round The number of programs to be measured at each search round.
* \param verbose Verbose level. 0 for silent, 1 to output information during schedule
* search.
* \param measurer A ProgramMeasurer to build and measure programs
* \param pre_search_callbacks SearchCallback to be called before schedule search.
* \return The best state found.
*/
virtual State Search(SearchTask task, int num_measure_trials, int early_stopping,
int num_measures_per_round, int verbose, ProgramMeasurer measurer,
Optional<Array<SearchCallback>> pre_search_callbacks) = 0;
virtual State Search(int num_measure_trials, int early_stopping, int num_measures_per_round,
ProgramMeasurer measurer) = 0;

/*!
* \brief Call SearchCallback with the current SearchPolicyNode
* \param callbacks SearchCallback to be called.
*/
void RunCallbacks(const Optional<Array<SearchCallback>>& callbacks);
void RunCallbacks(const Array<SearchCallback>& callbacks);

static constexpr const char* _type_key = "auto_scheduler.SearchPolicy";
TVM_DECLARE_BASE_OBJECT_INFO(SearchPolicyNode, Object);
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/auto_scheduler/transform_step.h
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ class SplitStepNode : public StepNode {
/*! \brief The id of the iter to split. */
int iter_id;
/*! \brief The extent length of the axis to split. */
Optional<Integer> extent;
Optional<PrimExpr> extent;
/*! \brief The split factors. */
Array<Optional<Integer>> lengths;
/*!
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/auto_scheduler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

# Shortcut
from .auto_schedule import SearchTask, TuningOptions, HardwareParams, \
auto_schedule, EmptyPolicy
auto_schedule, EmptyPolicy, SketchPolicy
from .compute_dag import ComputeDAG
from .cost_model import RandomModel
from .measure import MeasureInput, LocalBuilder, LocalRunner, RPCRunner, \
Expand Down
Loading

0 comments on commit 6051861

Please sign in to comment.