@@ -26,6 +26,16 @@ namespace math {
2626 */
2727class TeamThreadPool {
2828 public:
29+ // Call this before first use of TeamThreadPool::instance()
30+ static void set_num_threads (std::size_t n) noexcept {
31+ if (n < 1 ) n = 1 ;
32+ user_cap_ ().store (n, std::memory_order_release);
33+ }
34+
35+ static std::size_t get_num_threads () noexcept {
36+ return user_cap_ ().load (std::memory_order_acquire);
37+ }
38+
2939 static TeamThreadPool& instance () {
3040 static TeamThreadPool pool;
3141 return pool;
@@ -42,14 +52,15 @@ class TeamThreadPool {
4252
4353 template <typename F>
4454 void parallel_region (std::size_t n, F&& fn) {
55+ // std::cout << "#################### parallel_region, n = " << n << std::endl;
4556 if (n == 0 ) return ;
4657
4758 // If called from a worker, run serial to avoid nested deadlocks.
59+ // std::cout << "in_worker_ = " << in_worker_ << std::endl;
4860 if (in_worker_) {
4961 fn (std::size_t {0 });
5062 return ;
5163 }
52-
5364 const std::size_t max_team = team_size ();
5465 if (max_team == 1 ) {
5566 fn (std::size_t {0 });
@@ -84,30 +95,90 @@ class TeamThreadPool {
8495 fn_copy (0 );
8596 in_worker_ = false ;
8697
98+ // std::cout << "waiting for workers" << std::endl;
8799 // Wait for workers 1..n-1
88100 std::unique_lock<std::mutex> lk (done_m_);
89101 done_cv_.wait (lk, [&] {
90102 return remaining_.load (std::memory_order_acquire) == 0 ;
91103 });
104+ // std::cout << "#################### done" << std::endl << std::endl;
105+ }
106+
107+ private:
108+ // Function-local static avoids static init order fiasco.
109+ static std::atomic<std::size_t >& user_cap_ () {
110+ static std::atomic<std::size_t > cap{0 }; // 0 => "unset"
111+ return cap;
112+ }
113+
114+ static std::size_t configured_cap_ (std::size_t hw) {
115+ // priority: user cap > env var > hw
116+ std::size_t cap = user_cap_ ().load (std::memory_order_acquire);
117+ if (cap == 0 ) {
118+ cap = env_num_threads_ (); // if you have STAN_NUM_THREADS support
119+ }
120+ if (cap == 0 ) cap = hw;
121+
122+ if (cap < 1 ) cap = 1 ;
123+ if (cap > hw) cap = hw; // prevent oversubscription by default
124+ return cap;
92125 }
93126
94- private:
127+
95128 using call_fn_t = void (*)(void *, std::size_t );
96129
97130 template <typename Fn>
98131 static void call_impl (void * ctx, std::size_t tid) {
99132 (*static_cast <Fn*>(ctx))(tid);
100133 }
101134
135+ static size_t env_num_threads_ () {
136+ size_t num_threads = 1 ;
137+ #ifdef STAN_THREADS
138+ const char * env_stan_num_threads = std::getenv (" STAN_NUM_THREADS" );
139+ if (env_stan_num_threads != nullptr ) {
140+ try {
141+ const int env_num_threads
142+ = boost::lexical_cast<int >(env_stan_num_threads);
143+ if (env_num_threads > 0 ) {
144+ num_threads = env_num_threads;
145+ } else if (env_num_threads == -1 ) {
146+ num_threads = std::thread::hardware_concurrency ();
147+ } else {
148+ invalid_argument (" get_num_threads(int)" , " STAN_NUM_THREADS" ,
149+ env_stan_num_threads,
150+ " The STAN_NUM_THREADS environment variable is '" ,
151+ " ' but it must be positive or -1" );
152+ }
153+ } catch (const boost::bad_lexical_cast&) {
154+ invalid_argument (" get_num_threads(int)" , " STAN_NUM_THREADS" ,
155+ env_stan_num_threads,
156+ " The STAN_NUM_THREADS environment variable is '" ,
157+ " ' but it must be a positive number or -1" );
158+ }
159+ }
160+ #endif
161+ return num_threads;
162+ }
163+
164+
102165 TeamThreadPool ()
103166 : stop_(false ), epoch_(0 ), region_n_(0 ), region_ctx_(nullptr ),
104167 region_call_ (nullptr ), remaining_(0 ) {
105- unsigned hw = std::thread::hardware_concurrency ();
106- if (hw == 0 ) hw = 2 ;
107-
108- // hw-1 worker threads; caller is +1 participant.
109- const unsigned num_workers = (hw > 1 ) ? (hw - 1 ) : 1 ;
110168
169+ unsigned hw_u = std::thread::hardware_concurrency ();
170+ if (hw_u == 0 ) hw_u = 2 ;
171+ const std::size_t hw = static_cast <std::size_t >(hw_u);
172+
173+ const std::size_t cap = configured_cap_ (hw);
174+ const std::size_t num_workers = (cap > 1 ) ? (cap - 1 ) : 0 ;
175+
176+ std::cout << " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^" << std::endl
177+ << " hw = " << hw << std::endl
178+ << " num_workers = " << num_workers << std::endl
179+ << " cap = " << cap << std::endl
180+ << std::endl << std::endl;
181+
111182 workers_.reserve (num_workers);
112183 for (unsigned i = 0 ; i < num_workers; ++i) {
113184 const std::size_t tid = static_cast <std::size_t >(i + 1 ); // workers are 1..N
@@ -151,7 +222,9 @@ class TeamThreadPool {
151222 in_worker_ = false ;
152223 });
153224 }
225+ std::cout << " done with constructor" << std::endl;
154226 }
227+
155228
156229 ~TeamThreadPool () {
157230 stop_.store (true , std::memory_order_release);
0 commit comments