|
| 1 | +#ifndef STAN_MATH_REV_CORE_TEAM_THREAD_POOL_HPP |
| 2 | +#define STAN_MATH_REV_CORE_TEAM_THREAD_POOL_HPP |
| 3 | + |
| 4 | +#include <stan/math/rev/core/chainablestack.hpp> |
| 5 | + |
| 6 | +#include <atomic> |
| 7 | +#include <condition_variable> |
| 8 | +#include <cstddef> |
| 9 | +#include <mutex> |
| 10 | +#include <thread> |
| 11 | +#include <utility> |
| 12 | +#include <vector> |
| 13 | + |
| 14 | +namespace stan { |
| 15 | +namespace math { |
| 16 | + |
| 17 | +/** |
| 18 | + * Team (epoch) thread pool for low-overhead parallel regions. |
| 19 | + * |
| 20 | + * - Creates (hw-1) worker threads once. |
| 21 | + * - Caller participates with tid=0. |
| 22 | + * - parallel_region(n, fn): runs fn(tid) for tid in [0, n). |
| 23 | + * - Nested parallelism: if called from a worker thread, runs serial. |
| 24 | + * |
| 25 | + * Designed for reduce_sum/map_rect style internal parallelism. |
| 26 | + */ |
| 27 | +class TeamThreadPool { |
| 28 | + public: |
| 29 | + static TeamThreadPool& instance() { |
| 30 | + static TeamThreadPool pool; |
| 31 | + return pool; |
| 32 | + } |
| 33 | + |
| 34 | + TeamThreadPool(const TeamThreadPool&) = delete; |
| 35 | + TeamThreadPool& operator=(const TeamThreadPool&) = delete; |
| 36 | + |
| 37 | + // Number of worker threads (excluding caller) |
| 38 | + std::size_t worker_count() const noexcept { return workers_.size(); } |
| 39 | + |
| 40 | + // Total participants available = worker_count + 1 (caller) |
| 41 | + std::size_t team_size() const noexcept { return workers_.size() + 1; } |
| 42 | + |
| 43 | + template <typename F> |
| 44 | + void parallel_region(std::size_t n, F&& fn) { |
| 45 | + if (n == 0) return; |
| 46 | + |
| 47 | + // If called from a worker, run serial to avoid nested deadlocks. |
| 48 | + if (in_worker_) { |
| 49 | + fn(std::size_t{0}); |
| 50 | + return; |
| 51 | + } |
| 52 | + |
| 53 | + const std::size_t max_team = team_size(); |
| 54 | + if (max_team == 1) { |
| 55 | + fn(std::size_t{0}); |
| 56 | + return; |
| 57 | + } |
| 58 | + if (n > max_team) n = max_team; |
| 59 | + if (n == 1) { |
| 60 | + fn(std::size_t{0}); |
| 61 | + return; |
| 62 | + } |
| 63 | + |
| 64 | + // Stable storage for callable during this region |
| 65 | + using Fn = std::decay_t<F>; |
| 66 | + Fn fn_copy = std::forward<F>(fn); |
| 67 | + |
| 68 | + // Publish region |
| 69 | + remaining_.store(n - 1, std::memory_order_release); // workers only |
| 70 | + region_n_.store(n, std::memory_order_release); |
| 71 | + region_ctx_.store(static_cast<void*>(&fn_copy), std::memory_order_release); |
| 72 | + region_call_.store(&call_impl<Fn>, std::memory_order_release); |
| 73 | + |
| 74 | + epoch_.fetch_add(1, std::memory_order_acq_rel); |
| 75 | + |
| 76 | + // Wake workers |
| 77 | + { |
| 78 | + std::lock_guard<std::mutex> lk(wake_m_); |
| 79 | + } |
| 80 | + wake_cv_.notify_all(); |
| 81 | + |
| 82 | + // Caller participates as tid=0 |
| 83 | + in_worker_ = true; |
| 84 | + fn_copy(0); |
| 85 | + in_worker_ = false; |
| 86 | + |
| 87 | + // Wait for workers 1..n-1 |
| 88 | + std::unique_lock<std::mutex> lk(done_m_); |
| 89 | + done_cv_.wait(lk, [&] { |
| 90 | + return remaining_.load(std::memory_order_acquire) == 0; |
| 91 | + }); |
| 92 | + } |
| 93 | + |
| 94 | + private: |
| 95 | + using call_fn_t = void (*)(void*, std::size_t); |
| 96 | + |
| 97 | + template <typename Fn> |
| 98 | + static void call_impl(void* ctx, std::size_t tid) { |
| 99 | + (*static_cast<Fn*>(ctx))(tid); |
| 100 | + } |
| 101 | + |
| 102 | + TeamThreadPool() |
| 103 | + : stop_(false), epoch_(0), region_n_(0), region_ctx_(nullptr), |
| 104 | + region_call_(nullptr), remaining_(0) { |
| 105 | + unsigned hw = std::thread::hardware_concurrency(); |
| 106 | + if (hw == 0) hw = 2; |
| 107 | + |
| 108 | + // hw-1 worker threads; caller is +1 participant. |
| 109 | + const unsigned num_workers = (hw > 1) ? (hw - 1) : 1; |
| 110 | + |
| 111 | + workers_.reserve(num_workers); |
| 112 | + for (unsigned i = 0; i < num_workers; ++i) { |
| 113 | + const std::size_t tid = static_cast<std::size_t>(i + 1); // workers are 1..N |
| 114 | + workers_.emplace_back([this, tid] { |
| 115 | + // Per-worker AD tape initialized once |
| 116 | + static thread_local ChainableStack ad_tape; |
| 117 | + in_worker_ = true; |
| 118 | + |
| 119 | + std::size_t seen = epoch_.load(std::memory_order_acquire); |
| 120 | + for (;;) { |
| 121 | + // Sleep until epoch changes or stop requested |
| 122 | + { |
| 123 | + std::unique_lock<std::mutex> lk(wake_m_); |
| 124 | + wake_cv_.wait(lk, [&] { |
| 125 | + return stop_.load(std::memory_order_acquire) |
| 126 | + || epoch_.load(std::memory_order_acquire) != seen; |
| 127 | + }); |
| 128 | + } |
| 129 | + if (stop_.load(std::memory_order_acquire)) break; |
| 130 | + |
| 131 | + const std::size_t e = epoch_.load(std::memory_order_acquire); |
| 132 | + seen = e; |
| 133 | + |
| 134 | + const std::size_t n = region_n_.load(std::memory_order_acquire); |
| 135 | + if (tid >= n) { |
| 136 | + continue; // not participating this region |
| 137 | + } |
| 138 | + |
| 139 | + void* ctx = region_ctx_.load(std::memory_order_acquire); |
| 140 | + call_fn_t call = region_call_.load(std::memory_order_acquire); |
| 141 | + if (call) { |
| 142 | + call(ctx, tid); |
| 143 | + } |
| 144 | + |
| 145 | + if (remaining_.fetch_sub(1, std::memory_order_acq_rel) == 1) { |
| 146 | + std::lock_guard<std::mutex> lk(done_m_); |
| 147 | + done_cv_.notify_one(); |
| 148 | + } |
| 149 | + } |
| 150 | + |
| 151 | + in_worker_ = false; |
| 152 | + }); |
| 153 | + } |
| 154 | + } |
| 155 | + |
| 156 | + ~TeamThreadPool() { |
| 157 | + stop_.store(true, std::memory_order_release); |
| 158 | + { |
| 159 | + std::lock_guard<std::mutex> lk(wake_m_); |
| 160 | + } |
| 161 | + wake_cv_.notify_all(); |
| 162 | + for (auto& t : workers_) { |
| 163 | + if (t.joinable()) t.join(); |
| 164 | + } |
| 165 | + } |
| 166 | + |
| 167 | + static inline thread_local bool in_worker_ = false; |
| 168 | + |
| 169 | + std::vector<std::thread> workers_; |
| 170 | + std::atomic<bool> stop_; |
| 171 | + |
| 172 | + // Region publication |
| 173 | + std::atomic<std::size_t> epoch_; |
| 174 | + std::atomic<std::size_t> region_n_; |
| 175 | + std::atomic<void*> region_ctx_; |
| 176 | + std::atomic<call_fn_t> region_call_; |
| 177 | + |
| 178 | + // Worker wake |
| 179 | + std::mutex wake_m_; |
| 180 | + std::condition_variable wake_cv_; |
| 181 | + |
| 182 | + // Completion |
| 183 | + std::atomic<std::size_t> remaining_; |
| 184 | + std::mutex done_m_; |
| 185 | + std::condition_variable done_cv_; |
| 186 | +}; |
| 187 | + |
| 188 | +} // namespace math |
| 189 | +} // namespace stan |
| 190 | + |
| 191 | +#endif |
0 commit comments