Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Ansor][AutoTVM v2.0] Phase 2: Basic CPU Sketch Search Policy #6184

Merged
merged 10 commits into from
Aug 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 3 additions & 13 deletions include/tvm/auto_scheduler/auto_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,14 @@ class TuningOptionsNode : public Object {
int early_stopping;
/*! \brief The number of programs to be measured at each search round. */
int num_measures_per_round;
/*!
* \brief Verbosity level.
* 0 for silent, 1 to output information during schedule searching.
*/
/*! \brief Verbosity level. 0 for silent, 1 to output information during schedule searching. */
int verbose;
/*! \brief ProgramBuilder which builds the program */
ProgramBuilder builder;
/*! \brief ProgramRunner which runs the program and measures time costs */
ProgramRunner runner;
/*! \brief MeasureCallback functions to be called after each measure batch */
Optional<Array<MeasureCallback>> measure_callbacks;
/*! \brief SearchCallback functions to be called before schedule search */
Optional<Array<SearchCallback>> pre_search_callbacks;
jcf94 marked this conversation as resolved.
Show resolved Hide resolved

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("num_measure_trials", &num_measure_trials);
Expand All @@ -64,7 +59,6 @@ class TuningOptionsNode : public Object {
v->Visit("builder", &builder);
v->Visit("runner", &runner);
v->Visit("measure_callbacks", &measure_callbacks);
v->Visit("pre_search_callbacks", &pre_search_callbacks);
}

static constexpr const char* _type_key = "auto_scheduler.TuningOptions";
Expand All @@ -87,26 +81,22 @@ class TuningOptions : public ObjectRef {
* \param builder ProgramBuilder which builds the program.
* \param runner ProgramRunner which runs the program and measure time costs.
* \param measure_callbacks MeasureCallback functions to be called after each measure batch.
* \param pre_search_callbacks SearchCallback functions to be called before schedule search.
*/
TuningOptions(int num_measure_trials, int early_stopping, int num_measures_per_round, int verbose,
ProgramBuilder builder, ProgramRunner runner,
Optional<Array<MeasureCallback>> measure_callbacks,
Optional<Array<SearchCallback>> pre_search_callbacks);
Optional<Array<MeasureCallback>> measure_callbacks);

TVM_DEFINE_OBJECT_REF_METHODS(TuningOptions, ObjectRef, TuningOptionsNode);
};

