77#include < stan/math/prim/functor/map_rect_reduce.hpp>
88#include < stan/math/prim/functor/map_rect_combine.hpp>
99#include < stan/math/rev/core/chainablestack.hpp>
10+ #include < stan/math/rev/core/team_thread_pool.hpp>
1011
11- #include < tbb/parallel_for.h>
12- #include < tbb/blocked_range.h>
12+ // #include <tbb/parallel_for.h>
13+ // #include <tbb/blocked_range.h>
1314
1415#include < algorithm>
1516#include < numeric>
@@ -34,7 +35,7 @@ map_rect_concurrent(
3435 = map_rect_reduce<F, scalar_type_t <T_shared_param>, T_job_param>;
3536 using CombineF = map_rect_combine<F, T_shared_param, T_job_param>;
3637
37- const int num_jobs = job_params.size ();
38+ const std:: size_t num_jobs = job_params.size ();
3839 const vector_d shared_params_dbl = value_of (shared_params);
3940 std::vector<matrix_d> job_output (num_jobs);
4041 std::vector<int > world_f_out (num_jobs, 0 );
@@ -48,39 +49,25 @@ map_rect_concurrent(
4849 };
4950
5051#ifdef STAN_THREADS
51- std::cout << " ********************************************************************************" << std::endl;
52- if (num_jobs > 1 ) {
53- // simple chunked threading over [0, num_jobs)
54- unsigned hw_threads = std::thread::hardware_concurrency ();
55- if (hw_threads == 0 ) {
56- hw_threads = 2 ; // arbitrary but > 0
57- }
58-
59- const unsigned max_threads
60- = static_cast <unsigned >(std::min<std::size_t >(hw_threads, num_jobs));
61- std::cout << " max_threads = " << max_threads << std::endl;
62- std::vector<std::thread> threads;
63- threads.reserve (max_threads);
64-
65- const std::size_t chunk
66- = (num_jobs + max_threads - 1 ) / max_threads; // ceil
67-
68- for (unsigned t = 0 ; t < max_threads; ++t) {
69- const std::size_t start = t * chunk;
70- if (start >= num_jobs) break ;
71- const std::size_t end
72- = std::min<std::size_t >(start + chunk, num_jobs);
52+ auto & pool = stan::math::TeamThreadPool::instance ();
7353
74- threads.emplace_back ([&, start, end] {
75- execute_chunk (start, end);
76- });
77- }
54+ // Total participants includes caller (tid=0).
55+ const std::size_t max_team = pool.team_size ();
56+ const std::size_t n = std::min<std::size_t >(max_team,
57+ num_jobs == 0 ? 1u
58+ : num_jobs);
7859
79- for (auto & th : threads) {
80- th.join ();
81- }
82- } else {
60+ if (n <= 1 || num_jobs <= 1 ) {
8361 execute_chunk (0 , num_jobs);
62+ } else {
63+ pool.parallel_region (n, [&](std::size_t tid) {
64+ const std::size_t nj = num_jobs;
65+ const std::size_t b0 = (nj * tid) / n;
66+ const std::size_t b1 = (nj * (tid + 1 )) / n;
67+ if (b0 < b1) {
68+ execute_chunk (b0, b1);
69+ }
70+ });
8471 }
8572#else
8673 execute_chunk (0 , num_jobs);
0 commit comments