Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Ansor][AutoTVM v2.0] Phase 2: Update heavy operations with parallel_for #6348

Merged
merged 6 commits into from
Aug 31, 2020
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions src/auto_scheduler/compute_dag.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/auto_scheduler/search_policy.h>
#include <tvm/auto_scheduler/transform_step.h>
#include <tvm/runtime/registry.h>
#include <tvm/support/parallel_for.h>
#include <tvm/te/operation.h>
#include <tvm/te/schedule.h>
#include <tvm/te/schedule_pass.h>
Expand Down Expand Up @@ -812,16 +813,24 @@ State ComputeDAG::InferBound(const State& state) const {

Array<State> ComputeDAG::InferBound(const Array<State>& states) const {
Array<State> out_states;
// TODO(jcf94, merrymercy): Use parallel_for to run this in parallel
for (const auto& state : states) {
out_states.reserve(states.size());
std::mutex m;

support::parallel_for(0, states.size(), [this, &states, &out_states, &m](int i) {
State out_state;
try {
out_state = this->InferBound(state);
out_state = this->InferBound(states[i]);
} catch (dmlc::Error& e) {
LOG(WARNING) << "InferBound fails on the state:\n" << state << "\n" << e.what() << std::endl;
LOG(WARNING) << "InferBound fails on the state:\n"
<< states[i] << "\n"
<< "with: " << e.what() << std::endl;
}
out_states.push_back(std::move(out_state));
}
if (out_state.defined()) {
std::unique_lock<std::mutex> l(m);
Copy link
Member

@merrymercy merrymercy Aug 30, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use low-level APIs from ArrayNode to remove this mutex lock?
Because we know the number of states, we can allocate n empty State() in advance.
Do something like

std::vector<State> out_states(n, State());
parallel_for i in 0...n:
    out_state[i] = process(states[i])

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our original implementation in Ansor repo just worked in this way, while if there's any inferbound failure the out_state will contains some empy State().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a note in the doc string of ComputeDAG::Inferbound, added necessary PruneInvalidState to the search policy.

out_states.push_back(std::move(out_state));
}
});

return out_states;
}

Expand Down
17 changes: 11 additions & 6 deletions src/auto_scheduler/feature.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/auto_scheduler/measure.h>
#include <tvm/auto_scheduler/measure_record.h>
#include <tvm/runtime/registry.h>
#include <tvm/support/parallel_for.h>
#include <tvm/te/operation.h>
#include <tvm/te/schedule_pass.h>
#include <tvm/tir/analysis.h>
Expand Down Expand Up @@ -1337,9 +1338,11 @@ void GetPerStoreFeaturesFromStates(const Array<State>& states, const SearchTask&

std::atomic<int> error_ct(0);

for (size_t i = skip_first_n_feature_extraction; i < states.size(); ++i) {
GetPerStoreFeaturesWorkerFunc(task, states[i], max_n_bufs, &(*features)[i], &error_ct);
}
support::parallel_for(skip_first_n_feature_extraction, states.size(),
[&task, &states, &max_n_bufs, &features, &error_ct](int i) {
GetPerStoreFeaturesWorkerFunc(task, states[i], max_n_bufs,
&(*features)[i], &error_ct);
});

if (error_ct > 0) {
std::cerr << "Encountered " << error_ct
Expand All @@ -1355,9 +1358,11 @@ void GetPerStoreFeaturesFromStates(const Array<State>& states, const std::vector

std::atomic<int> error_ct(0);

for (size_t i = skip_first_n_feature_extraction; i < states.size(); ++i) {
GetPerStoreFeaturesWorkerFunc(tasks[i], states[i], max_n_bufs, &(*features)[i], &error_ct);
}
support::parallel_for(skip_first_n_feature_extraction, states.size(),
[&tasks, &states, &max_n_bufs, &features, &error_ct](int i) {
GetPerStoreFeaturesWorkerFunc(tasks[i], states[i], max_n_bufs,
&(*features)[i], &error_ct);
});

if (error_ct > 0) {
std::cerr << "Encountered " << error_ct
Expand Down
4 changes: 2 additions & 2 deletions src/auto_scheduler/search_policy/sketch_policy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ Array<State> SketchPolicyNode::SearchOneRound(int num_random_states, Array<State
*random_states = RandomSampleStates(init_population, &rand_gen, num_random_states * 10);
return EvolutionarySearch(init_population, num_measure_per_iter_ * 2);
} else {
PruneInvalidState(search_task, &init_population);
return RandomSampleStates(init_population, &rand_gen, num_measure_per_iter_ * 3);
}
}
Expand Down Expand Up @@ -340,8 +341,7 @@ Array<State> SketchPolicyNode::SampleInitPopulation(const Array<State>& sketches
Array<State> out_states;
auto tic_begin = std::chrono::high_resolution_clock::now();

// TODO(jcf94, merrymercy): Use parallel_for to run this loop in parallel
while (static_cast<int>(out_states.size()) < out_size && fail_ct < static_cast<int>(out_size)) {
while (static_cast<int>(out_states.size()) < out_size && fail_ct < out_size) {
// Random choose a starting sketch
// TODO(jcf94, merrymercy): Maybe choose sketches in different possibility for they may have
// different potential on generating state with better performance
Expand Down
18 changes: 16 additions & 2 deletions src/support/parallel_for.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ namespace support {

std::vector<std::vector<int>> rr_partitioner(int begin, int end, int step, int num_threads) {
int total_task_count = (end - begin) / step;
CHECK_GT(total_task_count, 0) << "Infinite loop condition, check the input value of "
<< "`begin`, `end`, `step`.";
CHECK_GE(total_task_count, 0) << "Infinite loop condition with begin: " << begin
<< " end: " << end << " step: " << step;
std::vector<std::vector<int>> ret;
ret.reserve(num_threads);
for (size_t thread = 0; begin < end; begin += step, thread = (thread + 1) % num_threads) {
Expand All @@ -49,6 +49,15 @@ std::vector<std::vector<int>> rr_partitioner(int begin, int end, int step, int n

void parallel_for(int begin, int end, const std::function<void(int)>& f, int step,
const PartitionerFuncType partitioner) {
static bool GLOBAL_PARALLEL_FOR_FLAG{false};
static std::mutex M_GLOBAL_PARALLEL_FOR_FLAG;
{
std::unique_lock<std::mutex> l(M_GLOBAL_PARALLEL_FOR_FLAG);
CHECK(!GLOBAL_PARALLEL_FOR_FLAG) << "There's another parallel_for running. Maybe you're "
<< "currently inside another parallel_for loop.";
GLOBAL_PARALLEL_FOR_FLAG = true;
}

int default_num_threads = std::thread::hardware_concurrency();
const auto& run_partitions = partitioner(begin, end, step, default_num_threads);

Expand All @@ -70,6 +79,11 @@ void parallel_for(int begin, int end, const std::function<void(int)>& f, int ste
for (auto&& thread : threads) {
thread.join();
}
{
std::unique_lock<std::mutex> l(M_GLOBAL_PARALLEL_FOR_FLAG);
CHECK(GLOBAL_PARALLEL_FOR_FLAG);
GLOBAL_PARALLEL_FOR_FLAG = false;
}
try {
for (auto&& i : res_vec) {
i.get();
Expand Down
17 changes: 17 additions & 0 deletions tests/cpp/parallel_for_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,23 @@ TEST(ParallelFor, NestedWithNormalForLoop) {
}
}

TEST(Parallelfor, NestedWithParallelFor) {
// Currently do not support using nested parallel_for
using tvm::support::parallel_for;

bool exception = false;
try {
parallel_for(0, 100, [](int i) {
parallel_for(0, 100, [](int j) {
// Blank loop
});
});
} catch (const std::exception& e) {
exception = true;
}
CHECK(exception);
}

TEST(ParallelFor, Exception) {
using tvm::support::parallel_for;

Expand Down
29 changes: 28 additions & 1 deletion tests/python/unittest/test_auto_scheduler_search_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def search_common(workload=matmul_auto_scheduler_test, target="llvm",
if search_policy == 'empty':
search_policy = auto_scheduler.EmptyPolicy(task)
elif search_policy == 'sketch':
search_policy = auto_scheduler.SketchPolicy(task,
search_policy = auto_scheduler.SketchPolicy(task, schedule_cost_model=cost_model,
init_search_callbacks=init_search_callbacks)

tuning_options = auto_scheduler.TuningOptions(num_measure_trials=num_measure_trials,
Expand Down Expand Up @@ -107,6 +107,18 @@ def test_sketch_search_policy_basic():
t.join()


def test_sketch_search_policy_xgbmodel():
if not tvm.runtime.enabled("llvm"):
return
# wrap the search in a new thread to avoid the conflict
# between python's multiprocessing and tvm's thread pool
t = PropagatingThread(target=search_common,
kwargs={'seed': 944563397, 'search_policy': 'sketch',
'cost_model': auto_scheduler.XGBModel()})
t.start()
t.join()


def test_sketch_search_policy_cuda_rpc_runner():
if not tvm.runtime.enabled("cuda"):
return
Expand All @@ -120,7 +132,22 @@ def test_sketch_search_policy_cuda_rpc_runner():
t.join()


def test_sketch_search_policy_cuda_xgbmodel_rpc_runner():
if not tvm.runtime.enabled("cuda"):
return
measure_ctx = auto_scheduler.LocalRPCMeasureContext()
# wrap the search in a new thread to avoid the conflict
# between python's multiprocessing and tvm's thread pool
t = PropagatingThread(target=search_common,
kwargs={'seed': 944563397, 'search_policy': 'sketch', 'target': 'cuda',
'runner': measure_ctx.runner, 'cost_model': auto_scheduler.XGBModel()})
t.start()
t.join()


if __name__ == "__main__":
test_workload_registry_search_basic()
test_sketch_search_policy_basic()
test_sketch_search_policy_xgbmodel()
test_sketch_search_policy_cuda_rpc_runner()
test_sketch_search_policy_cuda_xgbmodel_rpc_runner()