Skip to content

Commit 79a00dd

Browse files
committed
map_rect
1 parent e4a9a63 commit 79a00dd

1 file changed

Lines changed: 20 additions & 33 deletions

File tree

stan/math/rev/functor/map_rect_concurrent.hpp

Lines changed: 20 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
#include <stan/math/prim/functor/map_rect_reduce.hpp>
88
#include <stan/math/prim/functor/map_rect_combine.hpp>
99
#include <stan/math/rev/core/chainablestack.hpp>
10+
#include <stan/math/rev/core/team_thread_pool.hpp>
1011

11-
#include <tbb/parallel_for.h>
12-
#include <tbb/blocked_range.h>
12+
//#include <tbb/parallel_for.h>
13+
//#include <tbb/blocked_range.h>
1314

1415
#include <algorithm>
1516
#include <numeric>
@@ -34,7 +35,7 @@ map_rect_concurrent(
3435
= map_rect_reduce<F, scalar_type_t<T_shared_param>, T_job_param>;
3536
using CombineF = map_rect_combine<F, T_shared_param, T_job_param>;
3637

37-
const int num_jobs = job_params.size();
38+
const std::size_t num_jobs = job_params.size();
3839
const vector_d shared_params_dbl = value_of(shared_params);
3940
std::vector<matrix_d> job_output(num_jobs);
4041
std::vector<int> world_f_out(num_jobs, 0);
@@ -48,39 +49,25 @@ map_rect_concurrent(
4849
};
4950

5051
#ifdef STAN_THREADS
51-
std::cout << "********************************************************************************" << std::endl;
52-
if (num_jobs > 1) {
53-
// simple chunked threading over [0, num_jobs)
54-
unsigned hw_threads = std::thread::hardware_concurrency();
55-
if (hw_threads == 0) {
56-
hw_threads = 2; // arbitrary but > 0
57-
}
58-
59-
const unsigned max_threads
60-
= static_cast<unsigned>(std::min<std::size_t>(hw_threads, num_jobs));
61-
std::cout << "max_threads = " << max_threads << std::endl;
62-
std::vector<std::thread> threads;
63-
threads.reserve(max_threads);
64-
65-
const std::size_t chunk
66-
= (num_jobs + max_threads - 1) / max_threads; // ceil
67-
68-
for (unsigned t = 0; t < max_threads; ++t) {
69-
const std::size_t start = t * chunk;
70-
if (start >= num_jobs) break;
71-
const std::size_t end
72-
= std::min<std::size_t>(start + chunk, num_jobs);
52+
auto& pool = stan::math::TeamThreadPool::instance();
7353

74-
threads.emplace_back([&, start, end] {
75-
execute_chunk(start, end);
76-
});
77-
}
54+
// Total participants includes caller (tid=0).
55+
const std::size_t max_team = pool.team_size();
56+
const std::size_t n = std::min<std::size_t>(max_team,
57+
num_jobs == 0 ? 1u
58+
: num_jobs);
7859

79-
for (auto& th : threads) {
80-
th.join();
81-
}
82-
} else {
60+
if (n <= 1 || num_jobs <= 1) {
8361
execute_chunk(0, num_jobs);
62+
} else {
63+
pool.parallel_region(n, [&](std::size_t tid) {
64+
const std::size_t nj = num_jobs;
65+
const std::size_t b0 = (nj * tid) / n;
66+
const std::size_t b1 = (nj * (tid + 1)) / n;
67+
if (b0 < b1) {
68+
execute_chunk(b0, b1);
69+
}
70+
});
8471
}
8572
#else
8673
execute_chunk(0, num_jobs);

0 commit comments

Comments
 (0)