Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Ansor][AutoTVM v2.0] Phase 2: Evolutionary Search #6310

Merged
merged 6 commits into from
Aug 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions python/tvm/auto_scheduler/search_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ class SketchPolicy(SearchPolicy):
"retry_search_one_round_on_empty": 10,

'evolutionary_search_population': 2048,
'evolutionary_search_num_iters': 10,
'evolutionary_search_mutation_prob': 0.85,
"evolutionary_search_use_measured_ratio": 0.2,

'cpu_multi_level_tiling_structure': 'SSRSRS',
Expand Down Expand Up @@ -178,3 +180,21 @@ def sample_initial_population(self, pop_size):
"""
states = _ffi_api.SketchPolicySampleInitialPopulation(self, pop_size)
return states

def evolutionary_search(self, init_populuations, out_size):
"""Evolutionary search.
This python interface is mainly used for debugging and testing.
The actual search is all doen in c++.
Parameters
----------
init_populations: List[State]
The initial population states
out_size : int
The size of generated states
Returns
-------
states: List[State]
The generated states
"""
states = _ffi_api.SketchPolicyEvolutionarySearch(self, init_populuations, out_size)
return states
166 changes: 163 additions & 3 deletions src/auto_scheduler/search_policy/sketch_policy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <algorithm>
#include <iomanip>
#include <limits>
#include <queue>
#include <set>
#include <string>
#include <unordered_map>
Expand Down Expand Up @@ -65,6 +66,13 @@ static InitUnroll init_unroll;
static InitVectorization init_vectorization;
static InitThreadBind init_thread_bind;

/********** Mutation rules **********/

static MutateTileSize mutate_tile_size;
static MutateMaxUnrollFactor mutate_max_unroll_factor;
static MutateComputeLocation mutate_compute_location;
static MutateParallel mutate_parallel;

/********** Sketch policy **********/

TVM_REGISTER_NODE_TYPE(SketchPolicyNode);
Expand Down Expand Up @@ -129,6 +137,12 @@ SketchPolicy::SketchPolicy(SearchTask task, CostModel schedule_cost_model,
LOG(FATAL) << "No default init rules for target: " << task->target;
}

// The default mutation rules.
node->mutation_rules.push_back(&mutate_tile_size);
node->mutation_rules.push_back(&mutate_max_unroll_factor);
node->mutation_rules.push_back(&mutate_compute_location);
node->mutation_rules.push_back(&mutate_parallel);

data_ = std::move(node);
}

Expand Down Expand Up @@ -336,7 +350,7 @@ Array<State> SketchPolicyNode::SampleInitPopulation(const Array<State>& sketches
// Derivation rule based enumeration
bool valid = true;
for (const auto& rule : init_rules) {
if (rule->Apply(this, &tmp_s) == InitPopulationRule::ResultKind::kInvalid) {
if (rule->Apply(this, &tmp_s) == PopulationGenerationRule::ResultKind::kInvalid) {
valid = false;
break;
}
Expand All @@ -363,8 +377,148 @@ Array<State> SketchPolicyNode::EvolutionarySearch(const Array<State>& init_popul
Array<State> best_states;
auto tic_begin = std::chrono::high_resolution_clock::now();

// TODO(comaniac, merrymercy, jcf94): Since we haven't finished porting the cost model part
// yet, currently delete the implementation of EvolutionarySearch. To be added later.
size_t population = init_population.size();
int num_iters = GetIntParam(params, SketchParamKey::EvolutionarySearch::num_iters);
double mutation_prob = GetDoubleParam(params, SketchParamKey::EvolutionarySearch::mutation_prob);

// Two ping pong buffers to avoid copy.
Array<State> states_buf1{init_population}, states_buf2;
states_buf1.reserve(population);
states_buf2.reserve(population);
Array<State>* pnow = &states_buf1;
Array<State>* pnext = &states_buf2;

// The set of explored states to avoid redundancy.
std::unordered_set<std::string> explored_set;

// The heap to maintain the so far best states.
using StateHeapItem = std::pair<State, float>;
auto cmp = [](const StateHeapItem& left, const StateHeapItem& right) {
return left.second > right.second;
};
using StateHeap = std::priority_queue<StateHeapItem, std::vector<StateHeapItem>, decltype(cmp)>;
StateHeap heap(cmp);
auto update_heap = [&heap, &explored_set](const Array<State>& states,
const std::vector<float>& scores, const int out_size) {
float max_score = 0.0;
for (size_t i = 0; i < states.size(); ++i) {
const State& state = states[i];
std::string state_str = state.ToStr();

// Skip redundant states.
if (explored_set.count(state_str) > 0) {
continue;
}
explored_set.insert(state_str);

if (static_cast<int>(heap.size()) < out_size) {
// Directly push item if the heap is not full yet.
heap.push({state, scores[i]});
} else if (scores[i] > heap.top().second) {
// Replace the worst state in the heap with the new state.
heap.pop();
heap.push({state, scores[i]});
}
max_score = (scores[i] > max_score) ? scores[i] : max_score;
}
return max_score;
};

// Cost model predicted scores.
std::vector<float> scores;
scores.reserve(population);

// The function to generate prefix sum probabilities based on the given scores.
auto assign_prob = [](const std::vector<float>& scores, std::vector<double>* prefix_sum_probs) {
// Compute selection probabilities.
double sum = 0.0;
prefix_sum_probs->resize(scores.size());
for (size_t i = 0; i < scores.size(); ++i) {
sum += std::max(scores[i], 0.0f);
(*prefix_sum_probs)[i] = sum;
}
for (size_t i = 0; i < scores.size(); ++i) {
(*prefix_sum_probs)[i] /= sum;
}
};

// State selection probabilities.
std::uniform_real_distribution<> uniform_dist(0.0, 1.0);
std::vector<double> state_select_probs;
state_select_probs.reserve(population);

// Mutation rule selection probabilities.
std::vector<double> rule_select_probs;
rule_select_probs.reserve(mutation_rules.size());
std::vector<float> rule_levels;
for (const auto& rule : mutation_rules) {
rule_levels.push_back(rule->GetLevel(search_task));
}
assign_prob(rule_levels, &rule_select_probs);

// Evaluate the init populations.
*pnow = search_task->compute_dag.InferBound(*pnow);
PruneInvalidState(search_task, pnow);
CHECK_GT(pnow->size(), 0) << "All initial populations are invalid";
schedule_cost_model->Predict(search_task, *pnow, &scores);

// Maintain the best states in the heap.
float max_score = update_heap(*pnow, scores, out_size);

// Genetic algorithm.
for (auto iter_idx = 1; iter_idx <= num_iters; ++iter_idx) {
// Assign the selection probability to each state based on the cost model scores.
assign_prob(scores, &state_select_probs);

// TODO(@comaniac): Perform cross over.

// Perform mutations.
size_t fail_ct = 0;
while (pnext->size() < population && fail_ct < population * 2) {
// Select a state to be mutated.
State tmp_s = (*pnow)[RandomChoose(state_select_probs, &rand_gen)];
if (uniform_dist(rand_gen) < mutation_prob) {
// Select a rule and mutate the state.
const auto& rule = mutation_rules[RandomChoose(rule_select_probs, &rand_gen)];
if (rule->Apply(this, &tmp_s) == PopulationGenerationRule::ResultKind::kValid) {
pnext->push_back(std::move(tmp_s));
} else {
fail_ct++;
}
} else {
// Do not mutate this state in this round.
pnext->push_back(std::move(tmp_s));
}
}

// Evaluate the new populations.
*pnext = search_task->compute_dag.InferBound(*pnext);
PruneInvalidState(search_task, pnext);

// Throw away all states generated in this iterations if all new states are invalid.
if (pnext->size() > 0) {
std::swap(pnext, pnow);
schedule_cost_model->Predict(search_task, *pnow, &scores);

// Maintain the best states in the heap.
float iter_max_score = update_heap(*pnow, scores, out_size);
max_score = (iter_max_score > max_score) ? iter_max_score : max_score;
}
pnext->clear();

if (iter_idx % 5 == 0 || iter_idx == num_iters) {
StdCout(verbose) << "GA Iter: " << iter_idx << std::fixed << std::setprecision(4)
<< "\tMax Score: " << max_score << "\tPop Size: " << pnow->size()
<< std::endl;
}
}

// Copy best states in the heap to the output.
while (!heap.empty()) {
auto item = heap.top();
heap.pop();
best_states.push_back(std::move(item.first));
}

double duration = std::chrono::duration_cast<std::chrono::duration<double>>(
std::chrono::high_resolution_clock::now() - tic_begin)
Expand Down Expand Up @@ -441,5 +595,11 @@ TVM_REGISTER_GLOBAL("auto_scheduler.SketchPolicySampleInitialPopulation")
return init_population;
});

TVM_REGISTER_GLOBAL("auto_scheduler.SketchPolicyEvolutionarySearch")
.set_body_typed([](SketchPolicy policy, Array<State> init_population, int out_size) {
Array<State> states = policy->EvolutionarySearch(init_population, out_size);
return states;
});

} // namespace auto_scheduler
} // namespace tvm
24 changes: 15 additions & 9 deletions src/auto_scheduler/search_policy/sketch_policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ struct SketchParamKey {
struct EvolutionarySearch {
/*! \brief The population size for evolutionary search. */
static constexpr const char* population = "evolutionary_search_population";
/*! \brief The number of iterations performed by generic algorithm.*/
static constexpr const char* num_iters = "evolutionary_search_num_iters";
/*! \brief The mutation probability.*/
static constexpr const char* mutation_prob = "evolutionary_search_mutation_prob";
/*! \brief The maximum percentage of measured states in the initial population for evolutionary
* search. */
static constexpr const char* use_measured_ratio = "evolutionary_search_use_measured_ratio";
Expand Down Expand Up @@ -90,7 +94,9 @@ class SketchPolicyNode : public SearchPolicyNode {
/*! \brief The rules to generate sketches. */
std::vector<SketchGenerationRule*> sketch_rules;
/*! \brief The rules to generate initial states. */
std::vector<InitPopulationRule*> init_rules;
std::vector<PopulationGenerationRule*> init_rules;
/*! \brief The rules to mutate states. */
std::vector<PopulationMutationRule*> mutation_rules;
/*! \brief Random generator. */
std::mt19937 rand_gen;
/*! \brief Memorize split space for Split. */
Expand All @@ -113,6 +119,14 @@ class SketchPolicyNode : public SearchPolicyNode {
*/
Array<State> SampleInitPopulation(const Array<State>& sketches, int out_size);

/*!
* \brief Perform evolutionary search.
* \param init_populations The states generated from init population.
* \param out_size The number of expected output states.
* \return The generated states after evolutionary search.
*/
Array<State> EvolutionarySearch(const Array<State>& init_populations, int out_size);

static constexpr const char* _type_key = "auto_scheduler.SketchPolicy";

TVM_DECLARE_FINAL_OBJECT_INFO(SketchPolicyNode, SearchPolicyNode);
Expand All @@ -127,14 +141,6 @@ class SketchPolicyNode : public SearchPolicyNode {
*/
Array<State> SearchOneRound(int num_random_states, Array<State>* random_states = nullptr);

/*!
* \brief Perform evolutionary search.
* \param init_populations The states generated from init population.
* \param out_size The number of expected output states.
* \return The generated states after evolutionary search.
*/
Array<State> EvolutionarySearch(const Array<State>& init_populations, int out_size);

/*!
* \brief Pick states from best states and random states with eps-greedy policy.
* \param best_states States picked by cost model.
Expand Down
Loading