Skip to content

Commit 99c269b

Browse files
committed
better implementation, still hangs
1 parent a124301 commit 99c269b

1 file changed

Lines changed: 101 additions & 91 deletions

File tree

stan/math/rev/core/team_thread_pool.hpp

Lines changed: 101 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@
66
#include <atomic>
77
#include <condition_variable>
88
#include <cstddef>
9+
#include <cstdlib>
10+
#include <exception>
911
#include <mutex>
1012
#include <thread>
13+
#include <type_traits>
1114
#include <utility>
1215
#include <vector>
1316

@@ -17,16 +20,16 @@ namespace math {
1720
/**
1821
* Team (epoch) thread pool for low-overhead parallel regions.
1922
*
20-
* - Creates (hw-1) worker threads once.
21-
* - Caller participates with tid=0.
23+
* - Workers are created once.
24+
* - Caller participates as tid=0.
2225
* - parallel_region(n, fn): runs fn(tid) for tid in [0, n).
2326
* - Nested parallelism: if called from a worker thread, runs serial.
24-
*
25-
* Designed for reduce_sum/map_rect style internal parallelism.
27+
* - set_num_threads(k) must be called before instance() to size the pool.
2628
*/
2729
class TeamThreadPool {
2830
public:
29-
// Call this before first use of TeamThreadPool::instance()
31+
// Call before first instance() to control pool size.
32+
// Meaning: total participants INCLUDING caller (tid=0).
3033
static void set_num_threads(std::size_t n) noexcept {
3134
if (n < 1) n = 1;
3235
user_cap_().store(n, std::memory_order_release);
@@ -41,26 +44,25 @@ class TeamThreadPool {
4144
return pool;
4245
}
4346

44-
TeamThreadPool(const TeamThreadPool&) = delete;
45-
TeamThreadPool& operator=(const TeamThreadPool&) = delete;
46-
47-
// Number of worker threads (excluding caller)
47+
// Worker threads (excluding caller)
4848
std::size_t worker_count() const noexcept { return workers_.size(); }
4949

50-
// Total participants available = worker_count + 1 (caller)
50+
// Total possible participants INCLUDING caller
5151
std::size_t team_size() const noexcept { return workers_.size() + 1; }
5252

5353
template <typename F>
5454
void parallel_region(std::size_t n, F&& fn) {
55-
//std::cout << "#################### parallel_region, n = " << n << std::endl;
5655
if (n == 0) return;
5756

58-
// If called from a worker, run serial to avoid nested deadlocks.
59-
//std::cout << "in_worker_ = " << in_worker_ << std::endl;
57+
// Nested parallelism guard: if already on a worker, run serial.
6058
if (in_worker_) {
6159
fn(std::size_t{0});
6260
return;
6361
}
62+
63+
// Only one active region at a time (this is required for a single shared epoch design).
64+
std::unique_lock<std::mutex> region_lock(region_m_);
65+
6466
const std::size_t max_team = team_size();
6567
if (max_team == 1) {
6668
fn(std::size_t{0});
@@ -72,11 +74,17 @@ class TeamThreadPool {
7274
return;
7375
}
7476

75-
// Stable storage for callable during this region
7677
using Fn = std::decay_t<F>;
77-
Fn fn_copy = std::forward<F>(fn);
78+
Fn fn_copy = std::forward<F>(fn); // stable storage during this call
79+
80+
// Exception propagation: capture first exception from any participant.
81+
std::exception_ptr eptr = nullptr;
82+
{
83+
std::lock_guard<std::mutex> lk(exc_m_);
84+
exc_ptr_ = &eptr;
85+
}
7886

79-
// Publish region
87+
// Publish region state BEFORE bumping epoch.
8088
remaining_.store(n - 1, std::memory_order_release); // workers only
8189
region_n_.store(n, std::memory_order_release);
8290
region_ctx_.store(static_cast<void*>(&fn_copy), std::memory_order_release);
@@ -92,104 +100,86 @@ class TeamThreadPool {
92100

93101
// Caller participates as tid=0
94102
in_worker_ = true;
95-
fn_copy(0);
103+
try {
104+
fn_copy(0);
105+
} catch (...) {
106+
std::lock_guard<std::mutex> lk(exc_m_);
107+
if (eptr == nullptr) eptr = std::current_exception();
108+
}
96109
in_worker_ = false;
97110

98-
//std::cout << "waiting for workers" << std::endl;
99111
// Wait for workers 1..n-1
100112
std::unique_lock<std::mutex> lk(done_m_);
101113
done_cv_.wait(lk, [&] {
102114
return remaining_.load(std::memory_order_acquire) == 0;
103115
});
104-
//std::cout << "#################### done" << std::endl << std::endl;
105-
}
106-
107-
private:
108-
// Function-local static avoids static init order fiasco.
109-
static std::atomic<std::size_t>& user_cap_() {
110-
static std::atomic<std::size_t> cap{0}; // 0 => "unset"
111-
return cap;
112-
}
113116

114-
static std::size_t configured_cap_(std::size_t hw) {
115-
// priority: user cap > env var > hw
116-
std::size_t cap = user_cap_().load(std::memory_order_acquire);
117-
if (cap == 0) {
118-
cap = env_num_threads_(); // if you have STAN_NUM_THREADS support
119-
}
120-
if (cap == 0) cap = hw;
117+
// Clear region participation; not strictly necessary but good hygiene.
118+
region_n_.store(0, std::memory_order_release);
121119

122-
if (cap < 1) cap = 1;
123-
if (cap > hw) cap = hw; // prevent oversubscription by default
124-
return cap;
120+
// Rethrow exception (if any)
121+
if (eptr) std::rethrow_exception(eptr);
125122
}
126123

127-
124+
private:
128125
using call_fn_t = void (*)(void*, std::size_t);
129126

130127
template <typename Fn>
131128
static void call_impl(void* ctx, std::size_t tid) {
132129
(*static_cast<Fn*>(ctx))(tid);
133130
}
134131

135-
static size_t env_num_threads_() {
136-
size_t num_threads = 1;
137-
#ifdef STAN_THREADS
138-
const char* env_stan_num_threads = std::getenv("STAN_NUM_THREADS");
139-
if (env_stan_num_threads != nullptr) {
140-
try {
141-
const int env_num_threads
142-
= boost::lexical_cast<int>(env_stan_num_threads);
143-
if (env_num_threads > 0) {
144-
num_threads = env_num_threads;
145-
} else if (env_num_threads == -1) {
146-
num_threads = std::thread::hardware_concurrency();
147-
} else {
148-
invalid_argument("get_num_threads(int)", "STAN_NUM_THREADS",
149-
env_stan_num_threads,
150-
"The STAN_NUM_THREADS environment variable is '",
151-
"' but it must be positive or -1");
152-
}
153-
} catch (const boost::bad_lexical_cast&) {
154-
invalid_argument("get_num_threads(int)", "STAN_NUM_THREADS",
155-
env_stan_num_threads,
156-
"The STAN_NUM_THREADS environment variable is '",
157-
"' but it must be a positive number or -1");
158-
}
159-
}
160-
#endif
161-
return num_threads;
162-
}
132+
// Function-local static avoids static initialization order issues.
133+
static std::atomic<std::size_t>& user_cap_() {
134+
static std::atomic<std::size_t> cap{0}; // 0 => unset
135+
return cap;
136+
}
163137

138+
static std::size_t env_num_threads_() noexcept {
139+
const char* s = std::getenv("STAN_NUM_THREADS");
140+
if (!s || !*s) return 0;
141+
char* end = nullptr;
142+
long v = std::strtol(s, &end, 10);
143+
if (end == s || v <= 0) return 0;
144+
return static_cast<std::size_t>(v);
145+
}
164146

165-
TeamThreadPool()
166-
: stop_(false), epoch_(0), region_n_(0), region_ctx_(nullptr),
167-
region_call_(nullptr), remaining_(0) {
147+
static std::size_t configured_cap_(std::size_t hw) noexcept {
148+
// Priority: explicit set_num_threads > STAN_NUM_THREADS > hw
149+
std::size_t cap = user_cap_().load(std::memory_order_acquire);
150+
if (cap == 0) cap = env_num_threads_();
151+
if (cap == 0) cap = hw;
168152

153+
if (cap < 1) cap = 1;
154+
if (cap > hw) cap = hw; // don’t oversubscribe by default
155+
return cap;
156+
}
157+
158+
TeamThreadPool()
159+
: stop_(false),
160+
epoch_(0),
161+
region_n_(0),
162+
region_ctx_(nullptr),
163+
region_call_(nullptr),
164+
remaining_(0),
165+
exc_ptr_(nullptr) {
169166
unsigned hw_u = std::thread::hardware_concurrency();
170167
if (hw_u == 0) hw_u = 2;
171168
const std::size_t hw = static_cast<std::size_t>(hw_u);
172-
169+
170+
// Total participants includes caller.
173171
const std::size_t cap = configured_cap_(hw);
174172
const std::size_t num_workers = (cap > 1) ? (cap - 1) : 0;
175173

176-
std::cout << "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^" << std::endl
177-
<< "hw = " << hw << std::endl
178-
<< "num_workers = " << num_workers << std::endl
179-
<< "cap = " << cap << std::endl
180-
<< std::endl << std::endl;
181-
182174
workers_.reserve(num_workers);
183-
for (unsigned i = 0; i < num_workers; ++i) {
184-
const std::size_t tid = static_cast<std::size_t>(i + 1); // workers are 1..N
175+
for (std::size_t i = 0; i < num_workers; ++i) {
176+
const std::size_t tid = i + 1; // workers are 1..num_workers
185177
workers_.emplace_back([this, tid] {
186-
// Per-worker AD tape initialized once
187178
static thread_local ChainableStack ad_tape;
188179
in_worker_ = true;
189180

190181
std::size_t seen = epoch_.load(std::memory_order_acquire);
191182
for (;;) {
192-
// Sleep until epoch changes or stop requested
193183
{
194184
std::unique_lock<std::mutex> lk(wake_m_);
195185
wake_cv_.wait(lk, [&] {
@@ -203,28 +193,41 @@ class TeamThreadPool {
203193
seen = e;
204194

205195
const std::size_t n = region_n_.load(std::memory_order_acquire);
206-
if (tid >= n) {
207-
continue; // not participating this region
208-
}
196+
if (tid >= n) continue; // not participating this region
197+
198+
// Ensure we ALWAYS decrement remaining_ once for participating workers.
199+
struct DoneGuard {
200+
std::atomic<std::size_t>& rem;
201+
std::mutex& m;
202+
std::condition_variable& cv;
203+
bool active{true};
204+
~DoneGuard() {
205+
if (!active) return;
206+
if (rem.fetch_sub(1, std::memory_order_acq_rel) == 1) {
207+
std::lock_guard<std::mutex> lk(m);
208+
cv.notify_one();
209+
}
210+
}
211+
} guard{remaining_, done_m_, done_cv_};
209212

210213
void* ctx = region_ctx_.load(std::memory_order_acquire);
211214
call_fn_t call = region_call_.load(std::memory_order_acquire);
212-
if (call) {
213-
call(ctx, tid);
214-
}
215+
if (!call) continue;
215216

216-
if (remaining_.fetch_sub(1, std::memory_order_acq_rel) == 1) {
217-
std::lock_guard<std::mutex> lk(done_m_);
218-
done_cv_.notify_one();
217+
try {
218+
call(ctx, tid);
219+
} catch (...) {
220+
std::lock_guard<std::mutex> lk(exc_m_);
221+
if (exc_ptr_ && *exc_ptr_ == nullptr) {
222+
*exc_ptr_ = std::current_exception();
223+
}
219224
}
220225
}
221226

222227
in_worker_ = false;
223228
});
224229
}
225-
std::cout << "done with constructor" << std::endl;
226230
}
227-
228231

229232
~TeamThreadPool() {
230233
stop_.store(true, std::memory_order_release);
@@ -242,6 +245,9 @@ class TeamThreadPool {
242245
std::vector<std::thread> workers_;
243246
std::atomic<bool> stop_;
244247

248+
// Serialize regions (single shared-region design)
249+
std::mutex region_m_;
250+
245251
// Region publication
246252
std::atomic<std::size_t> epoch_;
247253
std::atomic<std::size_t> region_n_;
@@ -256,6 +262,10 @@ class TeamThreadPool {
256262
std::atomic<std::size_t> remaining_;
257263
std::mutex done_m_;
258264
std::condition_variable done_cv_;
265+
266+
// Exception plumbing
267+
std::mutex exc_m_;
268+
std::exception_ptr* exc_ptr_;
259269
};
260270

261271
} // namespace math

0 commit comments

Comments
 (0)