From 318b8ddbe4680361ef0425f14d4d95cd4ccb1834 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Thu, 27 Aug 2020 15:29:56 +0800 Subject: [PATCH 1/5] Update auto_scheduler with parallel_for --- src/auto_scheduler/compute_dag.cc | 21 +++++--- src/auto_scheduler/feature.cc | 17 +++--- .../search_policy/sketch_policy.cc | 52 +++++++++++-------- src/support/parallel_for.cc | 4 +- .../test_auto_scheduler_search_policy.py | 17 +++++- 5 files changed, 73 insertions(+), 38 deletions(-) diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index f5f08840de86..2f2c4fb448a1 100644 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -812,16 +813,24 @@ State ComputeDAG::InferBound(const State& state) const { Array ComputeDAG::InferBound(const Array& states) const { Array out_states; - // TODO(jcf94, merrymercy): Use parallel_for to run this in parallel - for (const auto& state : states) { + out_states.reserve(states.size()); + std::mutex m; + + tvm::support::parallel_for(0, states.size(), [this, &states, &out_states, &m](int index) { State out_state; try { - out_state = this->InferBound(state); + out_state = this->InferBound(states[index]); } catch (dmlc::Error& e) { - LOG(WARNING) << "InferBound fails on the state:\n" << state << "\n" << e.what() << std::endl; + LOG(WARNING) << "InferBound fails on the state:\n" + << states[index] << "\n" + << "with: " << e.what() << std::endl; } - out_states.push_back(std::move(out_state)); - } + if (out_state.defined()) { + std::unique_lock l(m); + out_states.push_back(std::move(out_state)); + } + }); + return out_states; } diff --git a/src/auto_scheduler/feature.cc b/src/auto_scheduler/feature.cc index bbef387d3f72..1b3657f62afc 100644 --- a/src/auto_scheduler/feature.cc +++ b/src/auto_scheduler/feature.cc @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -1337,9 +1338,11 @@ void GetPerStoreFeaturesFromStates(const Array& states, const SearchTask& std::atomic error_ct(0); - for (size_t i = skip_first_n_feature_extraction; i < states.size(); ++i) { - GetPerStoreFeaturesWorkerFunc(task, states[i], max_n_bufs, &(*features)[i], &error_ct); - } + support::parallel_for(skip_first_n_feature_extraction, states.size(), + [&task, &states, &max_n_bufs, &features, &error_ct](int i) { + GetPerStoreFeaturesWorkerFunc(task, states[i], max_n_bufs, + &(*features)[i], &error_ct); + }); if (error_ct > 0) { std::cerr << "Encountered " << error_ct @@ -1355,9 +1358,11 @@ void GetPerStoreFeaturesFromStates(const Array& states, const std::vector std::atomic error_ct(0); - for (size_t i = skip_first_n_feature_extraction; i < states.size(); ++i) { - GetPerStoreFeaturesWorkerFunc(tasks[i], states[i], max_n_bufs, &(*features)[i], &error_ct); - } + support::parallel_for(skip_first_n_feature_extraction, states.size(), + [&tasks, &states, &max_n_bufs, &features, &error_ct](int i) { + GetPerStoreFeaturesWorkerFunc(tasks[i], states[i], max_n_bufs, + &(*features)[i], &error_ct); + }); if (error_ct > 0) { std::cerr << "Encountered " << error_ct diff --git a/src/auto_scheduler/search_policy/sketch_policy.cc b/src/auto_scheduler/search_policy/sketch_policy.cc index 51c138be70bb..16cb742ee9a3 100644 --- a/src/auto_scheduler/search_policy/sketch_policy.cc +++ b/src/auto_scheduler/search_policy/sketch_policy.cc @@ -27,6 +27,7 @@ #include "sketch_policy.h" #include +#include #include #include @@ -322,32 +323,39 @@ Array SketchPolicyNode::GenerateSketches() { } Array SketchPolicyNode::SampleInitPopulation(const Array& sketches, int out_size) { - int fail_ct = 0; + std::atomic fail_ct(0); + std::mutex m; Array out_states; auto tic_begin = std::chrono::high_resolution_clock::now(); - // TODO(jcf94, merrymercy): Use parallel_for to run this loop in parallel - while (static_cast(out_states.size()) < out_size && fail_ct < static_cast(out_size)) { - // Random choose a starting sketch - // TODO(jcf94, merrymercy): Maybe choose sketches in different possibility for they may have - // different potential on generating state with better performance - State tmp_s = sketches[(rand_gen)() % sketches.size()]; - - // Derivation rule based enumeration - bool valid = true; - for (const auto& rule : init_rules) { - if (rule->Apply(this, &tmp_s) == InitPopulationRule::ResultKind::kInvalid) { - valid = false; - break; - } - } + support::parallel_for( + 0, out_size, [this, &out_size, &sketches, &out_states, &fail_ct, &m](int i) { + if (fail_ct >= out_size) { + return; + } - if (valid) { - out_states.push_back(std::move(tmp_s)); - } else { - fail_ct++; - } - } + // Random choose a starting sketch + // TODO(jcf94, merrymercy): Maybe choose sketches in different possibility for they may have + // different potential on generating state with better performance + State tmp_s = sketches[(rand_gen)() % sketches.size()]; + // Derivation rule based enumeration + bool valid = true; + for (const auto& rule : init_rules) { + // Some rule needs use the random generator of SketchPolicyNode, which has to be protected + std::unique_lock l(m); + if (rule->Apply(this, &tmp_s) == InitPopulationRule::ResultKind::kInvalid) { + valid = false; + break; + } + } + + if (valid) { + std::unique_lock l(m); + out_states.push_back(std::move(tmp_s)); + } else { + fail_ct++; + } + }); double duration = std::chrono::duration_cast>( std::chrono::high_resolution_clock::now() - tic_begin) diff --git a/src/support/parallel_for.cc b/src/support/parallel_for.cc index 30f39fbee6f9..90b7b449ab49 100644 --- a/src/support/parallel_for.cc +++ b/src/support/parallel_for.cc @@ -34,8 +34,8 @@ namespace support { std::vector> rr_partitioner(int begin, int end, int step, int num_threads) { int total_task_count = (end - begin) / step; - CHECK_GT(total_task_count, 0) << "Infinite loop condition, check the input value of " - << "`begin`, `end`, `step`."; + CHECK_GE(total_task_count, 0) << "Infinite loop condition with begin: " << begin + << " end: " << end << " step: " << step; std::vector> ret; ret.reserve(num_threads); for (size_t thread = 0; begin < end; begin += step, thread = (thread + 1) % num_threads) { diff --git a/tests/python/unittest/test_auto_scheduler_search_policy.py b/tests/python/unittest/test_auto_scheduler_search_policy.py index a646c38c93bf..d680742c680f 100644 --- a/tests/python/unittest/test_auto_scheduler_search_policy.py +++ b/tests/python/unittest/test_auto_scheduler_search_policy.py @@ -48,7 +48,7 @@ def search_common(workload=matmul_auto_scheduler_test, target="llvm", if search_policy == 'empty': search_policy = auto_scheduler.EmptyPolicy(task) elif search_policy == 'sketch': - search_policy = auto_scheduler.SketchPolicy(task, + search_policy = auto_scheduler.SketchPolicy(task, schedule_cost_model=cost_model, init_search_callbacks=init_search_callbacks) tuning_options = auto_scheduler.TuningOptions(num_measure_trials=num_measure_trials, @@ -107,6 +107,18 @@ def test_sketch_search_policy_basic(): t.join() +def test_sketch_search_policy_xgbmodel(): + if not tvm.runtime.enabled("llvm"): + return + # wrap the search in a new thread to avoid the conflict + # between python's multiprocessing and tvm's thread pool + t = PropagatingThread(target=search_common, + kwargs={'seed': 944563397, 'search_policy': 'sketch', + 'cost_model': auto_scheduler.XGBModel()}) + t.start() + t.join() + + def test_sketch_search_policy_cuda_rpc_runner(): if not tvm.runtime.enabled("cuda"): return @@ -121,6 +133,7 @@ def test_sketch_search_policy_cuda_rpc_runner(): if __name__ == "__main__": - test_workload_registry_search_basic() + # test_workload_registry_search_basic() test_sketch_search_policy_basic() + test_sketch_search_policy_xgbmodel() test_sketch_search_policy_cuda_rpc_runner() From 8d59ef01efe47c7ed3524acd57933452cb77b344 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Thu, 27 Aug 2020 15:53:02 +0800 Subject: [PATCH 2/5] Update --- src/auto_scheduler/compute_dag.cc | 6 +++--- tests/python/unittest/test_auto_scheduler_search_policy.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index 2f2c4fb448a1..b8ef74b74658 100644 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -816,13 +816,13 @@ Array ComputeDAG::InferBound(const Array& states) const { out_states.reserve(states.size()); std::mutex m; - tvm::support::parallel_for(0, states.size(), [this, &states, &out_states, &m](int index) { + support::parallel_for(0, states.size(), [this, &states, &out_states, &m](int i) { State out_state; try { - out_state = this->InferBound(states[index]); + out_state = this->InferBound(states[i]); } catch (dmlc::Error& e) { LOG(WARNING) << "InferBound fails on the state:\n" - << states[index] << "\n" + << states[i] << "\n" << "with: " << e.what() << std::endl; } if (out_state.defined()) { diff --git a/tests/python/unittest/test_auto_scheduler_search_policy.py b/tests/python/unittest/test_auto_scheduler_search_policy.py index d680742c680f..1516a4cb7535 100644 --- a/tests/python/unittest/test_auto_scheduler_search_policy.py +++ b/tests/python/unittest/test_auto_scheduler_search_policy.py @@ -133,7 +133,7 @@ def test_sketch_search_policy_cuda_rpc_runner(): if __name__ == "__main__": - # test_workload_registry_search_basic() + test_workload_registry_search_basic() test_sketch_search_policy_basic() test_sketch_search_policy_xgbmodel() test_sketch_search_policy_cuda_rpc_runner() From 81a5d27d62915eca37b22aef578111abc5d1d682 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Thu, 27 Aug 2020 16:10:25 +0800 Subject: [PATCH 3/5] Update --- src/auto_scheduler/search_policy/sketch_policy.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/auto_scheduler/search_policy/sketch_policy.cc b/src/auto_scheduler/search_policy/sketch_policy.cc index 16cb742ee9a3..83b7e0ef8c8e 100644 --- a/src/auto_scheduler/search_policy/sketch_policy.cc +++ b/src/auto_scheduler/search_policy/sketch_policy.cc @@ -341,7 +341,8 @@ Array SketchPolicyNode::SampleInitPopulation(const Array& sketches // Derivation rule based enumeration bool valid = true; for (const auto& rule : init_rules) { - // Some rule needs use the random generator of SketchPolicyNode, which has to be protected + // Some rules use the random generator of SketchPolicyNode, so this part has to be + // protected std::unique_lock l(m); if (rule->Apply(this, &tmp_s) == InitPopulationRule::ResultKind::kInvalid) { valid = false; From 9c8dca0b8178ae72a182153561a9034d92e079f6 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Fri, 28 Aug 2020 10:59:42 +0800 Subject: [PATCH 4/5] Update --- .../search_policy/sketch_policy.cc | 53 ++++++++----------- src/support/parallel_for.cc | 14 +++++ tests/cpp/parallel_for_test.cc | 17 ++++++ .../test_auto_scheduler_search_policy.py | 14 +++++ 4 files changed, 67 insertions(+), 31 deletions(-) diff --git a/src/auto_scheduler/search_policy/sketch_policy.cc b/src/auto_scheduler/search_policy/sketch_policy.cc index 6232b44fbd59..6e9de955fe83 100644 --- a/src/auto_scheduler/search_policy/sketch_policy.cc +++ b/src/auto_scheduler/search_policy/sketch_policy.cc @@ -27,7 +27,6 @@ #include "sketch_policy.h" #include -#include #include #include @@ -262,6 +261,7 @@ Array SketchPolicyNode::SearchOneRound(int num_random_states, Array SketchPolicyNode::GenerateSketches() { } Array SketchPolicyNode::SampleInitPopulation(const Array& sketches, int out_size) { - std::atomic fail_ct(0); - std::mutex m; + int fail_ct = 0; Array out_states; auto tic_begin = std::chrono::high_resolution_clock::now(); - support::parallel_for( - 0, out_size, [this, &out_size, &sketches, &out_states, &fail_ct, &m](int i) { - if (fail_ct >= out_size) { - return; - } - - // Random choose a starting sketch - // TODO(jcf94, merrymercy): Maybe choose sketches in different possibility for they may have - // different potential on generating state with better performance - State tmp_s = sketches[(rand_gen)() % sketches.size()]; - // Derivation rule based enumeration - bool valid = true; - for (const auto& rule : init_rules) { - // Some rules use the random generator of SketchPolicyNode, so this part has to be - // protected - std::unique_lock l(m); - if (rule->Apply(this, &tmp_s) == PopulationGenerationRule::ResultKind::kInvalid) { - valid = false; - break; - } - } + while (static_cast(out_states.size()) < out_size && fail_ct < out_size) { + // Random choose a starting sketch + // TODO(jcf94, merrymercy): Maybe choose sketches in different possibility for they may have + // different potential on generating state with better performance + State tmp_s = sketches[(rand_gen)() % sketches.size()]; + + // Derivation rule based enumeration + bool valid = true; + for (const auto& rule : init_rules) { + if (rule->Apply(this, &tmp_s) == PopulationGenerationRule::ResultKind::kInvalid) { + valid = false; + break; + } + } - if (valid) { - std::unique_lock l(m); - out_states.push_back(std::move(tmp_s)); - } else { - fail_ct++; - } - }); + if (valid) { + out_states.push_back(std::move(tmp_s)); + } else { + fail_ct++; + } + } double duration = std::chrono::duration_cast>( std::chrono::high_resolution_clock::now() - tic_begin) diff --git a/src/support/parallel_for.cc b/src/support/parallel_for.cc index 90b7b449ab49..0b8c810da70b 100644 --- a/src/support/parallel_for.cc +++ b/src/support/parallel_for.cc @@ -49,6 +49,15 @@ std::vector> rr_partitioner(int begin, int end, int step, int n void parallel_for(int begin, int end, const std::function& f, int step, const PartitionerFuncType partitioner) { + static bool GLOBAL_PARALLEL_FOR_FLAG{false}; + static std::mutex M_GLOBAL_PARALLEL_FOR_FLAG; + { + std::unique_lock l(M_GLOBAL_PARALLEL_FOR_FLAG); + CHECK(!GLOBAL_PARALLEL_FOR_FLAG) << "There's another parallel_for running. Maybe you're " + << "currently inside another parallel_for loop."; + GLOBAL_PARALLEL_FOR_FLAG = true; + } + int default_num_threads = std::thread::hardware_concurrency(); const auto& run_partitions = partitioner(begin, end, step, default_num_threads); @@ -70,6 +79,11 @@ void parallel_for(int begin, int end, const std::function& f, int ste for (auto&& thread : threads) { thread.join(); } + { + std::unique_lock l(M_GLOBAL_PARALLEL_FOR_FLAG); + CHECK(GLOBAL_PARALLEL_FOR_FLAG); + GLOBAL_PARALLEL_FOR_FLAG = false; + } try { for (auto&& i : res_vec) { i.get(); diff --git a/tests/cpp/parallel_for_test.cc b/tests/cpp/parallel_for_test.cc index 3d586fc1aa15..82e95f9ab46e 100644 --- a/tests/cpp/parallel_for_test.cc +++ b/tests/cpp/parallel_for_test.cc @@ -89,6 +89,23 @@ TEST(ParallelFor, NestedWithNormalForLoop) { } } +TEST(Parallelfor, NestedWithParallelFor) { + // Currently do not support using nested parallel_for + using tvm::support::parallel_for; + + bool exception = false; + try { + parallel_for(0, 100, [](int i) { + parallel_for(0, 100, [](int j) { + // Blank loop + }); + }); + } catch (const std::exception& e) { + exception = true; + } + CHECK(exception); +} + TEST(ParallelFor, Exception) { using tvm::support::parallel_for; diff --git a/tests/python/unittest/test_auto_scheduler_search_policy.py b/tests/python/unittest/test_auto_scheduler_search_policy.py index 1516a4cb7535..21ac9844b54c 100644 --- a/tests/python/unittest/test_auto_scheduler_search_policy.py +++ b/tests/python/unittest/test_auto_scheduler_search_policy.py @@ -132,8 +132,22 @@ def test_sketch_search_policy_cuda_rpc_runner(): t.join() +def test_sketch_search_policy_cuda_xgbmodel_rpc_runner(): + if not tvm.runtime.enabled("cuda"): + return + measure_ctx = auto_scheduler.LocalRPCMeasureContext() + # wrap the search in a new thread to avoid the conflict + # between python's multiprocessing and tvm's thread pool + t = PropagatingThread(target=search_common, + kwargs={'seed': 944563397, 'search_policy': 'sketch', 'target': 'cuda', + 'runner': measure_ctx.runner, 'cost_model': auto_scheduler.XGBModel()}) + t.start() + t.join() + + if __name__ == "__main__": test_workload_registry_search_basic() test_sketch_search_policy_basic() test_sketch_search_policy_xgbmodel() test_sketch_search_policy_cuda_rpc_runner() + test_sketch_search_policy_cuda_xgbmodel_rpc_runner() From 9e0a6857bc440626d22b0f06cdd19901845b7205 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Mon, 31 Aug 2020 09:54:46 +0800 Subject: [PATCH 5/5] Update inferbound --- include/tvm/auto_scheduler/compute_dag.h | 4 +++- src/auto_scheduler/compute_dag.cc | 13 +++---------- src/auto_scheduler/search_policy/search_policy.cc | 1 + src/auto_scheduler/search_policy/sketch_policy.cc | 2 ++ 4 files changed, 9 insertions(+), 11 deletions(-) diff --git a/include/tvm/auto_scheduler/compute_dag.h b/include/tvm/auto_scheduler/compute_dag.h index 34f1c9d8737d..da276ea139a7 100644 --- a/include/tvm/auto_scheduler/compute_dag.h +++ b/include/tvm/auto_scheduler/compute_dag.h @@ -245,7 +245,9 @@ class ComputeDAG : public ObjectRef { * This function calls TVM InferBound pass internally to get the bound. * The returned state of this function is guaranteed to have complete bound information. * \param states The input states. - * \return The States with complete bound information + * \return The States with complete bound information. + * \note The returned array will contains empty State, if there're infer bound failure on some + * states. */ Array InferBound(const Array& states) const; diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index b8ef74b74658..54f00e08a7a0 100644 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -812,23 +812,16 @@ State ComputeDAG::InferBound(const State& state) const { } Array ComputeDAG::InferBound(const Array& states) const { - Array out_states; - out_states.reserve(states.size()); - std::mutex m; + Array out_states(states.size(), State()); - support::parallel_for(0, states.size(), [this, &states, &out_states, &m](int i) { - State out_state; + support::parallel_for(0, states.size(), [this, &states, &out_states](int i) { try { - out_state = this->InferBound(states[i]); + out_states.Set(i, this->InferBound(states[i])); } catch (dmlc::Error& e) { LOG(WARNING) << "InferBound fails on the state:\n" << states[i] << "\n" << "with: " << e.what() << std::endl; } - if (out_state.defined()) { - std::unique_lock l(m); - out_states.push_back(std::move(out_state)); - } }); return out_states; diff --git a/src/auto_scheduler/search_policy/search_policy.cc b/src/auto_scheduler/search_policy/search_policy.cc index 723f8eedc322..d73bd911a921 100644 --- a/src/auto_scheduler/search_policy/search_policy.cc +++ b/src/auto_scheduler/search_policy/search_policy.cc @@ -58,6 +58,7 @@ void SearchPolicyNode::PreloadMeasuredStates(const String& log_file) { res.second[i]->error_no == 0 ? (1.0 / FloatArrayMean(res.second[i]->costs)) : 0.0); } } + // We can assume the recorded states will all be valid after infer bound measured_states = search_task->compute_dag.InferBound(measured_states); for (size_t i = 0; i < measured_states.size(); i++) { auto& state = measured_states[i]; diff --git a/src/auto_scheduler/search_policy/sketch_policy.cc b/src/auto_scheduler/search_policy/sketch_policy.cc index 6e9de955fe83..ffc00941143c 100644 --- a/src/auto_scheduler/search_policy/sketch_policy.cc +++ b/src/auto_scheduler/search_policy/sketch_policy.cc @@ -179,7 +179,9 @@ State SketchPolicyNode::Search(int n_trials, int early_stopping, int num_measure // Infer bound. This is necessary for computing the correct ToStr() for redundancy check best_states = search_task->compute_dag.InferBound(best_states); + PruneInvalidState(search_task, &best_states); random_states = search_task->compute_dag.InferBound(random_states); + PruneInvalidState(search_task, &random_states); // Pick `num_measure_per_iter` states to measure, check hash to remove already measured state // Also pick some random states to do eps-greedy