Skip to content

Commit 5e78da0

Browse files
committed
1.2x
1 parent a7a95fb commit 5e78da0

2 files changed

Lines changed: 315 additions & 113 deletions

File tree

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
#ifndef STAN_MATH_REV_CORE_TEAM_THREAD_POOL_HPP
2+
#define STAN_MATH_REV_CORE_TEAM_THREAD_POOL_HPP
3+
4+
#include <stan/math/rev/core/chainablestack.hpp>
5+
6+
#include <atomic>
7+
#include <condition_variable>
8+
#include <cstddef>
9+
#include <mutex>
10+
#include <thread>
11+
#include <utility>
12+
#include <vector>
13+
14+
namespace stan {
15+
namespace math {
16+
17+
/**
18+
* Team (epoch) thread pool for low-overhead parallel regions.
19+
*
20+
* - Creates (hw-1) worker threads once.
21+
* - Caller participates with tid=0.
22+
* - parallel_region(n, fn): runs fn(tid) for tid in [0, n).
23+
* - Nested parallelism: if called from a worker thread, runs serial.
24+
*
25+
* Designed for reduce_sum/map_rect style internal parallelism.
26+
*/
27+
class TeamThreadPool {
28+
public:
29+
static TeamThreadPool& instance() {
30+
static TeamThreadPool pool;
31+
return pool;
32+
}
33+
34+
TeamThreadPool(const TeamThreadPool&) = delete;
35+
TeamThreadPool& operator=(const TeamThreadPool&) = delete;
36+
37+
// Number of worker threads (excluding caller)
38+
std::size_t worker_count() const noexcept { return workers_.size(); }
39+
40+
// Total participants available = worker_count + 1 (caller)
41+
std::size_t team_size() const noexcept { return workers_.size() + 1; }
42+
43+
template <typename F>
44+
void parallel_region(std::size_t n, F&& fn) {
45+
if (n == 0) return;
46+
47+
// If called from a worker, run serial to avoid nested deadlocks.
48+
if (in_worker_) {
49+
fn(std::size_t{0});
50+
return;
51+
}
52+
53+
const std::size_t max_team = team_size();
54+
if (max_team == 1) {
55+
fn(std::size_t{0});
56+
return;
57+
}
58+
if (n > max_team) n = max_team;
59+
if (n == 1) {
60+
fn(std::size_t{0});
61+
return;
62+
}
63+
64+
// Stable storage for callable during this region
65+
using Fn = std::decay_t<F>;
66+
Fn fn_copy = std::forward<F>(fn);
67+
68+
// Publish region
69+
remaining_.store(n - 1, std::memory_order_release); // workers only
70+
region_n_.store(n, std::memory_order_release);
71+
region_ctx_.store(static_cast<void*>(&fn_copy), std::memory_order_release);
72+
region_call_.store(&call_impl<Fn>, std::memory_order_release);
73+
74+
epoch_.fetch_add(1, std::memory_order_acq_rel);
75+
76+
// Wake workers
77+
{
78+
std::lock_guard<std::mutex> lk(wake_m_);
79+
}
80+
wake_cv_.notify_all();
81+
82+
// Caller participates as tid=0
83+
in_worker_ = true;
84+
fn_copy(0);
85+
in_worker_ = false;
86+
87+
// Wait for workers 1..n-1
88+
std::unique_lock<std::mutex> lk(done_m_);
89+
done_cv_.wait(lk, [&] {
90+
return remaining_.load(std::memory_order_acquire) == 0;
91+
});
92+
}
93+
94+
private:
95+
using call_fn_t = void (*)(void*, std::size_t);
96+
97+
template <typename Fn>
98+
static void call_impl(void* ctx, std::size_t tid) {
99+
(*static_cast<Fn*>(ctx))(tid);
100+
}
101+
102+
TeamThreadPool()
103+
: stop_(false), epoch_(0), region_n_(0), region_ctx_(nullptr),
104+
region_call_(nullptr), remaining_(0) {
105+
unsigned hw = std::thread::hardware_concurrency();
106+
if (hw == 0) hw = 2;
107+
108+
// hw-1 worker threads; caller is +1 participant.
109+
const unsigned num_workers = (hw > 1) ? (hw - 1) : 1;
110+
111+
workers_.reserve(num_workers);
112+
for (unsigned i = 0; i < num_workers; ++i) {
113+
const std::size_t tid = static_cast<std::size_t>(i + 1); // workers are 1..N
114+
workers_.emplace_back([this, tid] {
115+
// Per-worker AD tape initialized once
116+
static thread_local ChainableStack ad_tape;
117+
in_worker_ = true;
118+
119+
std::size_t seen = epoch_.load(std::memory_order_acquire);
120+
for (;;) {
121+
// Sleep until epoch changes or stop requested
122+
{
123+
std::unique_lock<std::mutex> lk(wake_m_);
124+
wake_cv_.wait(lk, [&] {
125+
return stop_.load(std::memory_order_acquire)
126+
|| epoch_.load(std::memory_order_acquire) != seen;
127+
});
128+
}
129+
if (stop_.load(std::memory_order_acquire)) break;
130+
131+
const std::size_t e = epoch_.load(std::memory_order_acquire);
132+
seen = e;
133+
134+
const std::size_t n = region_n_.load(std::memory_order_acquire);
135+
if (tid >= n) {
136+
continue; // not participating this region
137+
}
138+
139+
void* ctx = region_ctx_.load(std::memory_order_acquire);
140+
call_fn_t call = region_call_.load(std::memory_order_acquire);
141+
if (call) {
142+
call(ctx, tid);
143+
}
144+
145+
if (remaining_.fetch_sub(1, std::memory_order_acq_rel) == 1) {
146+
std::lock_guard<std::mutex> lk(done_m_);
147+
done_cv_.notify_one();
148+
}
149+
}
150+
151+
in_worker_ = false;
152+
});
153+
}
154+
}
155+
156+
~TeamThreadPool() {
157+
stop_.store(true, std::memory_order_release);
158+
{
159+
std::lock_guard<std::mutex> lk(wake_m_);
160+
}
161+
wake_cv_.notify_all();
162+
for (auto& t : workers_) {
163+
if (t.joinable()) t.join();
164+
}
165+
}
166+
167+
static inline thread_local bool in_worker_ = false;
168+
169+
std::vector<std::thread> workers_;
170+
std::atomic<bool> stop_;
171+
172+
// Region publication
173+
std::atomic<std::size_t> epoch_;
174+
std::atomic<std::size_t> region_n_;
175+
std::atomic<void*> region_ctx_;
176+
std::atomic<call_fn_t> region_call_;
177+
178+
// Worker wake
179+
std::mutex wake_m_;
180+
std::condition_variable wake_cv_;
181+
182+
// Completion
183+
std::atomic<std::size_t> remaining_;
184+
std::mutex done_m_;
185+
std::condition_variable done_cv_;
186+
};
187+
188+
} // namespace math
189+
} // namespace stan
190+
191+
#endif

0 commit comments

Comments
 (0)