From c4ff93fb311019ed101203243d087fef5889d04c Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Mon, 24 Aug 2020 18:06:57 +0000 Subject: [PATCH] refactor --- .../search_policy/sketch_policy.cc | 7 +- .../search_policy/sketch_policy.h | 4 +- .../search_policy/sketch_policy_rules.cc | 263 +++++++++--------- .../search_policy/sketch_policy_rules.h | 52 ++-- 4 files changed, 154 insertions(+), 172 deletions(-) diff --git a/src/auto_scheduler/search_policy/sketch_policy.cc b/src/auto_scheduler/search_policy/sketch_policy.cc index 5e1e58c894235..5da7f154409b4 100644 --- a/src/auto_scheduler/search_policy/sketch_policy.cc +++ b/src/auto_scheduler/search_policy/sketch_policy.cc @@ -60,7 +60,6 @@ static RuleSpecialComputeLocationGPU rule_special_compute_location_gpu; /********** Init population rules **********/ static InitFillTileSize init_fill_tile_size; -static InitChangeComputeLocation init_change_compute_location; static InitParallel init_parallel; static InitUnroll init_unroll; static InitVectorization init_vectorization; @@ -125,7 +124,7 @@ SketchPolicy::SketchPolicy(SearchTask task, CostModel schedule_cost_model, node->init_rules.push_back(&init_fill_tile_size); // This should always be the first rule if (IsCPUTask(node->search_task)) { // The default init population rules for CPU policy - node->init_rules.push_back(&init_change_compute_location); + node->init_rules.push_back(&mutate_compute_location); node->init_rules.push_back(&init_parallel); node->init_rules.push_back(&init_unroll); node->init_rules.push_back(&init_vectorization); @@ -350,7 +349,7 @@ Array SketchPolicyNode::SampleInitPopulation(const Array& sketches // Derivation rule based enumeration bool valid = true; for (const auto& rule : init_rules) { - if (rule->Apply(this, &tmp_s) == InitPopulationRule::ResultKind::kInvalid) { + if (rule->Apply(this, &tmp_s) == PopulationGenerationRule::ResultKind::kInvalid) { valid = false; break; } @@ -482,7 +481,7 @@ Array SketchPolicyNode::EvolutionarySearch(const Array& init_popul if (uniform_dist(rand_gen) < mutation_prob) { // Select a rule and mutate the state. const auto rule = mutation_rules[RandomChoose(rule_select_probs, &rand_gen)]; - if (rule->Apply(this, &tmp_s) == MutationRule::ResultKind::kValid) { + if (rule->Apply(this, &tmp_s) == PopulationGenerationRule::ResultKind::kValid) { pnext->push_back(std::move(tmp_s)); } else { fail_ct++; diff --git a/src/auto_scheduler/search_policy/sketch_policy.h b/src/auto_scheduler/search_policy/sketch_policy.h index eee93f9b75f9a..2d93d8775c867 100644 --- a/src/auto_scheduler/search_policy/sketch_policy.h +++ b/src/auto_scheduler/search_policy/sketch_policy.h @@ -94,9 +94,9 @@ class SketchPolicyNode : public SearchPolicyNode { /*! \brief The rules to generate sketches. */ std::vector sketch_rules; /*! \brief The rules to generate initial states. */ - std::vector init_rules; + std::vector init_rules; /*! \brief The rules to mutate states. */ - std::vector mutation_rules; + std::vector mutation_rules; /*! \brief Random generator. */ std::mt19937 rand_gen; /*! \brief Memorize split space for Split. */ diff --git a/src/auto_scheduler/search_policy/sketch_policy_rules.cc b/src/auto_scheduler/search_policy/sketch_policy_rules.cc index 86bb897423b78..d954aec4a66b2 100644 --- a/src/auto_scheduler/search_policy/sketch_policy_rules.cc +++ b/src/auto_scheduler/search_policy/sketch_policy_rules.cc @@ -436,8 +436,8 @@ std::vector> RuleSpecialComputeLocationGPU::Apply( /********** Init Population **********/ -InitPopulationRule::ResultKind InitFillTileSize::Apply(SketchPolicyNode* policy, - State* state) const { +PopulationGenerationRule::ResultKind InitFillTileSize::Apply(SketchPolicyNode* policy, + State* state) const { StateNode* pstate = state->CopyOnWrite(); // Scan the transformation history and randomly fill tiles size for all SplitStep for (size_t step_id = 0; step_id < (*state)->transform_steps.size(); ++step_id) { @@ -472,123 +472,9 @@ InitPopulationRule::ResultKind InitFillTileSize::Apply(SketchPolicyNode* policy, return ResultKind::kValid; } -InitPopulationRule::ResultKind InitChangeComputeLocation::Apply(SketchPolicyNode* policy, - State* state) const { - if (GetIntParam(policy->params, SketchParamKey::disable_change_compute_location)) { - return ResultKind::kValid; - } - - for (int stage_id = static_cast((*state)->stages.size()) - 1; stage_id >= 0; stage_id--) { - const Stage& stage = (*state)->stages[stage_id]; - // Skip the inlined stages and placeholders - if (stage->op_type == StageKind::kPlaceholder || stage->compute_at == ComputeAtKind::kInlined) { - continue; - } - // Skip the tiled stages - if (IsTiled(stage) || NeedsMultilevelTiling(policy->search_task, *state, stage_id)) { - continue; - } - - int target_stage_id = GetSingleConsumerId(policy->search_task, *state, stage_id); - if (target_stage_id < 0) { - continue; - } - const Stage& target_stage = (*state)->stages[target_stage_id]; - - std::vector> candidates; - bool target_compute_at_other = target_stage->compute_at == ComputeAtKind::kIter; - bool target_is_tiled = IsTiled(target_stage); - - bool visited_reduce = false; - // enumerate compute_at location at target_stage - // TODO(merrymercy): More analysis here to make smarter choices - for (size_t i = 0; i < target_stage->iters.size(); ++i) { - const Iterator& target_iter = target_stage->iters[i]; - if (target_iter->iter_kind == IteratorKind::kReduction) { - visited_reduce = true; - if (!target_is_tiled) { // Do not go into reduce iter - break; - } - } else if (target_iter->iter_kind == IteratorKind::kSpatial) { - if (visited_reduce) { // Do not go into inner tile - break; - } - } - - if (target_iter->annotation == IteratorAnnotation::kUnroll) { - // Do not go into the unroll region of const tensor indices - break; - } - - if (GetExtent(target_iter) == 1) { - // Skip iterators with length of 1 - continue; - } - if (target_compute_at_other && target_iter->iter_kind == IteratorKind::kSpatial && - StrEndsWith(target_iter->name, ".0")) { - // Skip the first level iterators if target stage compute_at another stage - // In this case, the lengths of first level iterators are always one - continue; - } - candidates.emplace_back(target_stage_id, i); - - if ((*state)->attach_map->iter_to_attached_stages.count(std::make_pair(target_stage_id, i))) { - break; - } - } - - // if the target_stage is already compute_at another stage X, try also compute_at X - // We call stage X as `target_target_stage` - if (target_compute_at_other) { - int target_target_stage_id; - target_target_stage_id = (*state)->attach_map->stage_to_attach_iter.at(target_stage_id).first; - const Stage& target_target_stage = (*state)->stages[target_target_stage_id]; - - for (size_t i = 0; i < target_target_stage->iters.size(); ++i) { - const Iterator& target_target_iter = target_target_stage->iters[i]; - if (target_target_iter->iter_kind == IteratorKind::kReduction || - (*state)->attach_map->iter_to_attached_stages.count( - std::make_pair(target_target_stage_id, i))) { - break; - } - - if (target_target_iter->annotation == IteratorAnnotation::kUnroll) { - // Do not go into the unroll region of const tensor indices - break; - } - - if (GetExtent(target_target_iter) == 1) { // skip iterators with length of 1 - continue; - } - - candidates.emplace_back(target_target_stage_id, i); - } - } - - int choice = (policy->rand_gen)() % (candidates.size() + 2); - - if (choice == 0) { - if (!HasReduceIter(stage)) { - const auto& stage_to_attach_iter = (*state)->attach_map->stage_to_attach_iter; - if (stage_to_attach_iter.find(stage_id) != stage_to_attach_iter.end()) { - state->compute_inline(stage_id); - } - } - } else if (choice == 1) { - state->compute_root(stage_id); - } else { - choice = choice - 2; - const Stage& stage = (*state)->stages[candidates[choice].first]; - state->compute_at(stage_id, candidates[choice].first, - stage->iters[candidates[choice].second]); - } - } - - *state = policy->search_task->compute_dag.InferBound(*state); - return ResultKind::kValid; -} -InitPopulationRule::ResultKind InitParallel::Apply(SketchPolicyNode* policy, State* state) const { +PopulationGenerationRule::ResultKind InitParallel::Apply(SketchPolicyNode* policy, + State* state) const { std::function annotate_parallel; annotate_parallel = [&annotate_parallel](const SketchPolicyNode& policy, State* state, @@ -652,7 +538,8 @@ InitPopulationRule::ResultKind InitParallel::Apply(SketchPolicyNode* policy, Sta return ResultKind::kValid; } -InitPopulationRule::ResultKind InitUnroll::Apply(SketchPolicyNode* policy, State* state) const { +PopulationGenerationRule::ResultKind InitUnroll::Apply(SketchPolicyNode* policy, + State* state) const { std::vector auto_unroll_configs = IsGPUTask(policy->search_task) ? std::vector({0, 16, 64, 512, 1024}) : std::vector({0, 16, 64, 512}); @@ -703,8 +590,8 @@ InitPopulationRule::ResultKind InitUnroll::Apply(SketchPolicyNode* policy, State return ResultKind::kValid; } -InitPopulationRule::ResultKind InitVectorization::Apply(SketchPolicyNode* policy, - State* state) const { +PopulationGenerationRule::ResultKind InitVectorization::Apply(SketchPolicyNode* policy, + State* state) const { for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) { const Stage& stage = (*state)->stages[stage_id]; // Skip the inlined stage and placeholder stage @@ -762,7 +649,8 @@ InitPopulationRule::ResultKind InitVectorization::Apply(SketchPolicyNode* policy return ResultKind::kValid; } -InitPopulationRule::ResultKind InitThreadBind::Apply(SketchPolicyNode* policy, State* state) const { +PopulationGenerationRule::ResultKind InitThreadBind::Apply(SketchPolicyNode* policy, + State* state) const { std::set multi_level_tiling_root_set; for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) { if (NeedsMultilevelTiling(policy->search_task, *state, stage_id)) { @@ -911,7 +799,8 @@ InitPopulationRule::ResultKind InitThreadBind::Apply(SketchPolicyNode* policy, S return ResultKind::kValid; } -MutationRule::ResultKind MutateTileSize::Apply(SketchPolicyNode* policy, State* state) const { +PopulationGenerationRule::ResultKind MutateTileSize::Apply(SketchPolicyNode* policy, + State* state) const { int max_innermost_split_factor = GetIntParam(policy->params, SketchParamKey::max_innermost_split_factor); @@ -1015,8 +904,8 @@ MutationRule::ResultKind MutateTileSize::Apply(SketchPolicyNode* policy, State* return ResultKind::kInvalid; } -MutationRule::ResultKind MutateMaxUnrollFactor::Apply(SketchPolicyNode* policy, - State* state) const { +PopulationGenerationRule::ResultKind MutateMaxUnrollFactor::Apply(SketchPolicyNode* policy, + State* state) const { // Extract all auto_unroll_max_step pragma steps. std::vector annotate_steps; for (size_t i = 0; i < (*state)->transform_steps.size(); ++i) { @@ -1031,7 +920,7 @@ MutationRule::ResultKind MutateMaxUnrollFactor::Apply(SketchPolicyNode* policy, } // Random pick up one unroll factor candidate. - auto cands = (IsGPUTask(policy->search_task))? &gpu_unroll_cands_: &cpu_unroll_cands_; + auto cands = (IsGPUTask(policy->search_task)) ? &gpu_unroll_cands_ : &cpu_unroll_cands_; auto new_factor = std::to_string((*cands)[(policy->rand_gen)() % cands->size()]); // Random pick up and mutate an unroll step. @@ -1045,18 +934,124 @@ MutationRule::ResultKind MutateMaxUnrollFactor::Apply(SketchPolicyNode* policy, return ResultKind::kValid; } -MutationRule::ResultKind MutateComputeLocation::Apply(SketchPolicyNode* policy, - State* state) const { - // FIXME (@comaniac, @jc94): Combine initial population rules with the mutation rules. - static InitChangeComputeLocation mutate_compute_location; - if (mutate_compute_location.Apply(policy, state) == InitPopulationRule::ResultKind::kInvalid) { - return ResultKind::kInvalid; +PopulationGenerationRule::ResultKind MutateComputeLocation::Apply(SketchPolicyNode* policy, + State* state) const { + if (GetIntParam(policy->params, SketchParamKey::disable_change_compute_location)) { + return ResultKind::kValid; } + + for (int stage_id = static_cast((*state)->stages.size()) - 1; stage_id >= 0; stage_id--) { + const Stage& stage = (*state)->stages[stage_id]; + // Skip the inlined stages and placeholders + if (stage->op_type == StageKind::kPlaceholder || stage->compute_at == ComputeAtKind::kInlined) { + continue; + } + // Skip the tiled stages + if (IsTiled(stage) || NeedsMultilevelTiling(policy->search_task, *state, stage_id)) { + continue; + } + + int target_stage_id = GetSingleConsumerId(policy->search_task, *state, stage_id); + if (target_stage_id < 0) { + continue; + } + const Stage& target_stage = (*state)->stages[target_stage_id]; + + std::vector> candidates; + bool target_compute_at_other = target_stage->compute_at == ComputeAtKind::kIter; + bool target_is_tiled = IsTiled(target_stage); + + bool visited_reduce = false; + // enumerate compute_at location at target_stage + // TODO(merrymercy): More analysis here to make smarter choices + for (size_t i = 0; i < target_stage->iters.size(); ++i) { + const Iterator& target_iter = target_stage->iters[i]; + if (target_iter->iter_kind == IteratorKind::kReduction) { + visited_reduce = true; + if (!target_is_tiled) { // Do not go into reduce iter + break; + } + } else if (target_iter->iter_kind == IteratorKind::kSpatial) { + if (visited_reduce) { // Do not go into inner tile + break; + } + } + + if (target_iter->annotation == IteratorAnnotation::kUnroll) { + // Do not go into the unroll region of const tensor indices + break; + } + + if (GetExtent(target_iter) == 1) { + // Skip iterators with length of 1 + continue; + } + if (target_compute_at_other && target_iter->iter_kind == IteratorKind::kSpatial && + StrEndsWith(target_iter->name, ".0")) { + // Skip the first level iterators if target stage compute_at another stage + // In this case, the lengths of first level iterators are always one + continue; + } + candidates.emplace_back(target_stage_id, i); + + if ((*state)->attach_map->iter_to_attached_stages.count(std::make_pair(target_stage_id, i))) { + break; + } + } + + // if the target_stage is already compute_at another stage X, try also compute_at X + // We call stage X as `target_target_stage` + if (target_compute_at_other) { + int target_target_stage_id; + target_target_stage_id = (*state)->attach_map->stage_to_attach_iter.at(target_stage_id).first; + const Stage& target_target_stage = (*state)->stages[target_target_stage_id]; + + for (size_t i = 0; i < target_target_stage->iters.size(); ++i) { + const Iterator& target_target_iter = target_target_stage->iters[i]; + if (target_target_iter->iter_kind == IteratorKind::kReduction || + (*state)->attach_map->iter_to_attached_stages.count( + std::make_pair(target_target_stage_id, i))) { + break; + } + + if (target_target_iter->annotation == IteratorAnnotation::kUnroll) { + // Do not go into the unroll region of const tensor indices + break; + } + + if (GetExtent(target_target_iter) == 1) { // skip iterators with length of 1 + continue; + } + + candidates.emplace_back(target_target_stage_id, i); + } + } + + int choice = (policy->rand_gen)() % (candidates.size() + 2); + + if (choice == 0) { + if (!HasReduceIter(stage)) { + const auto& stage_to_attach_iter = (*state)->attach_map->stage_to_attach_iter; + if (stage_to_attach_iter.find(stage_id) != stage_to_attach_iter.end()) { + state->compute_inline(stage_id); + } + } + } else if (choice == 1) { + state->compute_root(stage_id); + } else { + choice = choice - 2; + const Stage& stage = (*state)->stages[candidates[choice].first]; + state->compute_at(stage_id, candidates[choice].first, + stage->iters[candidates[choice].second]); + } + } + + *state = policy->search_task->compute_dag.InferBound(*state); return ResultKind::kValid; } -MutationRule::ResultKind MutateParallel::Apply(SketchPolicyNode* policy, - State* state) const { +PopulationGenerationRule::ResultKind MutateParallel::Apply(SketchPolicyNode* policy, + State* state) const { // This mutation rule only focuses on a case that parallel was added to // the outermost loop and the loop is generated by fusing other loops. // In short, we mutate the fusion step before the parallel step. diff --git a/src/auto_scheduler/search_policy/sketch_policy_rules.h b/src/auto_scheduler/search_policy/sketch_policy_rules.h index ed26416726133..2312625e04edc 100644 --- a/src/auto_scheduler/search_policy/sketch_policy_rules.h +++ b/src/auto_scheduler/search_policy/sketch_policy_rules.h @@ -125,7 +125,7 @@ DEFINE_SKETCH_GENERATION_RULE(RuleSpecialComputeLocationGPU); /********** Init Population **********/ /*! \brief The base class for derivation rules used in the initial population. */ -class InitPopulationRule { +class PopulationGenerationRule { public: /*! \brief Result enumeration of the apply function. */ enum class ResultKind : int { kValid = 0, kInvalid = 1 }; @@ -141,7 +141,7 @@ class InitPopulationRule { }; #define DEFINE_INIT_POPULATION_RULE(rule_name) \ - class rule_name : public InitPopulationRule { \ + class rule_name : public PopulationGenerationRule { \ public: \ ResultKind Apply(SketchPolicyNode* policy, State* state) const final; \ }; @@ -149,10 +149,6 @@ class InitPopulationRule { /*! \brief The rule that fills the incomplete SplitSteps. */ DEFINE_INIT_POPULATION_RULE(InitFillTileSize); -/*! \brief The rule that randomly changes the computation location for some stages, which do not - * need tiling and are not strictly inlineable(e.g. data padding). */ -DEFINE_INIT_POPULATION_RULE(InitChangeComputeLocation); - /*! \brief The rule that annotates parallel for CPU. */ DEFINE_INIT_POPULATION_RULE(InitParallel); @@ -168,20 +164,8 @@ DEFINE_INIT_POPULATION_RULE(InitThreadBind); /********** Mutation **********/ /*! \brief The base class for mutation rules used in the evolutionary search. */ -class MutationRule { +class PopulationMutationRule: public PopulationGenerationRule { public: - /*! \brief Result enumeration of the apply function. */ - enum class ResultKind : int { kValid = 0, kInvalid = 1 }; - - /*! - * \brief Apply function of this rule. - * \param policy The SketchPolicyNode of this rule, some member may get changed during the - * rule applying. (e.g. random number generator) - * \param state The state to apply this rule, update inplace. - * \return The result of this rule, indicate if there's any valid state generated. - */ - virtual ResultKind Apply(SketchPolicyNode* policy, State* state) const = 0; - /*! * \brief Get the priority level of this mutation rule. * \return The priority level of this mutation rule. Higher the better. @@ -189,15 +173,23 @@ class MutationRule { virtual int GetLevel(const SearchTask& task) const = 0; }; +// A helper to define mutation rules with a constant rule level. +#define DEFINE_MUTATE_POPULATION_RULE(rule_name, rule_level) \ + class rule_name : public PopulationMutationRule { \ + public: \ + ResultKind Apply(SketchPolicyNode* policy, State* state) const final; \ + int GetLevel(const SearchTask& task) const final { return rule_level; } \ + }; + /*! \brief The rule that mutates tile size by randomly dividing a tile size by a factor and multipling it to another tile size. */ -class MutateTileSize : public MutationRule { - public: - ResultKind Apply(SketchPolicyNode* policy, State* state) const final; - int GetLevel(const SearchTask& task) const final { return 100; } -}; +DEFINE_MUTATE_POPULATION_RULE(MutateTileSize, 100); + +/*! \brief The rule that mutates the fusion iterators annotated by parallel. */ +DEFINE_MUTATE_POPULATION_RULE(MutateParallel, 50); -class MutateMaxUnrollFactor : public MutationRule { +/*! \brief The rule that mutates the factor of a randomly selected auto max unroll step. */ +class MutateMaxUnrollFactor : public PopulationMutationRule { public: ResultKind Apply(SketchPolicyNode* policy, State* state) const final; int GetLevel(const SearchTask& task) const final { return 10; } @@ -206,7 +198,9 @@ class MutateMaxUnrollFactor : public MutationRule { const std::vector gpu_unroll_cands_ = {0, 16, 64, 512}; }; -class MutateComputeLocation : public MutationRule { +/*! \brief The rule that randomly changes the computation location for some stages, which do not + * need tiling and are not strictly inlineable(e.g. data padding). */ +class MutateComputeLocation : public PopulationMutationRule { public: ResultKind Apply(SketchPolicyNode* policy, State* state) const final; int GetLevel(const SearchTask& task) const final { @@ -217,12 +211,6 @@ class MutateComputeLocation : public MutationRule { } }; -class MutateParallel : public MutationRule { - public: - ResultKind Apply(SketchPolicyNode* policy, State* state) const final; - int GetLevel(const SearchTask& task) const final { return 50; } -}; - } // namespace auto_scheduler } // namespace tvm