Skip to content

Commit

Permalink
[Support] Add parallel_for support to run a loop in parallel (#6275)
Browse files Browse the repository at this point in the history
  • Loading branch information
jcf94 authored Aug 18, 2020
1 parent 9b8eb81 commit c1d347f
Show file tree
Hide file tree
Showing 3 changed files with 264 additions and 0 deletions.
73 changes: 73 additions & 0 deletions include/tvm/support/parallel_for.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file parallel_for.h
* \brief An implementation to run loop in parallel.
*/
#ifndef TVM_SUPPORT_PARALLEL_FOR_H_
#define TVM_SUPPORT_PARALLEL_FOR_H_

#include <tvm/runtime/c_runtime_api.h>

#include <functional>
#include <vector>

namespace tvm {
namespace support {

using PartitionerFuncType = std::function<std::vector<std::vector<int>>(int, int, int, int)>;

/*!
* \brief A partitioner to split the task to each thread in Round-robin manner.
* \param begin The start index of this parallel loop(inclusive).
* \param end The end index of this parallel loop(exclusive).
* \param step The traversal step to the index.
* \param num_threads The number of threads(the number of tasks to be partitioned to).
* \return A list with `num_threads` elements, and each is a list of integers indicating the loop
* indexes for the corresponding thread to process.
*/
TVM_DLL std::vector<std::vector<int>> rr_partitioner(int begin, int end, int step, int num_threads);

/*!
* \brief A runtime api provided to run the task function in parallel.
* e.g. A for loop:
* for (int i = 0; i < 10; i++) {
* a[i] = i;
* }
* should work the same as:
* parallel_for(0, 10, [&a](int index) {
* a[i] = i;
* });
* \param begin The start index of this parallel loop(inclusive).
* \param end The end index of this parallel loop(exclusive).
* \param f The task function to be excuted. Assert to take an int index as input with no output.
* \param step The traversal step to the index.
* \param partitioner A partition function to split tasks to different threads. Use Round-robin
* partitioner by default.
* \note 1. Currently do not support nested parallel_for; 2. The order of execution in each thread
* is not guaranteed, the for loop task should be thread independent and thread safe.
*/
TVM_DLL void parallel_for(int begin, int end, const std::function<void(int)>& f, int step = 1,
const PartitionerFuncType partitioner = rr_partitioner);

} // namespace support
} // namespace tvm

#endif // TVM_SUPPORT_PARALLEL_FOR_H_
83 changes: 83 additions & 0 deletions src/support/parallel_for.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file parallel_for.cc
* \brief An implementation to run loop in parallel.
*/
#include <dmlc/logging.h>
#include <tvm/support/parallel_for.h>

#include <future>
#include <thread>
#include <utility>
#include <vector>

namespace tvm {
namespace support {

std::vector<std::vector<int>> rr_partitioner(int begin, int end, int step, int num_threads) {
int total_task_count = (end - begin) / step;
CHECK_GT(total_task_count, 0) << "Infinite loop condition, check the input value of "
<< "`begin`, `end`, `step`.";
std::vector<std::vector<int>> ret;
ret.reserve(num_threads);
for (size_t thread = 0; begin < end; begin += step, thread = (thread + 1) % num_threads) {
if (thread >= ret.size()) {
ret.push_back(std::vector<int>());
}
ret[thread].push_back(begin);
}
return ret;
}

void parallel_for(int begin, int end, const std::function<void(int)>& f, int step,
const PartitionerFuncType partitioner) {
int default_num_threads = std::thread::hardware_concurrency();
const auto& run_partitions = partitioner(begin, end, step, default_num_threads);

std::vector<std::thread> threads;
threads.reserve(run_partitions.size());
std::vector<std::future<void>> res_vec;
res_vec.reserve(run_partitions.size());
for (const auto& run_partition : run_partitions) {
std::packaged_task<void(const std::vector<int>&, const std::function<void(int)>&)> task(
[](const std::vector<int>& run_pattition, const std::function<void(int)>& f) {
for (const auto& i : run_pattition) {
f(i);
}
});
res_vec.emplace_back(task.get_future());
threads.emplace_back(std::move(task), run_partition, f);
}

for (auto&& thread : threads) {
thread.join();
}
try {
for (auto&& i : res_vec) {
i.get();
}
} catch (const std::exception& e) {
LOG(FATAL) << "Parallel_for error with " << e.what();
}
}

} // namespace support
} // namespace tvm
108 changes: 108 additions & 0 deletions tests/cpp/parallel_for_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/support/parallel_for.h>

#include <vector>

TEST(ParallelFor, Basic) {
using tvm::support::parallel_for;

int a[1000], b[1000];

// Check for a small size of parallel
for (int i = 0; i < 10; i++) {
a[i] = i;
}
parallel_for(0, 10, [&b](int i) { b[i] = i; });
for (int i = 0; i < 10; i++) {
CHECK_EQ(a[i], b[i]);
}

// Check for a large size of parallel
for (int i = 0; i < 1000; i++) {
a[i] = i;
}
parallel_for(0, 1000, [&b](int i) { b[i] = i; });
for (int i = 0; i < 1000; i++) {
CHECK_EQ(a[i], b[i]);
}

// Check for step != 1
for (int i = 0; i < 1000; i += 2) {
a[i] *= 2;
}
parallel_for(
0, 1000, [&b](int i) { b[i] *= 2; }, 2);
for (int i = 0; i < 1000; i++) {
CHECK_EQ(a[i], b[i]);
}
}

TEST(ParallelFor, NestedWithNormalForLoop) {
using tvm::support::parallel_for;

int a[500][500], b[500][500], c[500][500];

for (int i = 0; i < 500; i++) {
for (int j = 0; j < 500; j++) {
a[i][j] = i * j;
}
}

parallel_for(0, 500, [&b](int i) {
for (int j = 0; j < 500; j++) {
b[i][j] = i * j;
}
});
for (int i = 0; i < 500; i++) {
for (int j = 0; j < 500; j++) {
CHECK_EQ(a[i][j], b[i][j]);
}
}

for (int i = 0; i < 500; i++) {
parallel_for(0, 500, [&c, &i](int j) { c[i][j] = i * j; });
}
for (int i = 0; i < 500; i++) {
for (int j = 0; j < 500; j++) {
CHECK_EQ(a[i][j], c[i][j]);
}
}
}

TEST(ParallelFor, Exception) {
using tvm::support::parallel_for;

bool exception = false;
try {
parallel_for(0, 100, [](int i) { LOG(FATAL) << "error"; });
} catch (const std::exception& e) {
exception = true;
}
CHECK(exception);
}

int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
}

0 comments on commit c1d347f

Please sign in to comment.