From 10955edea6a065468cda73a45dd3c8764905f4ce Mon Sep 17 00:00:00 2001 From: Chenfan Date: Mon, 31 Aug 2020 14:07:24 +0800 Subject: [PATCH] [Ansor][AutoTVM v2.0] Phase 2: Update heavy operations with parallel_for (#6348) * Update auto_scheduler with parallel_for * Update * Update * Update * Update inferbound --- include/tvm/auto_scheduler/compute_dag.h | 4 ++- src/auto_scheduler/compute_dag.cc | 18 +++++++----- src/auto_scheduler/feature.cc | 17 +++++++---- .../search_policy/search_policy.cc | 1 + .../search_policy/sketch_policy.cc | 6 ++-- src/support/parallel_for.cc | 18 ++++++++++-- tests/cpp/parallel_for_test.cc | 17 +++++++++++ .../test_auto_scheduler_search_policy.py | 29 ++++++++++++++++++- 8 files changed, 90 insertions(+), 20 deletions(-) diff --git a/include/tvm/auto_scheduler/compute_dag.h b/include/tvm/auto_scheduler/compute_dag.h index 34f1c9d8737d..da276ea139a7 100644 --- a/include/tvm/auto_scheduler/compute_dag.h +++ b/include/tvm/auto_scheduler/compute_dag.h @@ -245,7 +245,9 @@ class ComputeDAG : public ObjectRef { * This function calls TVM InferBound pass internally to get the bound. * The returned state of this function is guaranteed to have complete bound information. * \param states The input states. - * \return The States with complete bound information + * \return The States with complete bound information. + * \note The returned array will contains empty State, if there're infer bound failure on some + * states. */ Array InferBound(const Array& states) const; diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index f5f08840de86..54f00e08a7a0 100644 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -811,17 +812,18 @@ State ComputeDAG::InferBound(const State& state) const { } Array ComputeDAG::InferBound(const Array& states) const { - Array out_states; - // TODO(jcf94, merrymercy): Use parallel_for to run this in parallel - for (const auto& state : states) { - State out_state; + Array out_states(states.size(), State()); + + support::parallel_for(0, states.size(), [this, &states, &out_states](int i) { try { - out_state = this->InferBound(state); + out_states.Set(i, this->InferBound(states[i])); } catch (dmlc::Error& e) { - LOG(WARNING) << "InferBound fails on the state:\n" << state << "\n" << e.what() << std::endl; + LOG(WARNING) << "InferBound fails on the state:\n" + << states[i] << "\n" + << "with: " << e.what() << std::endl; } - out_states.push_back(std::move(out_state)); - } + }); + return out_states; } diff --git a/src/auto_scheduler/feature.cc b/src/auto_scheduler/feature.cc index bbef387d3f72..1b3657f62afc 100644 --- a/src/auto_scheduler/feature.cc +++ b/src/auto_scheduler/feature.cc @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -1337,9 +1338,11 @@ void GetPerStoreFeaturesFromStates(const Array& states, const SearchTask& std::atomic error_ct(0); - for (size_t i = skip_first_n_feature_extraction; i < states.size(); ++i) { - GetPerStoreFeaturesWorkerFunc(task, states[i], max_n_bufs, &(*features)[i], &error_ct); - } + support::parallel_for(skip_first_n_feature_extraction, states.size(), + [&task, &states, &max_n_bufs, &features, &error_ct](int i) { + GetPerStoreFeaturesWorkerFunc(task, states[i], max_n_bufs, + &(*features)[i], &error_ct); + }); if (error_ct > 0) { std::cerr << "Encountered " << error_ct @@ -1355,9 +1358,11 @@ void GetPerStoreFeaturesFromStates(const Array& states, const std::vector std::atomic error_ct(0); - for (size_t i = skip_first_n_feature_extraction; i < states.size(); ++i) { - GetPerStoreFeaturesWorkerFunc(tasks[i], states[i], max_n_bufs, &(*features)[i], &error_ct); - } + support::parallel_for(skip_first_n_feature_extraction, states.size(), + [&tasks, &states, &max_n_bufs, &features, &error_ct](int i) { + GetPerStoreFeaturesWorkerFunc(tasks[i], states[i], max_n_bufs, + &(*features)[i], &error_ct); + }); if (error_ct > 0) { std::cerr << "Encountered " << error_ct diff --git a/src/auto_scheduler/search_policy/search_policy.cc b/src/auto_scheduler/search_policy/search_policy.cc index 723f8eedc322..d73bd911a921 100644 --- a/src/auto_scheduler/search_policy/search_policy.cc +++ b/src/auto_scheduler/search_policy/search_policy.cc @@ -58,6 +58,7 @@ void SearchPolicyNode::PreloadMeasuredStates(const String& log_file) { res.second[i]->error_no == 0 ? (1.0 / FloatArrayMean(res.second[i]->costs)) : 0.0); } } + // We can assume the recorded states will all be valid after infer bound measured_states = search_task->compute_dag.InferBound(measured_states); for (size_t i = 0; i < measured_states.size(); i++) { auto& state = measured_states[i]; diff --git a/src/auto_scheduler/search_policy/sketch_policy.cc b/src/auto_scheduler/search_policy/sketch_policy.cc index 4f536e829be4..ffc00941143c 100644 --- a/src/auto_scheduler/search_policy/sketch_policy.cc +++ b/src/auto_scheduler/search_policy/sketch_policy.cc @@ -179,7 +179,9 @@ State SketchPolicyNode::Search(int n_trials, int early_stopping, int num_measure // Infer bound. This is necessary for computing the correct ToStr() for redundancy check best_states = search_task->compute_dag.InferBound(best_states); + PruneInvalidState(search_task, &best_states); random_states = search_task->compute_dag.InferBound(random_states); + PruneInvalidState(search_task, &random_states); // Pick `num_measure_per_iter` states to measure, check hash to remove already measured state // Also pick some random states to do eps-greedy @@ -261,6 +263,7 @@ Array SketchPolicyNode::SearchOneRound(int num_random_states, Array SketchPolicyNode::SampleInitPopulation(const Array& sketches Array out_states; auto tic_begin = std::chrono::high_resolution_clock::now(); - // TODO(jcf94, merrymercy): Use parallel_for to run this loop in parallel - while (static_cast(out_states.size()) < out_size && fail_ct < static_cast(out_size)) { + while (static_cast(out_states.size()) < out_size && fail_ct < out_size) { // Random choose a starting sketch // TODO(jcf94, merrymercy): Maybe choose sketches in different possibility for they may have // different potential on generating state with better performance diff --git a/src/support/parallel_for.cc b/src/support/parallel_for.cc index 30f39fbee6f9..0b8c810da70b 100644 --- a/src/support/parallel_for.cc +++ b/src/support/parallel_for.cc @@ -34,8 +34,8 @@ namespace support { std::vector> rr_partitioner(int begin, int end, int step, int num_threads) { int total_task_count = (end - begin) / step; - CHECK_GT(total_task_count, 0) << "Infinite loop condition, check the input value of " - << "`begin`, `end`, `step`."; + CHECK_GE(total_task_count, 0) << "Infinite loop condition with begin: " << begin + << " end: " << end << " step: " << step; std::vector> ret; ret.reserve(num_threads); for (size_t thread = 0; begin < end; begin += step, thread = (thread + 1) % num_threads) { @@ -49,6 +49,15 @@ std::vector> rr_partitioner(int begin, int end, int step, int n void parallel_for(int begin, int end, const std::function& f, int step, const PartitionerFuncType partitioner) { + static bool GLOBAL_PARALLEL_FOR_FLAG{false}; + static std::mutex M_GLOBAL_PARALLEL_FOR_FLAG; + { + std::unique_lock l(M_GLOBAL_PARALLEL_FOR_FLAG); + CHECK(!GLOBAL_PARALLEL_FOR_FLAG) << "There's another parallel_for running. Maybe you're " + << "currently inside another parallel_for loop."; + GLOBAL_PARALLEL_FOR_FLAG = true; + } + int default_num_threads = std::thread::hardware_concurrency(); const auto& run_partitions = partitioner(begin, end, step, default_num_threads); @@ -70,6 +79,11 @@ void parallel_for(int begin, int end, const std::function& f, int ste for (auto&& thread : threads) { thread.join(); } + { + std::unique_lock l(M_GLOBAL_PARALLEL_FOR_FLAG); + CHECK(GLOBAL_PARALLEL_FOR_FLAG); + GLOBAL_PARALLEL_FOR_FLAG = false; + } try { for (auto&& i : res_vec) { i.get(); diff --git a/tests/cpp/parallel_for_test.cc b/tests/cpp/parallel_for_test.cc index 3d586fc1aa15..82e95f9ab46e 100644 --- a/tests/cpp/parallel_for_test.cc +++ b/tests/cpp/parallel_for_test.cc @@ -89,6 +89,23 @@ TEST(ParallelFor, NestedWithNormalForLoop) { } } +TEST(Parallelfor, NestedWithParallelFor) { + // Currently do not support using nested parallel_for + using tvm::support::parallel_for; + + bool exception = false; + try { + parallel_for(0, 100, [](int i) { + parallel_for(0, 100, [](int j) { + // Blank loop + }); + }); + } catch (const std::exception& e) { + exception = true; + } + CHECK(exception); +} + TEST(ParallelFor, Exception) { using tvm::support::parallel_for; diff --git a/tests/python/unittest/test_auto_scheduler_search_policy.py b/tests/python/unittest/test_auto_scheduler_search_policy.py index a646c38c93bf..21ac9844b54c 100644 --- a/tests/python/unittest/test_auto_scheduler_search_policy.py +++ b/tests/python/unittest/test_auto_scheduler_search_policy.py @@ -48,7 +48,7 @@ def search_common(workload=matmul_auto_scheduler_test, target="llvm", if search_policy == 'empty': search_policy = auto_scheduler.EmptyPolicy(task) elif search_policy == 'sketch': - search_policy = auto_scheduler.SketchPolicy(task, + search_policy = auto_scheduler.SketchPolicy(task, schedule_cost_model=cost_model, init_search_callbacks=init_search_callbacks) tuning_options = auto_scheduler.TuningOptions(num_measure_trials=num_measure_trials, @@ -107,6 +107,18 @@ def test_sketch_search_policy_basic(): t.join() +def test_sketch_search_policy_xgbmodel(): + if not tvm.runtime.enabled("llvm"): + return + # wrap the search in a new thread to avoid the conflict + # between python's multiprocessing and tvm's thread pool + t = PropagatingThread(target=search_common, + kwargs={'seed': 944563397, 'search_policy': 'sketch', + 'cost_model': auto_scheduler.XGBModel()}) + t.start() + t.join() + + def test_sketch_search_policy_cuda_rpc_runner(): if not tvm.runtime.enabled("cuda"): return @@ -120,7 +132,22 @@ def test_sketch_search_policy_cuda_rpc_runner(): t.join() +def test_sketch_search_policy_cuda_xgbmodel_rpc_runner(): + if not tvm.runtime.enabled("cuda"): + return + measure_ctx = auto_scheduler.LocalRPCMeasureContext() + # wrap the search in a new thread to avoid the conflict + # between python's multiprocessing and tvm's thread pool + t = PropagatingThread(target=search_common, + kwargs={'seed': 944563397, 'search_policy': 'sketch', 'target': 'cuda', + 'runner': measure_ctx.runner, 'cost_model': auto_scheduler.XGBModel()}) + t.start() + t.join() + + if __name__ == "__main__": test_workload_registry_search_basic() test_sketch_search_policy_basic() + test_sketch_search_policy_xgbmodel() test_sketch_search_policy_cuda_rpc_runner() + test_sketch_search_policy_cuda_xgbmodel_rpc_runner()