/*!
* \brief Run schedule search for a given compute declaration.
* \param task The search task of the compute declaration.
* \param search_policy The search policy.
* \param tuning_options Tuning and measurement options.
* \return A `te::schedule` and an Array of `te::Tensor` to be used in `tvm.lower` or
* `tvm.build`.
*/
TVM_DLL std::pair<te::Schedule, Array<te::Tensor>> AutoSchedule(SearchTask task,
SearchPolicy search_policy,
TVM_DLL std::pair<te::Schedule, Array<te::Tensor>> AutoSchedule(SearchPolicy search_policy,
TuningOptions tuning_options);
} // namespace auto_scheduler
} // namespace tvm
Expand Down
18 changes: 15 additions & 3 deletions include/tvm/auto_scheduler/compute_dag.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ class AccessAnalyzerNode : public Object {
/*! \brief Store whether the operation is an op with only simple access.
* (e.g., injective, broadcast and elementwise ops without reduction) */
OperationMap<bool> is_simple_access;
/*! \brief Store whether the operation is strictly-inlineable
/*! \brief Store whether the operation is strictly inlineable
* (e.g., injective, broadcast and elementwise without reduction, branch or expenive operations)
*/
OperationMap<bool> is_strict_inlineable;
OperationMap<bool> is_strictly_inlineable;
jcf94 marked this conversation as resolved.
Show resolved Hide resolved
/*! \brief Store whether the operation needs multi-level tiling
* (e.g., computation-intensive ops with data reuse opportunity like matmul, conv2d) */
OperationMap<bool> needs_multi_level_tiling;
Expand Down Expand Up @@ -102,7 +102,7 @@ class AccessAnalyzer : public ObjectRef {
* (e.g., injective, broadcast and elementwise without reduction, branch or expenive operations)
* \param op The operation
*/
TVM_DLL bool IsStrictInlineable(const te::Operation& op) const;
TVM_DLL bool IsStrictlyInlineable(const te::Operation& op) const;

/*!
* \brief Return whether this operation needs multi-level tiling
Expand Down Expand Up @@ -187,6 +187,7 @@ class ComputeDAGNode : public Object {
v->Visit("ops", &ops);
v->Visit("flop_ct", &flop_ct);
v->Visit("init_state", &init_state);
v->Visit("access_analyzer", &access_analyzer);
}

static constexpr const char* _type_key = "auto_scheduler.ComputeDAG";
Expand Down Expand Up @@ -237,6 +238,17 @@ class ComputeDAG : public ObjectRef {
*/
State InferBound(const State& state) const;

/*!
* \brief Fill the correct bound information for the given states by calling ir_pass::InferBound.
* The states can lose complete bound information after some transform steps (e.g., compute_at).
* We can call this function to infer and fill all the bound information.
* This function calls TVM InferBound pass internally to get the bound.
* The returned state of this function is guaranteed to have complete bound information.
* \param states The input states.
* \return The States with complete bound information
*/
Array<State> InferBound(const Array<State>& states) const;

/*!
* \brief Since some steps may change the ComputeDAG (e.g. CacheRead/CacheWrite), the initial
* ComputeDAG may not be up-to-date. This function replays the given transform steps from the
Expand Down
12 changes: 6 additions & 6 deletions include/tvm/auto_scheduler/cost_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
*/

/*!
* \file tvm/auto_scheduler/cost_model.h
* \file auto_scheduler/cost_model.h
* \brief Cost models that estimate the performance of programs
*/

Expand Down Expand Up @@ -54,7 +54,7 @@ class CostModelNode : public Object {
* \param states The input states
* \param scores The predicted scores for all states
*/
virtual void Predict(const SearchTask& task, const std::vector<State>& states,
virtual void Predict(const SearchTask& task, const Array<State>& states,
std::vector<float>* scores) = 0;

/*!
Expand All @@ -64,7 +64,7 @@ class CostModelNode : public Object {
* \param state_scores The predicted scores for all states
* \param stage_scores The predicted scores for all stages in all stages
*/
virtual void PredictStages(const SearchTask& task, const std::vector<State>& states,
virtual void PredictStages(const SearchTask& task, const Array<State>& states,
std::vector<float>* state_scores,
std::vector<std::vector<float>>* stage_scores) {
LOG(FATAL) << "Not implemented";
Expand All @@ -91,7 +91,7 @@ class RandomModelNode : public CostModelNode {

void Update(const Array<MeasureInput>& inputs, const Array<MeasureResult>& results) final;

void Predict(const SearchTask& task, const std::vector<State>& states,
void Predict(const SearchTask& task, const Array<State>& states,
std::vector<float>* scores) final;

static constexpr const char* _type_key = "auto_scheduler.RandomModel";
Expand Down Expand Up @@ -126,10 +126,10 @@ class PythonBasedModelNode : public CostModelNode {

void Update(const Array<MeasureInput>& inputs, const Array<MeasureResult>& results) final;

void Predict(const SearchTask& task, const std::vector<State>& states,
void Predict(const SearchTask& task, const Array<State>& states,
std::vector<float>* scores) final;

void PredictStages(const SearchTask& task, const std::vector<State>& states,
void PredictStages(const SearchTask& task, const Array<State>& states,
std::vector<float>* state_scores,
std::vector<std::vector<float>>* stage_scores) final;

Expand Down
37 changes: 27 additions & 10 deletions include/tvm/auto_scheduler/search_policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,17 @@
*
* Candidate schedules are measured against the specific hardware target.
*
* We intend to introduce different level of automation on the schedule generation process:
* - Level 0(the default level): For all kinds of ops/subgraphs, the search policy should be able
* to generate schedule automatically.
* - Level 1: For some complicated ops/subgraphs(e.g. conv2d windograd), the default search space
* of level 0 may be too large to find a high performance schedule efficiently. We provide some
* op attributes to help reduce the total search space, see `SearchPolicyKey` below for more
* information.
* - Level 2: For some further special ops/subgraphs, users may more likely to write their own
* template(just like AutoTVM). Search policy should be able to provide a flexible approach as
* well.
*
* \note How to add a new search policy.
* In design, there's no need for users to implement their own search policy, our formal search
* policy(will be brought later) should be enough to cover most use cases. Meanwhile, a custom rule
Expand Down Expand Up @@ -89,46 +100,52 @@ class SearchCallback : public ObjectRef {
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SearchCallback, ObjectRef, SearchCallbackNode);
};

/*! \brief Attribute keys of ops used for SearchPolicy. */
struct SearchPolicyKey {
/*! \brief Always apply unroll to the inner most iterator of the specificed iterators. */
static constexpr const char* always_unroll_inner = "auto_scheduler_always_unroll_inner";
/*! \brief The specified iterators will be placed in the inner most tile without split. */
static constexpr const char* no_split_at_inner = "auto_scheduler_no_split_at_inner";
/*! \brief The specified iterators are indices of const tensors in "fake reduction". */
static constexpr const char* simplify_const_tensor_indices =
"auto_scheduler_simplify_const_tensor_indices";
};

/*!
* \brief The base class of search policies.
*/
class SearchPolicyNode : public Object {
public:
/*! \brief The current search task. */
SearchTask cur_task;
SearchTask search_task;
/*!
* \brief Verbose level to control the screen output during schedule search.
* 0 for silent, 1 to output state & measure information during search process.
*/
int verbose;

void VisitAttrs(AttrVisitor* v) {
v->Visit("cur_task", &cur_task);
v->Visit("search_task", &search_task);
v->Visit("verbose", &verbose);
}

/*!
* \brief Do schedule search for a task. Takes the SearchTask as input and returns the best state
* found during the search.
* \param task The SearchTask for the computation declaration
* \param num_measure_trials The number of total measurement trials.
* \param early_stopping Stops the tuning early if no improvement after n measurements.
* \param num_measures_per_round The number of programs to be measured at each search round.
* \param verbose Verbose level. 0 for silent, 1 to output information during schedule
* search.
* \param measurer A ProgramMeasurer to build and measure programs
* \param pre_search_callbacks SearchCallback to be called before schedule search.
* \return The best state found.
*/
virtual State Search(SearchTask task, int num_measure_trials, int early_stopping,
int num_measures_per_round, int verbose, ProgramMeasurer measurer,
Optional<Array<SearchCallback>> pre_search_callbacks) = 0;
virtual State Search(int num_measure_trials, int early_stopping, int num_measures_per_round,
ProgramMeasurer measurer) = 0;

/*!
* \brief Call SearchCallback with the current SearchPolicyNode
* \param callbacks SearchCallback to be called.
*/
void RunCallbacks(const Optional<Array<SearchCallback>>& callbacks);
void RunCallbacks(const Array<SearchCallback>& callbacks);

static constexpr const char* _type_key = "auto_scheduler.SearchPolicy";
TVM_DECLARE_BASE_OBJECT_INFO(SearchPolicyNode, Object);
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/auto_scheduler/transform_step.h
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ class SplitStepNode : public StepNode {
/*! \brief The id of the iter to split. */
int iter_id;
/*! \brief The extent length of the axis to split. */
Optional<Integer> extent;
Optional<PrimExpr> extent;
/*! \brief The split factors. */
Array<Optional<Integer>> lengths;
/*!
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/auto_scheduler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

# Shortcut
from .auto_schedule import SearchTask, TuningOptions, HardwareParams, \
auto_schedule, EmptyPolicy
auto_schedule, EmptyPolicy, SketchPolicy
from .compute_dag import ComputeDAG
from .cost_model import RandomModel
from .measure import MeasureInput, LocalBuilder, LocalRunner, RPCRunner, \
Expand Down
Loading