66#include < atomic>
77#include < condition_variable>
88#include < cstddef>
9- #include < cstdio> // fprintf, fflush
109#include < cstdlib> // getenv, strtol
1110#include < exception> // exception_ptr
1211#include < mutex>
1514#include < utility>
1615#include < vector>
1716
18- // Debug mode: periodic dumps while waiting.
19- // #define STAN_TEAM_POOL_DEBUG_WAIT 1
20-
21- #if defined(STAN_TEAM_POOL_DEBUG_WAIT)
22- #include < chrono>
23- #endif
24-
2517namespace stan {
2618namespace math {
2719
2820/* *
29- * Team (epoch) thread pool for low-overhead parallel regions.
21+ * TeamThreadPool
3022 *
31- * Logical tids:
32- * - caller thread: tid=0
33- * - worker threads: tid=1..cap-1
23+ * - Fixed set of worker threads created once.
24+ * - Caller participates as logical tid=0.
25+ * - Worker threads have stable logical tids 1..(cap-1).
26+ * - parallel_region(n, fn): runs fn(tid) for tid in [0, n), exactly once each.
3427 *
35- * Key correctness points:
36- * - Single shared "region" state => serialize parallel_region() with region_m_.
37- * - Wake generation (wake_gen_) protected by wake_m_ to prevent missed wakeups.
38- * - Startup barrier ensures all workers are "armed" on wake_gen_ before first use,
39- * preventing late-start workers from missing the first region.
40- *
41- * Debug fields per logical tid:
42- * - alive[tid] : 1 once the worker thread started (0 means never started / died early)
43- * - seen[tid] : last epoch observed by that worker (0 means never saw any region)
44- * - exec[tid] : epoch currently executing (0 means not executing region work)
45- * - dec[tid] : count of decrements performed by that worker
28+ * Notes:
29+ * - Nested parallel_region calls from a worker run serially to avoid deadlock.
30+ * - Uses an epoch counter + condition_variable to wake workers per region.
31+ * - Startup barrier ensures all workers are waiting before the first region launch.
4632 */
4733class TeamThreadPool {
4834 public:
@@ -64,59 +50,21 @@ class TeamThreadPool {
6450 std::size_t worker_count () const noexcept { return workers_.size (); }
6551 std::size_t team_size () const noexcept { return workers_.size () + 1 ; }
6652
67- void dump_state (const char * tag = " TeamThreadPool" ) const {
68- const auto epoch = epoch_.load (std::memory_order_acquire);
69- const auto n = region_n_.load (std::memory_order_acquire);
70- const auto rem = remaining_.load (std::memory_order_acquire);
71-
72- std::fprintf (stderr,
73- " \n [%s] epoch=%zu region_n=%zu remaining=%zu team_size=%zu wake_gen=%zu ready=%zu/%zu\n " ,
74- tag, epoch, n, rem, team_size (), wake_gen_snapshot_ (),
75- ready_count_.load (std::memory_order_acquire),
76- workers_.size ());
77- std::fprintf (stderr, " tid | alive | seen | exec | dec | note\n " );
78- std::fprintf (stderr, " ----+-------+-------+-------+-----+-----------------------------\n " );
79-
80- for (std::size_t tid = 1 ; tid < worker_state_size_; ++tid) {
81- const unsigned alive = worker_alive_[tid].load (std::memory_order_acquire);
82- const std::size_t seen = worker_seen_epoch_[tid].load (std::memory_order_acquire);
83- const std::size_t exec = worker_exec_epoch_[tid].load (std::memory_order_acquire);
84- const std::size_t dec = worker_decrement_count_[tid].load (std::memory_order_acquire);
85-
86- const char * note = " " ;
87- if (!alive) {
88- note = " NOT ALIVE" ;
89- } else if (tid < n) {
90- if (seen < epoch) note = " participating but hasn't seen epoch" ;
91- else if (exec == epoch) note = " executing" ;
92- else if (exec != 0 ) note = " executing (old epoch?)" ;
93- else note = " idle/finished" ;
94- } else {
95- if (seen == epoch) note = " saw epoch but not participating" ;
96- else note = " not participating" ;
97- }
98-
99- std::fprintf (stderr, " %3zu | %u | %5zu | %5zu | %3zu | %s\n " ,
100- tid, alive, seen, exec, dec, note);
101- }
102- std::fflush (stderr);
103- }
104-
10553 template <typename F>
10654 void parallel_region (std::size_t n, F&& fn) {
10755 if (n == 0 ) return ;
10856
109- // Nested parallelism guard .
57+ // Prevent nested parallelism from deadlocking the pool .
11058 if (in_worker_) {
11159 fn (std::size_t {0 });
11260 return ;
11361 }
11462
115- // Single shared region state => serialize launches .
63+ // Only one active region at a time (shared region state) .
11664 std::unique_lock<std::mutex> region_lock (region_m_);
11765
11866 const std::size_t max_team = team_size ();
119- if (max_team = = 1 ) {
67+ if (max_team < = 1 ) {
12068 fn (std::size_t {0 });
12169 return ;
12270 }
@@ -137,27 +85,23 @@ class TeamThreadPool {
13785 }
13886
13987 // Publish region state BEFORE bumping epoch.
140- remaining_.store (n - 1 , std::memory_order_release);
88+ remaining_.store (n - 1 , std::memory_order_release); // workers only
14189 region_n_.store (n, std::memory_order_release);
14290 region_ctx_.store (static_cast <void *>(&fn_copy), std::memory_order_release);
14391 region_call_.store (&call_impl<Fn>, std::memory_order_release);
14492
93+ // Bump epoch to start the region, then wake workers.
14594 const std::size_t new_epoch =
14695 epoch_.fetch_add (1 , std::memory_order_acq_rel) + 1 ;
14796
148- // std::fprintf(stderr,
149- // "\n[TeamThreadPool(launch)] epoch=%zu n=%zu expected_workers=%zu team_size=%zu\n",
150- // new_epoch, n, n - 1, team_size());
151- // std::fflush(stderr);
152-
153- // Wake workers using wake generation (prevents missed wakeups).
15497 {
15598 std::lock_guard<std::mutex> lk (wake_m_);
156- ++wake_gen_;
99+ // epoch_ already updated; the mutex pairs with the cv wait.
100+ (void )new_epoch;
157101 }
158102 wake_cv_.notify_all ();
159103
160- // Caller participates ( tid=0) .
104+ // Caller participates as tid=0.
161105 in_worker_ = true ;
162106 try {
163107 fn_copy (0 );
@@ -169,25 +113,9 @@ class TeamThreadPool {
169113
170114 // Wait for workers 1..n-1.
171115 std::unique_lock<std::mutex> lk (done_m_);
172- #if defined(STAN_TEAM_POOL_DEBUG_WAIT)
173- auto last_dump = std::chrono::steady_clock::now ();
174- while (remaining_.load (std::memory_order_acquire) != 0 ) {
175- done_cv_.wait_for (lk, std::chrono::milliseconds (250 ));
176- const auto now = std::chrono::steady_clock::now ();
177- if (now - last_dump > std::chrono::seconds (2 )
178- && remaining_.load (std::memory_order_acquire) != 0 ) {
179- std::fprintf (stderr,
180- " [TeamThreadPool] waiting too long for epoch=%zu (remaining=%zu)\n " ,
181- new_epoch, remaining_.load (std::memory_order_acquire));
182- dump_state (" TeamThreadPool(wait)" );
183- last_dump = now;
184- }
185- }
186- #else
187116 done_cv_.wait (lk, [&] {
188117 return remaining_.load (std::memory_order_acquire) == 0 ;
189118 });
190- #endif
191119
192120 // Hygiene.
193121 region_n_.store (0 , std::memory_order_release);
@@ -204,7 +132,7 @@ class TeamThreadPool {
204132 }
205133
206134 static std::atomic<std::size_t >& user_cap_ () {
207- static std::atomic<std::size_t > cap{0 };
135+ static std::atomic<std::size_t > cap{0 }; // 0 => unset
208136 return cap;
209137 }
210138
@@ -226,11 +154,6 @@ class TeamThreadPool {
226154 return cap;
227155 }
228156
229- std::size_t wake_gen_snapshot_ () const {
230- std::lock_guard<std::mutex> lk (wake_m_);
231- return wake_gen_;
232- }
233-
234157 TeamThreadPool ()
235158 : stop_(false ),
236159 epoch_ (0 ),
@@ -239,7 +162,6 @@ class TeamThreadPool {
239162 region_call_(nullptr ),
240163 remaining_(0 ),
241164 exc_ptr_(nullptr ),
242- wake_gen_(0 ),
243165 ready_count_(0 ) {
244166 unsigned hw_u = std::thread::hardware_concurrency ();
245167 if (hw_u == 0 ) hw_u = 2 ;
@@ -248,94 +170,52 @@ class TeamThreadPool {
248170 const std::size_t cap = configured_cap_ (hw);
249171 const std::size_t num_workers = (cap > 1 ) ? (cap - 1 ) : 0 ;
250172
251- worker_state_size_ = cap;
252-
253- // raw arrays so atomics aren't moved
254- worker_alive_.reset (new std::atomic<unsigned >[cap]);
255- worker_seen_epoch_.reset (new std::atomic<std::size_t >[cap]);
256- worker_exec_epoch_.reset (new std::atomic<std::size_t >[cap]);
257- worker_decrement_count_.reset (new std::atomic<std::size_t >[cap]);
258-
259- for (std::size_t i = 0 ; i < cap; ++i) {
260- worker_alive_[i].store (0u , std::memory_order_relaxed);
261- worker_seen_epoch_[i].store (0 , std::memory_order_relaxed);
262- worker_exec_epoch_[i].store (0 , std::memory_order_relaxed);
263- worker_decrement_count_[i].store (0 , std::memory_order_relaxed);
264- }
265-
266- std::fprintf (stderr,
267- " [TeamThreadPool(ctor)] cap=%zu (workers=%zu) hw=%zu\n " ,
268- cap, num_workers, hw);
269- std::fflush (stderr);
270-
271173 workers_.reserve (num_workers);
272174 for (std::size_t i = 0 ; i < num_workers; ++i) {
273- const std::size_t tid = i + 1 ;
175+ const std::size_t tid = i + 1 ; // workers are 1..num_workers
274176 workers_.emplace_back ([this , tid] {
177+ // Per-worker AD tape initialized once.
275178 static thread_local ChainableStack ad_tape;
276- in_worker_ = true ;
277-
278- if (tid < worker_state_size_) {
279- worker_alive_[tid].store (1u , std::memory_order_release);
280- }
179+ (void )ad_tape;
281180
282- // "Arm" this worker on the current wake generation so it can't miss
283- // the first region wake.
284- std::size_t local_gen;
285- {
286- std::unique_lock<std::mutex> lk (wake_m_);
287- local_gen = wake_gen_;
288- }
181+ in_worker_ = true ;
289182
290- // Signal readiness AFTER arming.
291- ready_count_.fetch_add (1 , std::memory_order_acq_rel);
183+ // Startup barrier: ensure each worker has entered the wait loop once.
292184 {
293- std::lock_guard<std::mutex> lk (ready_m_);
185+ std::lock_guard<std::mutex> lk (wake_m_);
186+ ready_count_.fetch_add (1 , std::memory_order_acq_rel);
294187 }
295188 ready_cv_.notify_one ();
296189
190+ std::size_t seen_epoch = epoch_.load (std::memory_order_acquire);
191+
297192 for (;;) {
298- // Wait for wake_gen_ to change (or stop_ ).
193+ // Wait for a new epoch (or stop ).
299194 {
300195 std::unique_lock<std::mutex> lk (wake_m_);
301196 wake_cv_.wait (lk, [&] {
302197 return stop_.load (std::memory_order_acquire)
303- || wake_gen_ != local_gen ;
198+ || epoch_. load (std::memory_order_acquire) != seen_epoch ;
304199 });
305200 if (stop_.load (std::memory_order_acquire)) break ;
306- local_gen = wake_gen_;
307- }
308-
309- // Observe epoch and region parameters after wake.
310- const std::size_t e = epoch_.load (std::memory_order_acquire);
311-
312- if (tid < worker_state_size_) {
313- worker_seen_epoch_[tid].store (e, std::memory_order_release);
201+ seen_epoch = epoch_.load (std::memory_order_acquire);
314202 }
315203
316204 const std::size_t n = region_n_.load (std::memory_order_acquire);
317- if (tid >= n) {
318- continue ;
319- }
205+ if (tid >= n) continue ; // not participating this region
320206
321207 // Always decrement once for participating workers.
322208 struct DoneGuard {
323209 std::atomic<std::size_t >& rem;
324210 std::mutex& m;
325211 std::condition_variable& cv;
326- std::atomic<std::size_t >& dec_count;
327212 ~DoneGuard () {
328- dec_count.fetch_add (1 , std::memory_order_relaxed);
329213 if (rem.fetch_sub (1 , std::memory_order_acq_rel) == 1 ) {
330214 std::lock_guard<std::mutex> lk (m);
331215 cv.notify_one ();
332216 }
333217 }
334- } guard{remaining_, done_m_, done_cv_, worker_decrement_count_[tid]};
335-
336- if (tid < worker_state_size_) {
337- worker_exec_epoch_[tid].store (e, std::memory_order_release);
338- }
218+ } guard{remaining_, done_m_, done_cv_};
339219
340220 void * ctx = region_ctx_.load (std::memory_order_acquire);
341221 call_fn_t call = region_call_.load (std::memory_order_acquire);
@@ -350,34 +230,27 @@ class TeamThreadPool {
350230 }
351231 }
352232 }
353-
354- if (tid < worker_state_size_) {
355- worker_exec_epoch_[tid].store (0 , std::memory_order_release);
356- }
357233 }
358234
359235 in_worker_ = false ;
360236 });
361237 }
362238
363- // Startup barrier: ensure all workers are armed and waiting-ready .
239+ // Wait for all workers to reach the wait loop once before returning .
364240 {
365- std::unique_lock<std::mutex> lk (ready_m_ );
241+ std::unique_lock<std::mutex> lk (wake_m_ );
366242 ready_cv_.wait (lk, [&] {
367243 return ready_count_.load (std::memory_order_acquire) == workers_.size ();
368244 });
369245 }
370- std::fprintf (stderr, " [TeamThreadPool(ctor)] all workers ready: %zu\n " ,
371- workers_.size ());
372- std::fflush (stderr);
373246 }
374247
375248 ~TeamThreadPool () {
376249 stop_.store (true , std::memory_order_release);
377250 {
378- // bump wake_gen_ so workers wake and see stop_
379251 std::lock_guard<std::mutex> lk (wake_m_);
380- ++wake_gen_;
252+ // bump epoch to ensure wake predicate flips
253+ epoch_.fetch_add (1 , std::memory_order_acq_rel);
381254 }
382255 wake_cv_.notify_all ();
383256
@@ -400,13 +273,11 @@ class TeamThreadPool {
400273 std::atomic<void *> region_ctx_;
401274 std::atomic<call_fn_t > region_call_;
402275
403- // Wake workers (wake_gen_ is protected by wake_m_) .
404- mutable std::mutex wake_m_;
276+ // Wake workers.
277+ std::mutex wake_m_;
405278 std::condition_variable wake_cv_;
406- std::size_t wake_gen_;
407279
408280 // Startup barrier.
409- std::mutex ready_m_;
410281 std::condition_variable ready_cv_;
411282 std::atomic<std::size_t > ready_count_;
412283
@@ -418,13 +289,6 @@ class TeamThreadPool {
418289 // Exceptions.
419290 std::mutex exc_m_;
420291 std::exception_ptr* exc_ptr_;
421-
422- // Debug state (arrays of atomics to avoid std::vector<atomic<...>> move issues).
423- std::size_t worker_state_size_{0 };
424- std::unique_ptr<std::atomic<unsigned >[]> worker_alive_;
425- std::unique_ptr<std::atomic<std::size_t >[]> worker_seen_epoch_;
426- std::unique_ptr<std::atomic<std::size_t >[]> worker_exec_epoch_;
427- std::unique_ptr<std::atomic<std::size_t >[]> worker_decrement_count_;
428292};
429293
430294} // namespace math
0 commit comments