66#include < atomic>
77#include < condition_variable>
88#include < cstddef>
9+ #include < cstdlib>
10+ #include < exception>
911#include < mutex>
1012#include < thread>
13+ #include < type_traits>
1114#include < utility>
1215#include < vector>
1316
@@ -17,16 +20,16 @@ namespace math {
1720/* *
1821 * Team (epoch) thread pool for low-overhead parallel regions.
1922 *
20- * - Creates (hw-1) worker threads once.
21- * - Caller participates with tid=0.
23+ * - Workers are created once.
24+ * - Caller participates as tid=0.
2225 * - parallel_region(n, fn): runs fn(tid) for tid in [0, n).
2326 * - Nested parallelism: if called from a worker thread, runs serial.
24- *
25- * Designed for reduce_sum/map_rect style internal parallelism.
27+ * - set_num_threads(k) must be called before instance() to size the pool.
2628 */
2729class TeamThreadPool {
2830 public:
29- // Call this before first use of TeamThreadPool::instance()
31+ // Call before first instance() to control pool size.
32+ // Meaning: total participants INCLUDING caller (tid=0).
3033 static void set_num_threads (std::size_t n) noexcept {
3134 if (n < 1 ) n = 1 ;
3235 user_cap_ ().store (n, std::memory_order_release);
@@ -41,26 +44,25 @@ class TeamThreadPool {
4144 return pool;
4245 }
4346
44- TeamThreadPool (const TeamThreadPool&) = delete ;
45- TeamThreadPool& operator =(const TeamThreadPool&) = delete ;
46-
47- // Number of worker threads (excluding caller)
47+ // Worker threads (excluding caller)
4848 std::size_t worker_count () const noexcept { return workers_.size (); }
4949
50- // Total participants available = worker_count + 1 ( caller)
50+ // Total possible participants INCLUDING caller
5151 std::size_t team_size () const noexcept { return workers_.size () + 1 ; }
5252
5353 template <typename F>
5454 void parallel_region (std::size_t n, F&& fn) {
55- // std::cout << "#################### parallel_region, n = " << n << std::endl;
5655 if (n == 0 ) return ;
5756
58- // If called from a worker, run serial to avoid nested deadlocks.
59- // std::cout << "in_worker_ = " << in_worker_ << std::endl;
57+ // Nested parallelism guard: if already on a worker, run serial.
6058 if (in_worker_) {
6159 fn (std::size_t {0 });
6260 return ;
6361 }
62+
63+ // Only one active region at a time (this is required for a single shared epoch design).
64+ std::unique_lock<std::mutex> region_lock (region_m_);
65+
6466 const std::size_t max_team = team_size ();
6567 if (max_team == 1 ) {
6668 fn (std::size_t {0 });
@@ -72,11 +74,17 @@ class TeamThreadPool {
7274 return ;
7375 }
7476
75- // Stable storage for callable during this region
7677 using Fn = std::decay_t <F>;
77- Fn fn_copy = std::forward<F>(fn);
78+ Fn fn_copy = std::forward<F>(fn); // stable storage during this call
79+
80+ // Exception propagation: capture first exception from any participant.
81+ std::exception_ptr eptr = nullptr ;
82+ {
83+ std::lock_guard<std::mutex> lk (exc_m_);
84+ exc_ptr_ = &eptr;
85+ }
7886
79- // Publish region
87+ // Publish region state BEFORE bumping epoch.
8088 remaining_.store (n - 1 , std::memory_order_release); // workers only
8189 region_n_.store (n, std::memory_order_release);
8290 region_ctx_.store (static_cast <void *>(&fn_copy), std::memory_order_release);
@@ -92,104 +100,86 @@ class TeamThreadPool {
92100
93101 // Caller participates as tid=0
94102 in_worker_ = true ;
95- fn_copy (0 );
103+ try {
104+ fn_copy (0 );
105+ } catch (...) {
106+ std::lock_guard<std::mutex> lk (exc_m_);
107+ if (eptr == nullptr ) eptr = std::current_exception ();
108+ }
96109 in_worker_ = false ;
97110
98- // std::cout << "waiting for workers" << std::endl;
99111 // Wait for workers 1..n-1
100112 std::unique_lock<std::mutex> lk (done_m_);
101113 done_cv_.wait (lk, [&] {
102114 return remaining_.load (std::memory_order_acquire) == 0 ;
103115 });
104- // std::cout << "#################### done" << std::endl << std::endl;
105- }
106-
107- private:
108- // Function-local static avoids static init order fiasco.
109- static std::atomic<std::size_t >& user_cap_ () {
110- static std::atomic<std::size_t > cap{0 }; // 0 => "unset"
111- return cap;
112- }
113116
114- static std::size_t configured_cap_ (std::size_t hw) {
115- // priority: user cap > env var > hw
116- std::size_t cap = user_cap_ ().load (std::memory_order_acquire);
117- if (cap == 0 ) {
118- cap = env_num_threads_ (); // if you have STAN_NUM_THREADS support
119- }
120- if (cap == 0 ) cap = hw;
117+ // Clear region participation; not strictly necessary but good hygiene.
118+ region_n_.store (0 , std::memory_order_release);
121119
122- if (cap < 1 ) cap = 1 ;
123- if (cap > hw) cap = hw; // prevent oversubscription by default
124- return cap;
120+ // Rethrow exception (if any)
121+ if (eptr) std::rethrow_exception (eptr);
125122 }
126123
127-
124+ private:
128125 using call_fn_t = void (*)(void *, std::size_t );
129126
130127 template <typename Fn>
131128 static void call_impl (void * ctx, std::size_t tid) {
132129 (*static_cast <Fn*>(ctx))(tid);
133130 }
134131
135- static size_t env_num_threads_ () {
136- size_t num_threads = 1 ;
137- #ifdef STAN_THREADS
138- const char * env_stan_num_threads = std::getenv (" STAN_NUM_THREADS" );
139- if (env_stan_num_threads != nullptr ) {
140- try {
141- const int env_num_threads
142- = boost::lexical_cast<int >(env_stan_num_threads);
143- if (env_num_threads > 0 ) {
144- num_threads = env_num_threads;
145- } else if (env_num_threads == -1 ) {
146- num_threads = std::thread::hardware_concurrency ();
147- } else {
148- invalid_argument (" get_num_threads(int)" , " STAN_NUM_THREADS" ,
149- env_stan_num_threads,
150- " The STAN_NUM_THREADS environment variable is '" ,
151- " ' but it must be positive or -1" );
152- }
153- } catch (const boost::bad_lexical_cast&) {
154- invalid_argument (" get_num_threads(int)" , " STAN_NUM_THREADS" ,
155- env_stan_num_threads,
156- " The STAN_NUM_THREADS environment variable is '" ,
157- " ' but it must be a positive number or -1" );
158- }
159- }
160- #endif
161- return num_threads;
162- }
132+ // Function-local static avoids static initialization order issues.
133+ static std::atomic<std::size_t >& user_cap_ () {
134+ static std::atomic<std::size_t > cap{0 }; // 0 => unset
135+ return cap;
136+ }
163137
138+ static std::size_t env_num_threads_ () noexcept {
139+ const char * s = std::getenv (" STAN_NUM_THREADS" );
140+ if (!s || !*s) return 0 ;
141+ char * end = nullptr ;
142+ long v = std::strtol (s, &end, 10 );
143+ if (end == s || v <= 0 ) return 0 ;
144+ return static_cast <std::size_t >(v);
145+ }
164146
165- TeamThreadPool ()
166- : stop_(false ), epoch_(0 ), region_n_(0 ), region_ctx_(nullptr ),
167- region_call_ (nullptr ), remaining_(0 ) {
147+ static std::size_t configured_cap_ (std::size_t hw) noexcept {
148+ // Priority: explicit set_num_threads > STAN_NUM_THREADS > hw
149+ std::size_t cap = user_cap_ ().load (std::memory_order_acquire);
150+ if (cap == 0 ) cap = env_num_threads_ ();
151+ if (cap == 0 ) cap = hw;
168152
153+ if (cap < 1 ) cap = 1 ;
154+ if (cap > hw) cap = hw; // don’t oversubscribe by default
155+ return cap;
156+ }
157+
158+ TeamThreadPool ()
159+ : stop_(false ),
160+ epoch_ (0 ),
161+ region_n_(0 ),
162+ region_ctx_(nullptr ),
163+ region_call_(nullptr ),
164+ remaining_(0 ),
165+ exc_ptr_(nullptr ) {
169166 unsigned hw_u = std::thread::hardware_concurrency ();
170167 if (hw_u == 0 ) hw_u = 2 ;
171168 const std::size_t hw = static_cast <std::size_t >(hw_u);
172-
169+
170+ // Total participants includes caller.
173171 const std::size_t cap = configured_cap_ (hw);
174172 const std::size_t num_workers = (cap > 1 ) ? (cap - 1 ) : 0 ;
175173
176- std::cout << " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^" << std::endl
177- << " hw = " << hw << std::endl
178- << " num_workers = " << num_workers << std::endl
179- << " cap = " << cap << std::endl
180- << std::endl << std::endl;
181-
182174 workers_.reserve (num_workers);
183- for (unsigned i = 0 ; i < num_workers; ++i) {
184- const std::size_t tid = static_cast <std:: size_t >( i + 1 ) ; // workers are 1..N
175+ for (std:: size_t i = 0 ; i < num_workers; ++i) {
176+ const std::size_t tid = i + 1 ; // workers are 1..num_workers
185177 workers_.emplace_back ([this , tid] {
186- // Per-worker AD tape initialized once
187178 static thread_local ChainableStack ad_tape;
188179 in_worker_ = true ;
189180
190181 std::size_t seen = epoch_.load (std::memory_order_acquire);
191182 for (;;) {
192- // Sleep until epoch changes or stop requested
193183 {
194184 std::unique_lock<std::mutex> lk (wake_m_);
195185 wake_cv_.wait (lk, [&] {
@@ -203,28 +193,41 @@ class TeamThreadPool {
203193 seen = e;
204194
205195 const std::size_t n = region_n_.load (std::memory_order_acquire);
206- if (tid >= n) {
207- continue ; // not participating this region
208- }
196+ if (tid >= n) continue ; // not participating this region
197+
198+ // Ensure we ALWAYS decrement remaining_ once for participating workers.
199+ struct DoneGuard {
200+ std::atomic<std::size_t >& rem;
201+ std::mutex& m;
202+ std::condition_variable& cv;
203+ bool active{true };
204+ ~DoneGuard () {
205+ if (!active) return ;
206+ if (rem.fetch_sub (1 , std::memory_order_acq_rel) == 1 ) {
207+ std::lock_guard<std::mutex> lk (m);
208+ cv.notify_one ();
209+ }
210+ }
211+ } guard{remaining_, done_m_, done_cv_};
209212
210213 void * ctx = region_ctx_.load (std::memory_order_acquire);
211214 call_fn_t call = region_call_.load (std::memory_order_acquire);
212- if (call) {
213- call (ctx, tid);
214- }
215+ if (!call) continue ;
215216
216- if (remaining_.fetch_sub (1 , std::memory_order_acq_rel) == 1 ) {
217- std::lock_guard<std::mutex> lk (done_m_);
218- done_cv_.notify_one ();
217+ try {
218+ call (ctx, tid);
219+ } catch (...) {
220+ std::lock_guard<std::mutex> lk (exc_m_);
221+ if (exc_ptr_ && *exc_ptr_ == nullptr ) {
222+ *exc_ptr_ = std::current_exception ();
223+ }
219224 }
220225 }
221226
222227 in_worker_ = false ;
223228 });
224229 }
225- std::cout << " done with constructor" << std::endl;
226230 }
227-
228231
229232 ~TeamThreadPool () {
230233 stop_.store (true , std::memory_order_release);
@@ -242,6 +245,9 @@ class TeamThreadPool {
242245 std::vector<std::thread> workers_;
243246 std::atomic<bool > stop_;
244247
248+ // Serialize regions (single shared-region design)
249+ std::mutex region_m_;
250+
245251 // Region publication
246252 std::atomic<std::size_t > epoch_;
247253 std::atomic<std::size_t > region_n_;
@@ -256,6 +262,10 @@ class TeamThreadPool {
256262 std::atomic<std::size_t > remaining_;
257263 std::mutex done_m_;
258264 std::condition_variable done_cv_;
265+
266+ // Exception plumbing
267+ std::mutex exc_m_;
268+ std::exception_ptr* exc_ptr_;
259269};
260270
261271} // namespace math
0 commit comments