Skip to content

Commit a7a95fb

Browse files
committed
10x slower
1 parent 9f72cc3 commit a7a95fb

2 files changed

Lines changed: 278 additions & 323 deletions

File tree

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
#ifndef STAN_MATH_REV_CORE_SIMPLE_THREAD_POOL_HPP
2+
#define STAN_MATH_REV_CORE_SIMPLE_THREAD_POOL_HPP
3+
4+
#include <stan/math/rev/core/chainablestack.hpp>
5+
6+
#include <atomic>
7+
#include <condition_variable>
8+
#include <cstddef>
9+
#include <functional>
10+
#include <future>
11+
#include <mutex>
12+
#include <queue>
13+
#include <thread>
14+
#include <type_traits>
15+
#include <utility>
16+
#include <vector>
17+
18+
namespace stan {
19+
namespace math {
20+
21+
class SimpleThreadPool {
22+
public:
23+
static SimpleThreadPool& instance() {
24+
static SimpleThreadPool pool;
25+
return pool;
26+
}
27+
28+
SimpleThreadPool(const SimpleThreadPool&) = delete;
29+
SimpleThreadPool& operator=(const SimpleThreadPool&) = delete;
30+
31+
std::size_t thread_count() const noexcept { return workers_.size(); }
32+
33+
template <typename F, typename... Args>
34+
auto submit(F&& f, Args&&... args)
35+
-> std::future<std::invoke_result_t<F, Args...>> {
36+
using R = std::invoke_result_t<F, Args...>;
37+
38+
auto task_ptr = std::make_shared<std::packaged_task<R()>>(
39+
std::bind(std::forward<F>(f), std::forward<Args>(args)...));
40+
41+
enqueue_([task_ptr] { (*task_ptr)(); });
42+
return task_ptr->get_future();
43+
}
44+
45+
template <typename F>
46+
void parallel_region(std::size_t n, F&& fn) {
47+
if (n == 0) return;
48+
49+
// Avoid nested parallelism deadlocks/oversubscription.
50+
if (in_worker_) {
51+
fn(std::size_t{0});
52+
return;
53+
}
54+
55+
const std::size_t tc = thread_count();
56+
if (tc == 0) {
57+
fn(std::size_t{0});
58+
return;
59+
}
60+
if (n > tc) n = tc;
61+
62+
using Fn = std::decay_t<F>;
63+
struct Shared {
64+
std::atomic<std::size_t> remaining;
65+
std::mutex m;
66+
std::condition_variable cv;
67+
Fn fn;
68+
Shared(std::size_t n_, Fn&& f_) : remaining(n_), fn(std::move(f_)) {}
69+
};
70+
71+
auto shared = std::make_shared<Shared>(n, Fn(std::forward<F>(fn)));
72+
73+
for (std::size_t tid = 0; tid < n; ++tid) {
74+
enqueue_([shared, tid] {
75+
shared->fn(tid);
76+
if (shared->remaining.fetch_sub(1, std::memory_order_acq_rel) == 1) {
77+
std::lock_guard<std::mutex> lk(shared->m);
78+
shared->cv.notify_one();
79+
}
80+
});
81+
}
82+
83+
std::unique_lock<std::mutex> lk(shared->m);
84+
shared->cv.wait(lk, [&] {
85+
return shared->remaining.load(std::memory_order_acquire) == 0;
86+
});
87+
}
88+
89+
private:
90+
SimpleThreadPool() : done_(false) {
91+
unsigned hw = std::thread::hardware_concurrency();
92+
if (hw == 0) hw = 2;
93+
const unsigned num_threads = hw;
94+
95+
workers_.reserve(num_threads);
96+
for (unsigned i = 0; i < num_threads; ++i) {
97+
workers_.emplace_back([this] {
98+
// Per-worker AD tape (TLS) initialized once.
99+
static thread_local ChainableStack ad_tape;
100+
101+
for (;;) {
102+
std::function<void()> task;
103+
{
104+
std::unique_lock<std::mutex> lock(mtx_);
105+
cv_.wait(lock, [&] { return done_ || !tasks_.empty(); });
106+
if (done_ && tasks_.empty()) return;
107+
task = std::move(tasks_.front());
108+
tasks_.pop();
109+
}
110+
111+
WorkerScope scope; // sets in_worker_ for all tasks
112+
task();
113+
}
114+
});
115+
}
116+
}
117+
118+
~SimpleThreadPool() {
119+
{
120+
std::lock_guard<std::mutex> lock(mtx_);
121+
done_ = true;
122+
}
123+
cv_.notify_all();
124+
for (auto& th : workers_) {
125+
if (th.joinable()) th.join();
126+
}
127+
}
128+
129+
void enqueue_(std::function<void()> task) {
130+
{
131+
std::lock_guard<std::mutex> lock(mtx_);
132+
tasks_.emplace(std::move(task));
133+
}
134+
cv_.notify_one();
135+
}
136+
137+
struct WorkerScope {
138+
WorkerScope() : prev_(in_worker_) { in_worker_ = true; }
139+
~WorkerScope() { in_worker_ = prev_; }
140+
bool prev_;
141+
};
142+
143+
static inline thread_local bool in_worker_ = false;
144+
145+
std::vector<std::thread> workers_;
146+
std::queue<std::function<void()>> tasks_;
147+
std::mutex mtx_;
148+
std::condition_variable cv_;
149+
bool done_;
150+
};
151+
152+
} // namespace math
153+
} // namespace stan
154+
155+
#endif

0 commit comments

Comments
 (0)