Skip to content

Commit 0258923

Browse files
committed
updates to team_thread_pool.hpp
1 parent 1526101 commit 0258923

1 file changed

Lines changed: 38 additions & 174 deletions

File tree

stan/math/rev/core/team_thread_pool.hpp

Lines changed: 38 additions & 174 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
#include <atomic>
77
#include <condition_variable>
88
#include <cstddef>
9-
#include <cstdio> // fprintf, fflush
109
#include <cstdlib> // getenv, strtol
1110
#include <exception> // exception_ptr
1211
#include <mutex>
@@ -15,34 +14,21 @@
1514
#include <utility>
1615
#include <vector>
1716

18-
// Debug mode: periodic dumps while waiting.
19-
// #define STAN_TEAM_POOL_DEBUG_WAIT 1
20-
21-
#if defined(STAN_TEAM_POOL_DEBUG_WAIT)
22-
#include <chrono>
23-
#endif
24-
2517
namespace stan {
2618
namespace math {
2719

2820
/**
29-
* Team (epoch) thread pool for low-overhead parallel regions.
21+
* TeamThreadPool
3022
*
31-
* Logical tids:
32-
* - caller thread: tid=0
33-
* - worker threads: tid=1..cap-1
23+
* - Fixed set of worker threads created once.
24+
* - Caller participates as logical tid=0.
25+
* - Worker threads have stable logical tids 1..(cap-1).
26+
* - parallel_region(n, fn): runs fn(tid) for tid in [0, n), exactly once each.
3427
*
35-
* Key correctness points:
36-
* - Single shared "region" state => serialize parallel_region() with region_m_.
37-
* - Wake generation (wake_gen_) protected by wake_m_ to prevent missed wakeups.
38-
* - Startup barrier ensures all workers are "armed" on wake_gen_ before first use,
39-
* preventing late-start workers from missing the first region.
40-
*
41-
* Debug fields per logical tid:
42-
* - alive[tid] : 1 once the worker thread started (0 means never started / died early)
43-
* - seen[tid] : last epoch observed by that worker (0 means never saw any region)
44-
* - exec[tid] : epoch currently executing (0 means not executing region work)
45-
* - dec[tid] : count of decrements performed by that worker
28+
* Notes:
29+
* - Nested parallel_region calls from a worker run serially to avoid deadlock.
30+
* - Uses an epoch counter + condition_variable to wake workers per region.
31+
* - Startup barrier ensures all workers are waiting before the first region launch.
4632
*/
4733
class TeamThreadPool {
4834
public:
@@ -64,59 +50,21 @@ class TeamThreadPool {
6450
std::size_t worker_count() const noexcept { return workers_.size(); }
6551
std::size_t team_size() const noexcept { return workers_.size() + 1; }
6652

67-
void dump_state(const char* tag = "TeamThreadPool") const {
68-
const auto epoch = epoch_.load(std::memory_order_acquire);
69-
const auto n = region_n_.load(std::memory_order_acquire);
70-
const auto rem = remaining_.load(std::memory_order_acquire);
71-
72-
std::fprintf(stderr,
73-
"\n[%s] epoch=%zu region_n=%zu remaining=%zu team_size=%zu wake_gen=%zu ready=%zu/%zu\n",
74-
tag, epoch, n, rem, team_size(), wake_gen_snapshot_(),
75-
ready_count_.load(std::memory_order_acquire),
76-
workers_.size());
77-
std::fprintf(stderr, " tid | alive | seen | exec | dec | note\n");
78-
std::fprintf(stderr, " ----+-------+-------+-------+-----+-----------------------------\n");
79-
80-
for (std::size_t tid = 1; tid < worker_state_size_; ++tid) {
81-
const unsigned alive = worker_alive_[tid].load(std::memory_order_acquire);
82-
const std::size_t seen = worker_seen_epoch_[tid].load(std::memory_order_acquire);
83-
const std::size_t exec = worker_exec_epoch_[tid].load(std::memory_order_acquire);
84-
const std::size_t dec = worker_decrement_count_[tid].load(std::memory_order_acquire);
85-
86-
const char* note = "";
87-
if (!alive) {
88-
note = "NOT ALIVE";
89-
} else if (tid < n) {
90-
if (seen < epoch) note = "participating but hasn't seen epoch";
91-
else if (exec == epoch) note = "executing";
92-
else if (exec != 0) note = "executing (old epoch?)";
93-
else note = "idle/finished";
94-
} else {
95-
if (seen == epoch) note = "saw epoch but not participating";
96-
else note = "not participating";
97-
}
98-
99-
std::fprintf(stderr, " %3zu | %u | %5zu | %5zu | %3zu | %s\n",
100-
tid, alive, seen, exec, dec, note);
101-
}
102-
std::fflush(stderr);
103-
}
104-
10553
template <typename F>
10654
void parallel_region(std::size_t n, F&& fn) {
10755
if (n == 0) return;
10856

109-
// Nested parallelism guard.
57+
// Prevent nested parallelism from deadlocking the pool.
11058
if (in_worker_) {
11159
fn(std::size_t{0});
11260
return;
11361
}
11462

115-
// Single shared region state => serialize launches.
63+
// Only one active region at a time (shared region state).
11664
std::unique_lock<std::mutex> region_lock(region_m_);
11765

11866
const std::size_t max_team = team_size();
119-
if (max_team == 1) {
67+
if (max_team <= 1) {
12068
fn(std::size_t{0});
12169
return;
12270
}
@@ -137,27 +85,23 @@ class TeamThreadPool {
13785
}
13886

13987
// Publish region state BEFORE bumping epoch.
140-
remaining_.store(n - 1, std::memory_order_release);
88+
remaining_.store(n - 1, std::memory_order_release); // workers only
14189
region_n_.store(n, std::memory_order_release);
14290
region_ctx_.store(static_cast<void*>(&fn_copy), std::memory_order_release);
14391
region_call_.store(&call_impl<Fn>, std::memory_order_release);
14492

93+
// Bump epoch to start the region, then wake workers.
14594
const std::size_t new_epoch =
14695
epoch_.fetch_add(1, std::memory_order_acq_rel) + 1;
14796

148-
// std::fprintf(stderr,
149-
// "\n[TeamThreadPool(launch)] epoch=%zu n=%zu expected_workers=%zu team_size=%zu\n",
150-
// new_epoch, n, n - 1, team_size());
151-
// std::fflush(stderr);
152-
153-
// Wake workers using wake generation (prevents missed wakeups).
15497
{
15598
std::lock_guard<std::mutex> lk(wake_m_);
156-
++wake_gen_;
99+
// epoch_ already updated; the mutex pairs with the cv wait.
100+
(void)new_epoch;
157101
}
158102
wake_cv_.notify_all();
159103

160-
// Caller participates (tid=0).
104+
// Caller participates as tid=0.
161105
in_worker_ = true;
162106
try {
163107
fn_copy(0);
@@ -169,25 +113,9 @@ class TeamThreadPool {
169113

170114
// Wait for workers 1..n-1.
171115
std::unique_lock<std::mutex> lk(done_m_);
172-
#if defined(STAN_TEAM_POOL_DEBUG_WAIT)
173-
auto last_dump = std::chrono::steady_clock::now();
174-
while (remaining_.load(std::memory_order_acquire) != 0) {
175-
done_cv_.wait_for(lk, std::chrono::milliseconds(250));
176-
const auto now = std::chrono::steady_clock::now();
177-
if (now - last_dump > std::chrono::seconds(2)
178-
&& remaining_.load(std::memory_order_acquire) != 0) {
179-
std::fprintf(stderr,
180-
"[TeamThreadPool] waiting too long for epoch=%zu (remaining=%zu)\n",
181-
new_epoch, remaining_.load(std::memory_order_acquire));
182-
dump_state("TeamThreadPool(wait)");
183-
last_dump = now;
184-
}
185-
}
186-
#else
187116
done_cv_.wait(lk, [&] {
188117
return remaining_.load(std::memory_order_acquire) == 0;
189118
});
190-
#endif
191119

192120
// Hygiene.
193121
region_n_.store(0, std::memory_order_release);
@@ -204,7 +132,7 @@ class TeamThreadPool {
204132
}
205133

206134
static std::atomic<std::size_t>& user_cap_() {
207-
static std::atomic<std::size_t> cap{0};
135+
static std::atomic<std::size_t> cap{0}; // 0 => unset
208136
return cap;
209137
}
210138

@@ -226,11 +154,6 @@ class TeamThreadPool {
226154
return cap;
227155
}
228156

229-
std::size_t wake_gen_snapshot_() const {
230-
std::lock_guard<std::mutex> lk(wake_m_);
231-
return wake_gen_;
232-
}
233-
234157
TeamThreadPool()
235158
: stop_(false),
236159
epoch_(0),
@@ -239,7 +162,6 @@ class TeamThreadPool {
239162
region_call_(nullptr),
240163
remaining_(0),
241164
exc_ptr_(nullptr),
242-
wake_gen_(0),
243165
ready_count_(0) {
244166
unsigned hw_u = std::thread::hardware_concurrency();
245167
if (hw_u == 0) hw_u = 2;
@@ -248,94 +170,52 @@ class TeamThreadPool {
248170
const std::size_t cap = configured_cap_(hw);
249171
const std::size_t num_workers = (cap > 1) ? (cap - 1) : 0;
250172

251-
worker_state_size_ = cap;
252-
253-
// raw arrays so atomics aren't moved
254-
worker_alive_.reset(new std::atomic<unsigned>[cap]);
255-
worker_seen_epoch_.reset(new std::atomic<std::size_t>[cap]);
256-
worker_exec_epoch_.reset(new std::atomic<std::size_t>[cap]);
257-
worker_decrement_count_.reset(new std::atomic<std::size_t>[cap]);
258-
259-
for (std::size_t i = 0; i < cap; ++i) {
260-
worker_alive_[i].store(0u, std::memory_order_relaxed);
261-
worker_seen_epoch_[i].store(0, std::memory_order_relaxed);
262-
worker_exec_epoch_[i].store(0, std::memory_order_relaxed);
263-
worker_decrement_count_[i].store(0, std::memory_order_relaxed);
264-
}
265-
266-
std::fprintf(stderr,
267-
"[TeamThreadPool(ctor)] cap=%zu (workers=%zu) hw=%zu\n",
268-
cap, num_workers, hw);
269-
std::fflush(stderr);
270-
271173
workers_.reserve(num_workers);
272174
for (std::size_t i = 0; i < num_workers; ++i) {
273-
const std::size_t tid = i + 1;
175+
const std::size_t tid = i + 1; // workers are 1..num_workers
274176
workers_.emplace_back([this, tid] {
177+
// Per-worker AD tape initialized once.
275178
static thread_local ChainableStack ad_tape;
276-
in_worker_ = true;
277-
278-
if (tid < worker_state_size_) {
279-
worker_alive_[tid].store(1u, std::memory_order_release);
280-
}
179+
(void)ad_tape;
281180

282-
// "Arm" this worker on the current wake generation so it can't miss
283-
// the first region wake.
284-
std::size_t local_gen;
285-
{
286-
std::unique_lock<std::mutex> lk(wake_m_);
287-
local_gen = wake_gen_;
288-
}
181+
in_worker_ = true;
289182

290-
// Signal readiness AFTER arming.
291-
ready_count_.fetch_add(1, std::memory_order_acq_rel);
183+
// Startup barrier: ensure each worker has entered the wait loop once.
292184
{
293-
std::lock_guard<std::mutex> lk(ready_m_);
185+
std::lock_guard<std::mutex> lk(wake_m_);
186+
ready_count_.fetch_add(1, std::memory_order_acq_rel);
294187
}
295188
ready_cv_.notify_one();
296189

190+
std::size_t seen_epoch = epoch_.load(std::memory_order_acquire);
191+
297192
for (;;) {
298-
// Wait for wake_gen_ to change (or stop_).
193+
// Wait for a new epoch (or stop).
299194
{
300195
std::unique_lock<std::mutex> lk(wake_m_);
301196
wake_cv_.wait(lk, [&] {
302197
return stop_.load(std::memory_order_acquire)
303-
|| wake_gen_ != local_gen;
198+
|| epoch_.load(std::memory_order_acquire) != seen_epoch;
304199
});
305200
if (stop_.load(std::memory_order_acquire)) break;
306-
local_gen = wake_gen_;
307-
}
308-
309-
// Observe epoch and region parameters after wake.
310-
const std::size_t e = epoch_.load(std::memory_order_acquire);
311-
312-
if (tid < worker_state_size_) {
313-
worker_seen_epoch_[tid].store(e, std::memory_order_release);
201+
seen_epoch = epoch_.load(std::memory_order_acquire);
314202
}
315203

316204
const std::size_t n = region_n_.load(std::memory_order_acquire);
317-
if (tid >= n) {
318-
continue;
319-
}
205+
if (tid >= n) continue; // not participating this region
320206

321207
// Always decrement once for participating workers.
322208
struct DoneGuard {
323209
std::atomic<std::size_t>& rem;
324210
std::mutex& m;
325211
std::condition_variable& cv;
326-
std::atomic<std::size_t>& dec_count;
327212
~DoneGuard() {
328-
dec_count.fetch_add(1, std::memory_order_relaxed);
329213
if (rem.fetch_sub(1, std::memory_order_acq_rel) == 1) {
330214
std::lock_guard<std::mutex> lk(m);
331215
cv.notify_one();
332216
}
333217
}
334-
} guard{remaining_, done_m_, done_cv_, worker_decrement_count_[tid]};
335-
336-
if (tid < worker_state_size_) {
337-
worker_exec_epoch_[tid].store(e, std::memory_order_release);
338-
}
218+
} guard{remaining_, done_m_, done_cv_};
339219

340220
void* ctx = region_ctx_.load(std::memory_order_acquire);
341221
call_fn_t call = region_call_.load(std::memory_order_acquire);
@@ -350,34 +230,27 @@ class TeamThreadPool {
350230
}
351231
}
352232
}
353-
354-
if (tid < worker_state_size_) {
355-
worker_exec_epoch_[tid].store(0, std::memory_order_release);
356-
}
357233
}
358234

359235
in_worker_ = false;
360236
});
361237
}
362238

363-
// Startup barrier: ensure all workers are armed and waiting-ready.
239+
// Wait for all workers to reach the wait loop once before returning.
364240
{
365-
std::unique_lock<std::mutex> lk(ready_m_);
241+
std::unique_lock<std::mutex> lk(wake_m_);
366242
ready_cv_.wait(lk, [&] {
367243
return ready_count_.load(std::memory_order_acquire) == workers_.size();
368244
});
369245
}
370-
std::fprintf(stderr, "[TeamThreadPool(ctor)] all workers ready: %zu\n",
371-
workers_.size());
372-
std::fflush(stderr);
373246
}
374247

375248
~TeamThreadPool() {
376249
stop_.store(true, std::memory_order_release);
377250
{
378-
// bump wake_gen_ so workers wake and see stop_
379251
std::lock_guard<std::mutex> lk(wake_m_);
380-
++wake_gen_;
252+
// bump epoch to ensure wake predicate flips
253+
epoch_.fetch_add(1, std::memory_order_acq_rel);
381254
}
382255
wake_cv_.notify_all();
383256

@@ -400,13 +273,11 @@ class TeamThreadPool {
400273
std::atomic<void*> region_ctx_;
401274
std::atomic<call_fn_t> region_call_;
402275

403-
// Wake workers (wake_gen_ is protected by wake_m_).
404-
mutable std::mutex wake_m_;
276+
// Wake workers.
277+
std::mutex wake_m_;
405278
std::condition_variable wake_cv_;
406-
std::size_t wake_gen_;
407279

408280
// Startup barrier.
409-
std::mutex ready_m_;
410281
std::condition_variable ready_cv_;
411282
std::atomic<std::size_t> ready_count_;
412283

@@ -418,13 +289,6 @@ class TeamThreadPool {
418289
// Exceptions.
419290
std::mutex exc_m_;
420291
std::exception_ptr* exc_ptr_;
421-
422-
// Debug state (arrays of atomics to avoid std::vector<atomic<...>> move issues).
423-
std::size_t worker_state_size_{0};
424-
std::unique_ptr<std::atomic<unsigned>[]> worker_alive_;
425-
std::unique_ptr<std::atomic<std::size_t>[]> worker_seen_epoch_;
426-
std::unique_ptr<std::atomic<std::size_t>[]> worker_exec_epoch_;
427-
std::unique_ptr<std::atomic<std::size_t>[]> worker_decrement_count_;
428292
};
429293

430294
} // namespace math

0 commit comments

Comments
 (0)