-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathMountainRangeThreaded.hpp
142 lines (112 loc) · 4.79 KB
/
MountainRangeThreaded.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
#pragma once
#include <cstring>
#include <charconv>
#include <vector>
#include <array>
#include <ranges>
#include <thread>
#include <semaphore>
#include <atomic>
#include <barrier>
#include "MountainRange.hpp"
namespace {
// Create a vector of threads, each of which will run F(thread_id) in a while loop until F() returns false
auto looping_threadpool(auto thread_count, auto F) {
std::vector<std::jthread> threads;
threads.reserve(thread_count);
for (size_t tid=0; tid<thread_count; tid++) {
threads.emplace_back([F, tid]{
while (F(tid));
});
}
return threads;
}
};
class MountainRangeThreaded: public MountainRange {
// Threading-related members
bool continue_iteration; // used to tell the looping threadpool to terminate at the end of the simulation
const size_type nthreads;
std::barrier<> ds_barrier, step_barrier;
std::vector<std::jthread> ds_workers, step_workers;
std::atomic<value_type> ds_aggregator; // used to reduce dsteepness from each thread
value_type iter_dt; // Used to distribute dt to each thread
// Determine which rows a certain thread is in charge of
auto this_thread_cell_range(auto tid) {
return mr::split_range(cells, tid, nthreads);
}
public:
// Help message to show that SOLVER_NUM_THREADS controls thread counts
inline static const std::string help_message =
"Set the environment variable SOLVER_NUM_THREADS to a positive integer to set thread count (default 1).";
// Run base constructor, then build threading infrastructure
MountainRangeThreaded(auto &&...args): MountainRange(args...), // https://tinyurl.com/byusc-parpack
continue_iteration{true},
nthreads{[]{ // https://tinyurl.com/byusc-lambdai
size_type nthreads = 1;
auto nthreads_str = std::getenv("SOLVER_NUM_THREADS");
if (nthreads_str != nullptr) std::from_chars(nthreads_str, nthreads_str+std::strlen(nthreads_str), nthreads);
return nthreads;
}()},
ds_barrier(nthreads+1), // worker threads plus main thread
step_barrier(nthreads+1), // worker threads plus main thread
ds_workers(looping_threadpool(nthreads, [this](auto tid){ // https://tinyurl.com/byusc-lambda
ds_barrier.arrive_and_wait();
if (!continue_iteration) return false;
auto [first, last] = this_thread_cell_range(tid);
auto gfirst = tid==0 ? 1 : first;
auto glast = tid==nthreads-1 ? last-1 : last;
value_type ds_local = 0;
for (size_t i=gfirst; i<glast; i++) ds_local += ds_cell(i);
ds_aggregator += ds_local;
ds_barrier.arrive_and_wait();
return true;
})),
step_workers(looping_threadpool(nthreads, [this](auto tid){ // https://tinyurl.com/byusc-lambda
step_barrier.arrive_and_wait();
if (!continue_iteration) return false;
auto [first, last] = this_thread_cell_range(tid);
auto gfirst = tid==0 ? 1 : first;
auto glast = tid==nthreads-1 ? last-1 : last;
for (size_t i=first; i<last; i++) update_h_cell(i, iter_dt);
step_barrier.arrive_and_wait(); // h has to be completely updated before g update can start
for (size_t i=gfirst; i<glast; i++) update_g_cell(i);
step_barrier.arrive_and_wait();
return true;
})) {
// Initialize g
step(0);
}
// Destructor just tells threads to exit
~MountainRangeThreaded() {
continue_iteration = false;
ds_barrier.arrive_and_wait(); // signal ds_workers to exit
step_barrier.arrive_and_wait(); // signal step_workers to exit
}
// Steepness derivative
value_type dsteepness() override {
// Reset reduction destination
ds_aggregator = 0;
// Launch workers
ds_barrier.arrive_and_wait();
// Wait for workers to finish this iteration
ds_barrier.arrive_and_wait();
return ds_aggregator;
}
// Iterate from t to t+dt in one step
value_type step(value_type dt) override {
// Let threads know what the time step this iteration is
iter_dt = dt;
// Signal workers to update h
step_barrier.arrive_and_wait();
// Signal workers to update g
step_barrier.arrive_and_wait();
// Wait for workers to finish this iteration
step_barrier.arrive_and_wait();
// Enforce boundary condition
g[0] = g[1];
g[cells-1] = g[cells-2];
// Increment and return dt
t += dt;
return t;
}
};