Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
jcf94 committed Aug 28, 2020
1 parent 788e2c1 commit 9c8dca0
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 31 deletions.
53 changes: 22 additions & 31 deletions src/auto_scheduler/search_policy/sketch_policy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
#include "sketch_policy.h"

#include <tvm/runtime/registry.h>
#include <tvm/support/parallel_for.h>

#include <algorithm>
#include <iomanip>
Expand Down Expand Up @@ -262,6 +261,7 @@ Array<State> SketchPolicyNode::SearchOneRound(int num_random_states, Array<State
*random_states = RandomSampleStates(init_population, &rand_gen, num_random_states * 10);
return EvolutionarySearch(init_population, num_measure_per_iter_ * 2);
} else {
PruneInvalidState(search_task, &init_population);
return RandomSampleStates(init_population, &rand_gen, num_measure_per_iter_ * 3);
}
}
Expand Down Expand Up @@ -337,40 +337,31 @@ Array<State> SketchPolicyNode::GenerateSketches() {
}

Array<State> SketchPolicyNode::SampleInitPopulation(const Array<State>& sketches, int out_size) {
std::atomic<int> fail_ct(0);
std::mutex m;
int fail_ct = 0;
Array<State> out_states;
auto tic_begin = std::chrono::high_resolution_clock::now();

support::parallel_for(
0, out_size, [this, &out_size, &sketches, &out_states, &fail_ct, &m](int i) {
if (fail_ct >= out_size) {
return;
}

// Random choose a starting sketch
// TODO(jcf94, merrymercy): Maybe choose sketches in different possibility for they may have
// different potential on generating state with better performance
State tmp_s = sketches[(rand_gen)() % sketches.size()];
// Derivation rule based enumeration
bool valid = true;
for (const auto& rule : init_rules) {
// Some rules use the random generator of SketchPolicyNode, so this part has to be
// protected
std::unique_lock<std::mutex> l(m);
if (rule->Apply(this, &tmp_s) == PopulationGenerationRule::ResultKind::kInvalid) {
valid = false;
break;
}
}
while (static_cast<int>(out_states.size()) < out_size && fail_ct < out_size) {
// Random choose a starting sketch
// TODO(jcf94, merrymercy): Maybe choose sketches in different possibility for they may have
// different potential on generating state with better performance
State tmp_s = sketches[(rand_gen)() % sketches.size()];

// Derivation rule based enumeration
bool valid = true;
for (const auto& rule : init_rules) {
if (rule->Apply(this, &tmp_s) == PopulationGenerationRule::ResultKind::kInvalid) {
valid = false;
break;
}
}

if (valid) {
std::unique_lock<std::mutex> l(m);
out_states.push_back(std::move(tmp_s));
} else {
fail_ct++;
}
});
if (valid) {
out_states.push_back(std::move(tmp_s));
} else {
fail_ct++;
}
}

double duration = std::chrono::duration_cast<std::chrono::duration<double>>(
std::chrono::high_resolution_clock::now() - tic_begin)
Expand Down
14 changes: 14 additions & 0 deletions src/support/parallel_for.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,15 @@ std::vector<std::vector<int>> rr_partitioner(int begin, int end, int step, int n

void parallel_for(int begin, int end, const std::function<void(int)>& f, int step,
const PartitionerFuncType partitioner) {
static bool GLOBAL_PARALLEL_FOR_FLAG{false};
static std::mutex M_GLOBAL_PARALLEL_FOR_FLAG;
{
std::unique_lock<std::mutex> l(M_GLOBAL_PARALLEL_FOR_FLAG);
CHECK(!GLOBAL_PARALLEL_FOR_FLAG) << "There's another parallel_for running. Maybe you're "
<< "currently inside another parallel_for loop.";
GLOBAL_PARALLEL_FOR_FLAG = true;
}

int default_num_threads = std::thread::hardware_concurrency();
const auto& run_partitions = partitioner(begin, end, step, default_num_threads);

Expand All @@ -70,6 +79,11 @@ void parallel_for(int begin, int end, const std::function<void(int)>& f, int ste
for (auto&& thread : threads) {
thread.join();
}
{
std::unique_lock<std::mutex> l(M_GLOBAL_PARALLEL_FOR_FLAG);
CHECK(GLOBAL_PARALLEL_FOR_FLAG);
GLOBAL_PARALLEL_FOR_FLAG = false;
}
try {
for (auto&& i : res_vec) {
i.get();
Expand Down
17 changes: 17 additions & 0 deletions tests/cpp/parallel_for_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,23 @@ TEST(ParallelFor, NestedWithNormalForLoop) {
}
}

TEST(Parallelfor, NestedWithParallelFor) {
// Currently do not support using nested parallel_for
using tvm::support::parallel_for;

bool exception = false;
try {
parallel_for(0, 100, [](int i) {
parallel_for(0, 100, [](int j) {
// Blank loop
});
});
} catch (const std::exception& e) {
exception = true;
}
CHECK(exception);
}

TEST(ParallelFor, Exception) {
using tvm::support::parallel_for;

Expand Down
14 changes: 14 additions & 0 deletions tests/python/unittest/test_auto_scheduler_search_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,22 @@ def test_sketch_search_policy_cuda_rpc_runner():
t.join()


def test_sketch_search_policy_cuda_xgbmodel_rpc_runner():
if not tvm.runtime.enabled("cuda"):
return
measure_ctx = auto_scheduler.LocalRPCMeasureContext()
# wrap the search in a new thread to avoid the conflict
# between python's multiprocessing and tvm's thread pool
t = PropagatingThread(target=search_common,
kwargs={'seed': 944563397, 'search_policy': 'sketch', 'target': 'cuda',
'runner': measure_ctx.runner, 'cost_model': auto_scheduler.XGBModel()})
t.start()
t.join()


if __name__ == "__main__":
test_workload_registry_search_basic()
test_sketch_search_policy_basic()
test_sketch_search_policy_xgbmodel()
test_sketch_search_policy_cuda_rpc_runner()
test_sketch_search_policy_cuda_xgbmodel_rpc_runner()

0 comments on commit 9c8dca0

Please sign in to comment.