forked from Stonepia/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathParallelNative.cpp
320 lines (281 loc) · 8.64 KB
/
ParallelNative.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
#include <ATen/Config.h>
#if AT_PARALLEL_NATIVE
#include <ATen/Parallel.h>
#include <ATen/ParallelFuture.h>
#include <ATen/PTThreadPool.h>
#ifndef C10_MOBILE
#include <c10/core/thread_pool.h>
#include <c10/util/irange.h>
#else
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
#endif // C10_MOBILE
#include <atomic>
#include <utility>
#ifdef _OPENMP
#include <omp.h>
#endif
#if AT_MKL_ENABLED()
#include <mkl.h>
#endif
namespace at {
namespace {
// used with _set_in_parallel_region to mark master thread
// as in parallel region while executing parallel primitives
thread_local bool in_parallel_region_ = false;
// thread number (task_id) set by parallel primitive
thread_local int thread_num_ = 0;
void _set_in_parallel_region(bool in_region) {
in_parallel_region_ = in_region;
}
} // namespace (anonymous)
namespace internal {
void set_thread_num(int thread_num) {
thread_num_ = thread_num;
}
}
namespace {
void _unset_thread_num() {
thread_num_ = 0;
}
#ifndef C10_MOBILE
const int NOT_SET = -1;
const int CONSUMED = -2;
// Number of threads set by the user
// NOT_SET -> positive value -> CONSUMED
// or
// NOT_SET -> CONSUMED
// Meaning:
// - NOT_SET - pool not initialized, user value is not set
// - positive value - pool not initialized, user value set
// - CONSUMED - pool is initialized
std::atomic<int> num_intraop_threads{NOT_SET};
int _num_pool_threads(int nthreads) {
if (nthreads == NOT_SET) {
nthreads = intraop_default_num_threads();
} else {
TORCH_INTERNAL_ASSERT(nthreads > 0);
}
// minus one because of the master thread
return nthreads - 1;
}
TaskThreadPoolBase& _get_intraop_pool() {
static std::shared_ptr<TaskThreadPoolBase> pool =
ThreadPoolRegistry()->Create(
"C10",
/* device_id */ 0,
/* pool_size */ _num_pool_threads(num_intraop_threads.exchange(CONSUMED)),
/* create_new */ true); // create a separate thread pool for intra-op
return *pool;
}
#endif // C10_MOBILE
// Run lambda function `fn` over `task_id` in [0, `range`) with threadpool.
// `fn` will be called with params: (thread_pool_task_id, task_id).
void _run_with_pool(const std::function<void(int, size_t)>& fn, size_t range) {
#ifndef C10_MOBILE
for (const auto i : c10::irange(1, range)) {
_get_intraop_pool().run([fn, i]() { fn((int)i, i); });
}
// Run the first task on the current thread directly.
fn(0, 0);
#else
caffe2::PThreadPool* const pool = caffe2::pthreadpool();
TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!");
pool->run(
// PThreadPool::run() is blocking. A std::function [const] reference to
// this lambda cannot go out of scope before PThreadPool::run() returns.
[&fn](const size_t task_id) {
fn(0 /* unused */, task_id);
}, range);
#endif // C10_MOBILE
}
// RAII guard helps to support in_parallel_region() and get_thread_num() API.
struct ParallelRegionGuard {
ParallelRegionGuard(int task_id) {
internal::set_thread_num(task_id);
_set_in_parallel_region(true);
}
~ParallelRegionGuard() {
_set_in_parallel_region(false);
_unset_thread_num();
}
};
} // namespace
namespace internal {
inline std::tuple<size_t, size_t> calc_num_tasks_and_chunk_size(
int64_t begin, int64_t end, int64_t grain_size) {
if ((end - begin) < grain_size) {
return std::make_tuple(1, std::max((int64_t)0, end - begin));
}
// Choose number of tasks based on grain size and number of threads.
size_t chunk_size = divup((end - begin), get_num_threads());
// Make sure each task is at least grain_size size.
chunk_size = std::max((size_t)grain_size, chunk_size);
size_t num_tasks = divup((end - begin), chunk_size);
return std::make_tuple(num_tasks, chunk_size);
}
void invoke_parallel(
const int64_t begin,
const int64_t end,
const int64_t grain_size,
const std::function<void(int64_t, int64_t)>& f) {
at::internal::lazy_init_num_threads();
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
size_t num_tasks, chunk_size;
std::tie(num_tasks, chunk_size) =
internal::calc_num_tasks_and_chunk_size(begin, end, grain_size);
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
struct {
std::atomic_flag err_flag = ATOMIC_FLAG_INIT;
std::exception_ptr eptr;
std::mutex mutex;
volatile size_t remaining;
std::condition_variable cv;
} state;
auto task = [f, &state, begin, end, chunk_size]
(int /* unused */, size_t task_id) {
int64_t local_start = begin + task_id * chunk_size;
if (local_start < end) {
int64_t local_end = std::min(end, (int64_t)(chunk_size + local_start));
try {
ParallelRegionGuard guard(task_id);
f(local_start, local_end);
} catch (...) {
if (!state.err_flag.test_and_set()) {
state.eptr = std::current_exception();
}
}
}
{
std::unique_lock<std::mutex> lk(state.mutex);
if (--state.remaining == 0) {
state.cv.notify_one();
}
}
};
state.remaining = num_tasks;
_run_with_pool(std::move(task), num_tasks);
// Wait for all tasks to finish.
{
std::unique_lock<std::mutex> lk(state.mutex);
if (state.remaining != 0) {
state.cv.wait(lk);
}
}
if (state.eptr) {
std::rethrow_exception(state.eptr);
}
}
} // namespace internal
void init_num_threads() {
#ifdef _OPENMP
omp_set_num_threads(1);
#endif
#if AT_MKL_ENABLED()
mkl_set_num_threads(1);
#endif
#ifdef C10_MOBILE
caffe2::pthreadpool();
#endif
}
void set_num_threads(int nthreads) {
#ifndef C10_MOBILE
TORCH_CHECK(nthreads > 0, "Expected positive number of threads");
int no_value = NOT_SET;
if (!num_intraop_threads.compare_exchange_strong(no_value, nthreads)) {
// num_intraop_threads either stores a positive integer or CONSUMED,
// check that requested size is the same as the current one
int stored_nthreads = num_intraop_threads.load();
if (stored_nthreads <= 0) {
// plus one because of master thread
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
stored_nthreads = _get_intraop_pool().size() + 1;
}
if (stored_nthreads != nthreads) {
TORCH_WARN(
"Cannot set number of intraop threads "
"after parallel work has started or after set_num_threads call "
"when using native parallel backend");
}
}
#else
caffe2::PThreadPool* const pool = caffe2::pthreadpool();
TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!");
pool->set_thread_count(nthreads);
#endif // C10_MOBILE
}
int get_num_threads() {
at::internal::lazy_init_num_threads();
#ifndef C10_MOBILE
// not initializing pool unnecessarily,
// because pool cannot be resized after initialization
int nthreads = num_intraop_threads.load();
if (nthreads > 0) {
return nthreads;
} else if (nthreads == NOT_SET) {
return intraop_default_num_threads();
} else {
TORCH_INTERNAL_ASSERT(nthreads == CONSUMED);
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
return _get_intraop_pool().size() + 1;
}
#else
caffe2::PThreadPool* const pool = caffe2::pthreadpool();
TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!")
return in_parallel_region() ? 1 /* current thread */ : pool->get_thread_count();
#endif // C10_MOBILE
}
int get_thread_num() {
return thread_num_;
}
bool in_parallel_region() {
#ifndef C10_MOBILE
return in_parallel_region_ || (
num_intraop_threads.load() == CONSUMED &&
// Needed as intraop_launch() doesn't set in_parallel_region().
_get_intraop_pool().inThreadPool()
);
#else
return in_parallel_region_;
#endif // C10_MOBILE
}
void intraop_launch(std::function<void()> func) {
#ifndef C10_MOBILE
if (!in_parallel_region() && get_num_threads() > 1) {
_get_intraop_pool().run(std::move(func));
} else {
// execute inline if we're in parallel region
func();
}
#else
// TODO: caffe2::PThreadPool only provides a data-parallel API.
// Task parallelism is not currently supported.
func();
#endif // C10_MOBILE
}
c10::intrusive_ptr<c10::ivalue::Future> intraop_launch_future(
std::function<void()> func) {
#ifndef C10_MOBILE
auto future = c10::make_intrusive<c10::ivalue::Future>(c10::NoneType::get());
if (!in_parallel_region() && get_num_threads() > 1) {
_get_intraop_pool().run(
[func, future]() {
func();
future->markCompleted();
}
);
} else {
func();
future->markCompleted();
}
return future;
#else
// TODO: caffe2::PThreadPool only provides a data-parallel API.
// Task parallelism is not currently supported.
auto future = c10::make_intrusive<c10::ivalue::Future>(c10::dynT<NoneType>());
func();
future->markCompleted();
return future;
#endif // C10_MOBILE
}
} // namespace at
#endif