From 771c14da6b181c89e0be1acfc7d636deb0947505 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Fri, 28 Aug 2020 10:59:42 +0800 Subject: [PATCH] Update --- .../search_policy/sketch_policy.cc | 51 ++++++++----------- src/support/parallel_for.cc | 14 +++++ tests/cpp/parallel_for_test.cc | 17 +++++++ .../test_auto_scheduler_search_policy.py | 14 +++++ 4 files changed, 66 insertions(+), 30 deletions(-) diff --git a/src/auto_scheduler/search_policy/sketch_policy.cc b/src/auto_scheduler/search_policy/sketch_policy.cc index 6232b44fbd59f..5ddc4eca949c9 100644 --- a/src/auto_scheduler/search_policy/sketch_policy.cc +++ b/src/auto_scheduler/search_policy/sketch_policy.cc @@ -337,40 +337,31 @@ Array SketchPolicyNode::GenerateSketches() { } Array SketchPolicyNode::SampleInitPopulation(const Array& sketches, int out_size) { - std::atomic fail_ct(0); - std::mutex m; + int fail_ct = 0; Array out_states; auto tic_begin = std::chrono::high_resolution_clock::now(); - support::parallel_for( - 0, out_size, [this, &out_size, &sketches, &out_states, &fail_ct, &m](int i) { - if (fail_ct >= out_size) { - return; - } - - // Random choose a starting sketch - // TODO(jcf94, merrymercy): Maybe choose sketches in different possibility for they may have - // different potential on generating state with better performance - State tmp_s = sketches[(rand_gen)() % sketches.size()]; - // Derivation rule based enumeration - bool valid = true; - for (const auto& rule : init_rules) { - // Some rules use the random generator of SketchPolicyNode, so this part has to be - // protected - std::unique_lock l(m); - if (rule->Apply(this, &tmp_s) == PopulationGenerationRule::ResultKind::kInvalid) { - valid = false; - break; - } - } + while (static_cast(out_states.size()) < out_size && fail_ct < out_size) { + // Random choose a starting sketch + // TODO(jcf94, merrymercy): Maybe choose sketches in different possibility for they may have + // different potential on generating state with better performance + State tmp_s = sketches[(rand_gen)() % sketches.size()]; + + // Derivation rule based enumeration + bool valid = true; + for (const auto& rule : init_rules) { + if (rule->Apply(this, &tmp_s) == PopulationGenerationRule::ResultKind::kInvalid) { + valid = false; + break; + } + } - if (valid) { - std::unique_lock l(m); - out_states.push_back(std::move(tmp_s)); - } else { - fail_ct++; - } - }); + if (valid) { + out_states.push_back(std::move(tmp_s)); + } else { + fail_ct++; + } + } double duration = std::chrono::duration_cast>( std::chrono::high_resolution_clock::now() - tic_begin) diff --git a/src/support/parallel_for.cc b/src/support/parallel_for.cc index 90b7b449ab49a..0b8c810da70b7 100644 --- a/src/support/parallel_for.cc +++ b/src/support/parallel_for.cc @@ -49,6 +49,15 @@ std::vector> rr_partitioner(int begin, int end, int step, int n void parallel_for(int begin, int end, const std::function& f, int step, const PartitionerFuncType partitioner) { + static bool GLOBAL_PARALLEL_FOR_FLAG{false}; + static std::mutex M_GLOBAL_PARALLEL_FOR_FLAG; + { + std::unique_lock l(M_GLOBAL_PARALLEL_FOR_FLAG); + CHECK(!GLOBAL_PARALLEL_FOR_FLAG) << "There's another parallel_for running. Maybe you're " + << "currently inside another parallel_for loop."; + GLOBAL_PARALLEL_FOR_FLAG = true; + } + int default_num_threads = std::thread::hardware_concurrency(); const auto& run_partitions = partitioner(begin, end, step, default_num_threads); @@ -70,6 +79,11 @@ void parallel_for(int begin, int end, const std::function& f, int ste for (auto&& thread : threads) { thread.join(); } + { + std::unique_lock l(M_GLOBAL_PARALLEL_FOR_FLAG); + CHECK(GLOBAL_PARALLEL_FOR_FLAG); + GLOBAL_PARALLEL_FOR_FLAG = false; + } try { for (auto&& i : res_vec) { i.get(); diff --git a/tests/cpp/parallel_for_test.cc b/tests/cpp/parallel_for_test.cc index 3d586fc1aa158..82e95f9ab46e2 100644 --- a/tests/cpp/parallel_for_test.cc +++ b/tests/cpp/parallel_for_test.cc @@ -89,6 +89,23 @@ TEST(ParallelFor, NestedWithNormalForLoop) { } } +TEST(Parallelfor, NestedWithParallelFor) { + // Currently do not support using nested parallel_for + using tvm::support::parallel_for; + + bool exception = false; + try { + parallel_for(0, 100, [](int i) { + parallel_for(0, 100, [](int j) { + // Blank loop + }); + }); + } catch (const std::exception& e) { + exception = true; + } + CHECK(exception); +} + TEST(ParallelFor, Exception) { using tvm::support::parallel_for; diff --git a/tests/python/unittest/test_auto_scheduler_search_policy.py b/tests/python/unittest/test_auto_scheduler_search_policy.py index 1516a4cb75359..21ac9844b54ce 100644 --- a/tests/python/unittest/test_auto_scheduler_search_policy.py +++ b/tests/python/unittest/test_auto_scheduler_search_policy.py @@ -132,8 +132,22 @@ def test_sketch_search_policy_cuda_rpc_runner(): t.join() +def test_sketch_search_policy_cuda_xgbmodel_rpc_runner(): + if not tvm.runtime.enabled("cuda"): + return + measure_ctx = auto_scheduler.LocalRPCMeasureContext() + # wrap the search in a new thread to avoid the conflict + # between python's multiprocessing and tvm's thread pool + t = PropagatingThread(target=search_common, + kwargs={'seed': 944563397, 'search_policy': 'sketch', 'target': 'cuda', + 'runner': measure_ctx.runner, 'cost_model': auto_scheduler.XGBModel()}) + t.start() + t.join() + + if __name__ == "__main__": test_workload_registry_search_basic() test_sketch_search_policy_basic() test_sketch_search_policy_xgbmodel() test_sketch_search_policy_cuda_rpc_runner() + test_sketch_search_policy_cuda_xgbmodel_rpc_runner()