Skip to content

Commit a124301

Browse files
committed
almost working version
1 parent 5e78da0 commit a124301

3 files changed

Lines changed: 89 additions & 8 deletions

File tree

stan/math/rev/core.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
#include <stan/math/rev/core/std_isnan.hpp>
6464
#include <stan/math/rev/core/std_numeric_limits.hpp>
6565
#include <stan/math/rev/core/stored_gradient_vari.hpp>
66+
#include <stan/math/rev/core/team_thread_pool.hpp>
6667
#include <stan/math/rev/core/typedefs.hpp>
6768
#include <stan/math/rev/core/var.hpp>
6869
#include <stan/math/rev/core/vari.hpp>

stan/math/rev/core/team_thread_pool.hpp

Lines changed: 80 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,16 @@ namespace math {
2626
*/
2727
class TeamThreadPool {
2828
public:
29+
// Call this before first use of TeamThreadPool::instance()
30+
static void set_num_threads(std::size_t n) noexcept {
31+
if (n < 1) n = 1;
32+
user_cap_().store(n, std::memory_order_release);
33+
}
34+
35+
static std::size_t get_num_threads() noexcept {
36+
return user_cap_().load(std::memory_order_acquire);
37+
}
38+
2939
static TeamThreadPool& instance() {
3040
static TeamThreadPool pool;
3141
return pool;
@@ -42,14 +52,15 @@ class TeamThreadPool {
4252

4353
template <typename F>
4454
void parallel_region(std::size_t n, F&& fn) {
55+
//std::cout << "#################### parallel_region, n = " << n << std::endl;
4556
if (n == 0) return;
4657

4758
// If called from a worker, run serial to avoid nested deadlocks.
59+
//std::cout << "in_worker_ = " << in_worker_ << std::endl;
4860
if (in_worker_) {
4961
fn(std::size_t{0});
5062
return;
5163
}
52-
5364
const std::size_t max_team = team_size();
5465
if (max_team == 1) {
5566
fn(std::size_t{0});
@@ -84,30 +95,90 @@ class TeamThreadPool {
8495
fn_copy(0);
8596
in_worker_ = false;
8697

98+
//std::cout << "waiting for workers" << std::endl;
8799
// Wait for workers 1..n-1
88100
std::unique_lock<std::mutex> lk(done_m_);
89101
done_cv_.wait(lk, [&] {
90102
return remaining_.load(std::memory_order_acquire) == 0;
91103
});
104+
//std::cout << "#################### done" << std::endl << std::endl;
105+
}
106+
107+
private:
108+
// Function-local static avoids static init order fiasco.
109+
static std::atomic<std::size_t>& user_cap_() {
110+
static std::atomic<std::size_t> cap{0}; // 0 => "unset"
111+
return cap;
112+
}
113+
114+
static std::size_t configured_cap_(std::size_t hw) {
115+
// priority: user cap > env var > hw
116+
std::size_t cap = user_cap_().load(std::memory_order_acquire);
117+
if (cap == 0) {
118+
cap = env_num_threads_(); // if you have STAN_NUM_THREADS support
119+
}
120+
if (cap == 0) cap = hw;
121+
122+
if (cap < 1) cap = 1;
123+
if (cap > hw) cap = hw; // prevent oversubscription by default
124+
return cap;
92125
}
93126

94-
private:
127+
95128
using call_fn_t = void (*)(void*, std::size_t);
96129

97130
template <typename Fn>
98131
static void call_impl(void* ctx, std::size_t tid) {
99132
(*static_cast<Fn*>(ctx))(tid);
100133
}
101134

135+
static size_t env_num_threads_() {
136+
size_t num_threads = 1;
137+
#ifdef STAN_THREADS
138+
const char* env_stan_num_threads = std::getenv("STAN_NUM_THREADS");
139+
if (env_stan_num_threads != nullptr) {
140+
try {
141+
const int env_num_threads
142+
= boost::lexical_cast<int>(env_stan_num_threads);
143+
if (env_num_threads > 0) {
144+
num_threads = env_num_threads;
145+
} else if (env_num_threads == -1) {
146+
num_threads = std::thread::hardware_concurrency();
147+
} else {
148+
invalid_argument("get_num_threads(int)", "STAN_NUM_THREADS",
149+
env_stan_num_threads,
150+
"The STAN_NUM_THREADS environment variable is '",
151+
"' but it must be positive or -1");
152+
}
153+
} catch (const boost::bad_lexical_cast&) {
154+
invalid_argument("get_num_threads(int)", "STAN_NUM_THREADS",
155+
env_stan_num_threads,
156+
"The STAN_NUM_THREADS environment variable is '",
157+
"' but it must be a positive number or -1");
158+
}
159+
}
160+
#endif
161+
return num_threads;
162+
}
163+
164+
102165
TeamThreadPool()
103166
: stop_(false), epoch_(0), region_n_(0), region_ctx_(nullptr),
104167
region_call_(nullptr), remaining_(0) {
105-
unsigned hw = std::thread::hardware_concurrency();
106-
if (hw == 0) hw = 2;
107-
108-
// hw-1 worker threads; caller is +1 participant.
109-
const unsigned num_workers = (hw > 1) ? (hw - 1) : 1;
110168

169+
unsigned hw_u = std::thread::hardware_concurrency();
170+
if (hw_u == 0) hw_u = 2;
171+
const std::size_t hw = static_cast<std::size_t>(hw_u);
172+
173+
const std::size_t cap = configured_cap_(hw);
174+
const std::size_t num_workers = (cap > 1) ? (cap - 1) : 0;
175+
176+
std::cout << "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^" << std::endl
177+
<< "hw = " << hw << std::endl
178+
<< "num_workers = " << num_workers << std::endl
179+
<< "cap = " << cap << std::endl
180+
<< std::endl << std::endl;
181+
111182
workers_.reserve(num_workers);
112183
for (unsigned i = 0; i < num_workers; ++i) {
113184
const std::size_t tid = static_cast<std::size_t>(i + 1); // workers are 1..N
@@ -151,7 +222,9 @@ class TeamThreadPool {
151222
in_worker_ = false;
152223
});
153224
}
225+
std::cout << "done with constructor" << std::endl;
154226
}
227+
155228

156229
~TeamThreadPool() {
157230
stop_.store(true, std::memory_order_release);

stan/math/rev/functor/reduce_sum.hpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,14 @@ struct reduce_sum_impl<ReduceFunction, require_var_t<ReturnType>, ReturnType, Ve
216216
num_vars_per_term, num_vars_shared_terms, partials,
217217
vmapped, msgs, args...));
218218
}
219-
219+
/*
220+
std::cout << "--------------------------------------------------------------------------------" << std::endl
221+
<< "worker count = " << pool.worker_count() << std::endl
222+
<< "team size = " << pool.team_size() << std::endl
223+
<< "gs = " << gs << std::endl
224+
<< std::endl << std::endl;
225+
*/
226+
220227
// Static partition: each participant gets a contiguous block once
221228
pool.parallel_region(n, [&](std::size_t tid) {
222229
const std::size_t b0 = (num_terms * tid) / n;

0 commit comments

Comments
 (0)