forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ParallelThreadPoolNative.cpp
88 lines (74 loc) · 2.28 KB
/
ParallelThreadPoolNative.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#if AT_PARALLEL_OPENMP || AT_PARALLEL_NATIVE || AT_PARALLEL_NATIVE_TBB
#include <ATen/Parallel.h>
#include <ATen/PTThreadPool.h>
#include <ATen/ThreadLocalDebugInfo.h>
#include <atomic>
namespace at {
namespace {
const int NOT_SET = -1;
const int CONSUMED = -2;
// Number of inter-op threads set by the user;
// NOT_SET -> positive value -> CONSUMED
// (CONSUMED - thread pool is initialized)
// or
// NOT_SET -> CONSUMED
std::atomic<int> num_interop_threads{NOT_SET};
// thread pool global instance is hidden,
// users should use at::launch and get/set_num_interop_threads interface
TaskThreadPoolBase& get_pool() {
static std::shared_ptr<TaskThreadPoolBase> pool =
ThreadPoolRegistry()->Create(
"C10",
/* device_id */ 0,
/* pool_size */ num_interop_threads.exchange(CONSUMED),
/* create_new */ true);
return *pool;
}
// Factory function for ThreadPoolRegistry
std::shared_ptr<TaskThreadPoolBase> create_c10_threadpool(
int device_id,
int pool_size,
bool create_new) {
// For now, the only accepted device id is 0
TORCH_CHECK(device_id == 0);
// Create new thread pool
TORCH_CHECK(create_new);
return std::make_shared<PTThreadPool>(pool_size);
}
} // namespace
C10_REGISTER_CREATOR(ThreadPoolRegistry, C10, create_c10_threadpool);
void set_num_interop_threads(int nthreads) {
TORCH_CHECK(nthreads > 0, "Expected positive number of threads");
int no_value = NOT_SET;
TORCH_CHECK(num_interop_threads.compare_exchange_strong(no_value, nthreads),
"Error: cannot set number of interop threads after parallel work "
"has started or set_num_interop_threads called");
}
int get_num_interop_threads() {
int nthreads = num_interop_threads.load();
if (nthreads > 0) {
return nthreads;
} else if (nthreads == NOT_SET) {
// return default value
return TaskThreadPoolBase::defaultNumThreads();
} else {
return get_pool().size();
}
}
void launch(std::function<void()> func) {
auto fn = std::bind([](
std::function<void()> f, std::shared_ptr<ThreadLocalDebugInfoBase> info) {
DebugInfoGuard guard(std::move(info));
f();
},
std::move(func),
getThreadLocalDebugInfo()
);
#if AT_EXPERIMENTAL_SINGLE_THREAD_POOL
intraop_launch(fn);
#else
get_pool().run(fn);
#endif
}
} // namespace at
#endif