Skip to content

Commit

Permalink
Add rest rules
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac committed Aug 20, 2020
1 parent 19f6c18 commit de6e232
Show file tree
Hide file tree
Showing 3 changed files with 203 additions and 11 deletions.
12 changes: 9 additions & 3 deletions src/auto_scheduler/search_policy/sketch_policy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ static InitVectorization init_vectorization;
/********** Mutation rules **********/

static MutateTileSize mutate_tile_size;
static MutateMaxUnrollFactor mutate_max_unroll_factor;
static MutateComputeLocation mutate_compute_location;
static MutateParallel mutate_parallel;

/********** Sketch policy **********/

Expand Down Expand Up @@ -110,6 +113,9 @@ SketchPolicy::SketchPolicy(SearchTask task, CostModel schedule_cost_model,

// The default mutation rules.
node->mutation_rules.push_back(&mutate_tile_size);
node->mutation_rules.push_back(&mutate_max_unroll_factor);
node->mutation_rules.push_back(&mutate_compute_location);
node->mutation_rules.push_back(&mutate_parallel);

data_ = std::move(node);
}
Expand Down Expand Up @@ -412,12 +418,12 @@ Array<State> SketchPolicyNode::EvolutionarySearch(const Array<State>& init_popul
rule_select_probs.reserve(mutation_rules.size());
std::vector<float> rule_levels;
for (const auto& rule : mutation_rules) {
rule_levels.push_back(rule->GetLevel());
rule_levels.push_back(rule->GetLevel(search_task));
}
assign_prob(rule_levels, &rule_select_probs);

// Evaluate the init populations.
search_task->compute_dag.InferBound(*pnow);
*pnow = search_task->compute_dag.InferBound(*pnow);
PruneInvalidState(search_task, pnow);
CHECK_GT(pnow->size(), 0) << "All initial populations are invalid";
schedule_cost_model->Predict(search_task, *pnow, &scores);
Expand Down Expand Up @@ -452,7 +458,7 @@ Array<State> SketchPolicyNode::EvolutionarySearch(const Array<State>& init_popul
}

// Evaluate the new populations.
search_task->compute_dag.InferBound(*pnext);
*pnext = search_task->compute_dag.InferBound(*pnext);
PruneInvalidState(search_task, pnext);

// Throw away all states generated in this iterations if all new states are invalid.
Expand Down
155 changes: 150 additions & 5 deletions src/auto_scheduler/search_policy/sketch_policy_rules.cc
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ InitPopulationRule::ResultKind InitVectorization::Apply(SketchPolicyNode* policy
return ResultKind::kValid;
}

MutateTileSize::ResultKind MutateTileSize::Apply(SketchPolicyNode* policy, State* state) const {
MutationRule::ResultKind MutateTileSize::Apply(SketchPolicyNode* policy, State* state) const {
int max_innermost_split_factor =
GetIntParam(policy->params, SketchParamKey::max_innermost_split_factor);

Expand All @@ -599,7 +599,7 @@ MutateTileSize::ResultKind MutateTileSize::Apply(SketchPolicyNode* policy, State
}
if (split_step_ids.empty()) {
// No tile size could be mutated.
return MutateTileSize::ResultKind::kInvalid;
return ResultKind::kInvalid;
}

// Select a SplitStep with extent larger than one to mutate.
Expand All @@ -618,7 +618,7 @@ MutateTileSize::ResultKind MutateTileSize::Apply(SketchPolicyNode* policy, State

if (extent <= 1) {
// Cannot find a step with extent larger than one.
return MutateTileSize::ResultKind::kInvalid;
return ResultKind::kInvalid;
}

// Fetch the current tile sizes.
Expand Down Expand Up @@ -679,9 +679,154 @@ MutateTileSize::ResultKind MutateTileSize::Apply(SketchPolicyNode* policy, State
step_id, SplitStep(ps->stage_id, ps->iter_id, ps->extent,
Array<Optional<Integer>>(new_lengths.begin(), new_lengths.end()),
ps->inner_to_outer));
return MutateTileSize::ResultKind::kValid;
return ResultKind::kValid;
}
return ResultKind::kInvalid;
}

MutationRule::ResultKind MutateMaxUnrollFactor::Apply(SketchPolicyNode* policy,
State* state) const {
// Extract all auto_unroll_max_step pragma steps.
std::vector<int> annotate_steps;
for (size_t i = 0; i < (*state)->transform_steps.size(); ++i) {
if (auto ps = (*state)->transform_steps[i].as<PragmaStepNode>()) {
if (StrStartsWith(ps->pragma_type, "auto_unroll_max_step")) {
annotate_steps.push_back(i);
}
}
}
if (annotate_steps.empty()) {
return ResultKind::kInvalid;
}

// Random pick up one unroll factor candidate.
auto cands = (IsGPUTask(policy->search_task))? &gpu_unroll_cands_: &cpu_unroll_cands_;
auto new_factor = std::to_string((*cands)[(policy->rand_gen)() % cands->size()]);

// Random pick up and mutate an unroll step.
auto step_id = annotate_steps[(policy->rand_gen)() % annotate_steps.size()];
auto ps = (*state)->transform_steps[step_id].as<PragmaStepNode>();
CHECK(ps);
StateNode* pstate = state->CopyOnWrite();
pstate->transform_steps.Set(step_id,
PragmaStep(ps->stage_id, ps->iter_id,
std::string("auto_unroll_max_step") + "$" + new_factor));
return ResultKind::kValid;
}

MutationRule::ResultKind MutateComputeLocation::Apply(SketchPolicyNode* policy,
State* state) const {
// FIXME (@comaniac, @jc94): Combine initial population rules with the mutation rules.
static InitChangeComputeLocation mutate_compute_location;
if (mutate_compute_location.Apply(policy, state) == InitPopulationRule::ResultKind::kInvalid) {
return ResultKind::kInvalid;
}
return ResultKind::kValid;
}

MutationRule::ResultKind MutateParallel::Apply(SketchPolicyNode* policy,
State* state) const {
// This mutation rule only focuses on a case that parallel was added to
// the outermost loop and the loop is generated by fusing other loops.
// In short, we mutate the fusion step before the parallel step.

// Extract all parallel steps.
std::vector<int> parallel_steps;
for (size_t s = 0; s < (*state)->transform_steps.size(); ++s) {
auto ps = (*state)->transform_steps[s].as<AnnotationStepNode>();
if (!ps || ps->annotation != IteratorAnnotation::kParallel) {
continue;
}

// Skip non-outermost loop or the parallel step without fusion beforehand.
if (ps->iter_id > 0 || s == 0 || !(*state)->transform_steps[s - 1].as<FuseStepNode>()) {
continue;
}
parallel_steps.push_back(s);
}
if (parallel_steps.empty()) {
return ResultKind::kInvalid;
}

// Randomly pick one parallel step.
size_t step_id = parallel_steps[(policy->rand_gen)() % parallel_steps.size()];
auto ps = (*state)->transform_steps[step_id].as<AnnotationStepNode>();
CHECK(ps);
size_t stage_id = ps->stage_id;
size_t iter_id = ps->iter_id;
const Stage& stage = (*state)->stages[stage_id];
const Iterator& it = stage->iters[iter_id];

// Replay a new state until the picked fuse step.
State tmp_s = policy->search_task->compute_dag->init_state;
for (size_t s = 0; s < step_id - 1; ++s) {
auto step = (*state)->transform_steps[s];
tmp_s.CopyOnWrite()->transform_steps.push_back(step);
StepApplyToState(step, &tmp_s, policy->search_task->compute_dag);
}
return MutateTileSize::ResultKind::kInvalid;

// Determine the fusion mutation direction.
// 0: fuse less; 1: fuse more.
auto fuse_step = (*state)->transform_steps[step_id - 1].as<FuseStepNode>();
auto fused_ids = fuse_step->fused_ids;
std::vector<double> fuse_dir = {0.5, 1.0};

// The case that we can only fuse more. This may happen after multiple mutations.
if (fused_ids.size() == 1) {
fuse_dir[0] = 0.0;
}

// The cases that we cannot fuse the next iters.
if ((*state)->attach_map->iter_to_attached_stages.count(std::make_pair(stage_id, iter_id)) ||
it->iter_kind == IteratorKind::kReduction || it->annotation != IteratorAnnotation::kNone) {
if (fuse_dir[0] == 0.0) {
// No room to mutate this fusion.
return ResultKind::kInvalid;
}
fuse_dir[0] = 1.0;
}

// Mutate the fusion iters and replay the mutated fused/annotation steps.
int iter_offset = 0;
if (RandomChoose(fuse_dir, &(policy->rand_gen)) == 0) {
fused_ids.pop_back();
iter_offset = 1;
} else {
auto last_id = fused_ids.back().get()->value;
fused_ids.push_back(last_id + 1);
iter_offset = -1;
}
auto new_fuse_step = FuseStep(stage_id, fused_ids);
tmp_s.CopyOnWrite()->transform_steps.push_back(new_fuse_step);
StepApplyToState(new_fuse_step, &tmp_s, policy->search_task->compute_dag);
tmp_s.CopyOnWrite()->transform_steps.push_back((*state)->transform_steps[step_id]);
StepApplyToState((*state)->transform_steps[step_id], &tmp_s, policy->search_task->compute_dag);

// Replay the rest steps.
for (size_t s = step_id + 1; s < (*state)->transform_steps.size(); ++s) {
auto step = (*state)->transform_steps[s];
if (step->stage_id == static_cast<int>(stage_id)) {
// Since we changed the loop structure, iter ID in later steps to the same stage
// has to be adjusted.
auto ps = step.as<AnnotationStepNode>();
if (ps) {
if (ps->iter_id == 0) {
step = AnnotationStep(ps->stage_id, 0, ps->annotation);
} else {
CHECK_LE(ps->iter_id + iter_offset, tmp_s->stages[stage_id]->iters.size());
step = AnnotationStep(ps->stage_id, ps->iter_id + iter_offset, ps->annotation);
}
} else {
// Unexpected step node that we did not process for now.
return ResultKind::kInvalid;
}
}
tmp_s.CopyOnWrite()->transform_steps.push_back(step);
StepApplyToState(step, &tmp_s, policy->search_task->compute_dag);
}

*state = tmp_s;
return ResultKind::kValid;
}

} // namespace auto_scheduler
Expand Down
47 changes: 44 additions & 3 deletions src/auto_scheduler/search_policy/sketch_policy_rules.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,13 @@
#define TVM_AUTO_SCHEDULER_SEARCH_POLICY_SKETCH_POLICY_RULES_H_

#include <tvm/auto_scheduler/loop_state.h>
#include <tvm/auto_scheduler/search_task.h>

#include <utility>
#include <vector>

#include "utils.h"

namespace tvm {
namespace auto_scheduler {

Expand Down Expand Up @@ -204,21 +207,59 @@ class InitVectorization : public InitPopulationRule {
/********** Mutation **********/

/*! \brief The base class for mutation rules used in the evolutionary search. */
class MutationRule : public InitPopulationRule {
class MutationRule {
public:
/*! \brief Result enumeration of the apply function. */
enum class ResultKind : int { kValid = 0, kInvalid = 1 };

/*!
* \brief Apply function of this rule.
* \param policy The SketchPolicyNode of this rule, some member may get changed during the
* rule applying. (e.g. random number generator)
* \param state The state to apply this rule, update inplace.
* \return The result of this rule, indicate if there's any valid state generated.
*/
virtual ResultKind Apply(SketchPolicyNode* policy, State* state) const = 0;

/*!
* \brief Get the priority level of this mutation rule.
* \return The priority level of this mutation rule. Higher the better.
*/
virtual int GetLevel() const = 0;
virtual int GetLevel(const SearchTask& task) const = 0;
};

/*! \brief The rule that mutates tile size by randomly dividing a tile size by a factor
and multipling it to another tile size. */
class MutateTileSize : public MutationRule {
public:
ResultKind Apply(SketchPolicyNode* policy, State* state) const final;
int GetLevel() const final { return 50; }
int GetLevel(const SearchTask& task) const final { return 100; }
};

class MutateMaxUnrollFactor : public MutationRule {
public:
ResultKind Apply(SketchPolicyNode* policy, State* state) const final;
int GetLevel(const SearchTask& task) const final { return 10; }

const std::vector<int> cpu_unroll_cands_ = {0, 16, 64, 512, 1024};
const std::vector<int> gpu_unroll_cands_ = {0, 16, 64, 512};
};

class MutateComputeLocation : public MutationRule {
public:
ResultKind Apply(SketchPolicyNode* policy, State* state) const final;
int GetLevel(const SearchTask& task) const final {
if (IsGPUTask(task)) {
return 0;
}
return 5;
}
};

class MutateParallel : public MutationRule {
public:
ResultKind Apply(SketchPolicyNode* policy, State* state) const final;
int GetLevel(const SearchTask& task) const final { return 50; }
};

} // namespace auto_scheduler
Expand Down

0 comments on commit de6e232

Please sign in to comment.