55#include < stan/math/prim/functor.hpp>
66#include < stan/math/rev/core.hpp>
77
8+ #include < thread>
9+
810#include < tbb/task_arena.h>
911#include < tbb/parallel_reduce.h>
1012#include < tbb/blocked_range.h>
@@ -74,13 +76,79 @@ struct reduce_sum_impl<ReduceFunction, require_var_t<ReturnType>, ReturnType,
7476 * to zero since the newly created reducer is used to accumulate
7577 * an independent partial sum.
7678 */
79+ /*
7780 recursive_reducer(recursive_reducer& other, tbb::split)
7881 : num_vars_per_term_(other.num_vars_per_term_),
7982 num_vars_shared_terms_(other.num_vars_shared_terms_),
8083 sliced_partials_(other.sliced_partials_),
8184 vmapped_(other.vmapped_),
8285 args_tuple_(other.args_tuple_) {}
86+ */
87+
88+ inline void operator ()(std::size_t begin, std::size_t end) {
89+ if (begin == end) {
90+ return ;
91+ }
92+
93+ if (args_adjoints_.size () == 0 ) {
94+ args_adjoints_ = Eigen::VectorXd::Zero (num_vars_shared_terms_);
95+ }
8396
97+ // local copy of shared arguments in a local stack
98+ if (!local_args_tuple_scope_.args_tuple_holder_ ) {
99+ local_args_tuple_scope_.stack_ .execute ([&]() {
100+ math::apply (
101+ [&](auto &&... args) {
102+ local_args_tuple_scope_.args_tuple_holder_ =
103+ std::make_unique<typename scoped_args_tuple::args_tuple_t >(
104+ deep_copy_vars (args)...);
105+ },
106+ args_tuple_);
107+ });
108+ } else {
109+ // set adjoints of shared arguments to zero
110+ local_args_tuple_scope_.stack_ .execute ([] { set_zero_all_adjoints (); });
111+ }
112+
113+ auto & args_tuple_local = *(local_args_tuple_scope_.args_tuple_holder_ );
114+
115+ // Initialize nested autodiff stack
116+ const nested_rev_autodiff begin_nest;
117+
118+ // Create nested autodiff copies of sliced argument that do not point
119+ // back to main autodiff stack
120+ std::decay_t <Vec> local_sub_slice;
121+ local_sub_slice.reserve (end - begin);
122+ for (std::size_t i = begin; i < end; ++i) {
123+ local_sub_slice.emplace_back (deep_copy_vars (vmapped_[i]));
124+ }
125+
126+ // Perform calculation
127+ var sub_sum_v = math::apply (
128+ [&](auto &&... args) {
129+ return ReduceFunction ()(local_sub_slice, begin, end - 1 , &msgs_,
130+ args...);
131+ },
132+ args_tuple_local);
133+
134+ // Compute Jacobian
135+ sub_sum_v.grad ();
136+
137+ // accumulate value
138+ sum_ += sub_sum_v.val ();
139+
140+ // accumulate adjoints of sliced_arguments
141+ accumulate_adjoints (sliced_partials_ + begin * num_vars_per_term_,
142+ std::move (local_sub_slice));
143+
144+ // accumulate adjoints of shared_arguments
145+ math::apply (
146+ [&](auto &&... args) {
147+ accumulate_adjoints (args_adjoints_.data (), args...);
148+ },
149+ args_tuple_local);
150+ }
151+
84152 /* *
85153 * Compute, using nested autodiff, the value and Jacobian of
86154 * `ReduceFunction` called over the range defined by r and accumulate those
@@ -94,7 +162,7 @@ struct reduce_sum_impl<ReduceFunction, require_var_t<ReturnType>, ReturnType,
94162 *
95163 * @param r Range over which to compute reduce_sum
96164 */
97- inline void operator ()(const tbb::blocked_range<size_t >& r) {
165+ /* inline void operator()(const tbb::blocked_range<size_t>& r) {
98166 if (r.empty()) {
99167 return;
100168 }
@@ -163,7 +231,7 @@ struct reduce_sum_impl<ReduceFunction, require_var_t<ReturnType>, ReturnType,
163231 },
164232 args_tuple_local);
165233 }
166-
234+ */
167235 /* *
168236 * Join reducers. Accumuluate the value (sum_) and Jacobian (arg_adoints_)
169237 * of the other reducer.
@@ -221,6 +289,101 @@ struct reduce_sum_impl<ReduceFunction, require_var_t<ReturnType>, ReturnType,
221289 * @param args Shared arguments used in every sum term
222290 * @return Summation of all terms
223291 */
292+ inline var operator ()(Vec&& vmapped, bool /* auto_partitioning*/ , int /* grainsize*/ ,
293+ std::ostream* msgs, Args&&... args) const {
294+ if (vmapped.empty ()) {
295+ return var (0.0 );
296+ }
297+
298+ const std::size_t num_terms = vmapped.size ();
299+ const std::size_t num_vars_per_term = count_vars (vmapped[0 ]);
300+ const std::size_t num_vars_sliced_terms = num_terms * num_vars_per_term;
301+ const std::size_t num_vars_shared_terms = count_vars (args...);
302+
303+ vari** varis
304+ = ChainableStack::instance_->memalloc_ .alloc_array <vari*>(
305+ num_vars_sliced_terms + num_vars_shared_terms);
306+ double * partials
307+ = ChainableStack::instance_->memalloc_ .alloc_array <double >(
308+ num_vars_sliced_terms + num_vars_shared_terms);
309+
310+ save_varis (varis, vmapped);
311+ save_varis (varis + num_vars_sliced_terms, args...);
312+
313+ for (std::size_t i = 0 ; i < num_vars_sliced_terms; ++i) {
314+ partials[i] = 0.0 ;
315+ }
316+
317+ // --- simple std::thread parallelism ---
318+
319+ // how many threads to use
320+ const unsigned hw = std::thread::hardware_concurrency ();
321+ const std::size_t max_threads = hw == 0 ? 2 : hw;
322+ const std::size_t num_threads = std::min<std::size_t >(max_threads, num_terms);
323+
324+ // each thread gets its own reducer, but they all share the same partials buffer
325+ // (sliced_partials_) and write to disjoint regions
326+ std::vector<std::unique_ptr<recursive_reducer>> workers;
327+ workers.reserve (num_threads);
328+
329+ std::vector<std::thread> threads;
330+ threads.reserve (num_threads);
331+
332+ std::size_t block_begin = 0 ;
333+ for (std::size_t t = 0 ; t < num_threads; ++t) {
334+ std::size_t block_end
335+ = (t + 1 == num_threads)
336+ ? num_terms
337+ : (num_terms * (t + 1 )) / num_threads;
338+
339+ // construct reducer for this thread
340+ workers.emplace_back (std::make_unique<recursive_reducer>(
341+ num_vars_per_term, num_vars_shared_terms, partials,
342+ vmapped, args...));
343+
344+ auto * wptr = workers.back ().get ();
345+
346+ threads.emplace_back ([wptr, block_begin, block_end]() {
347+ // each worker thread needs its own AD tape
348+ static thread_local ChainableStack ad_tape;
349+ wptr->operator ()(block_begin, block_end);
350+ });
351+
352+ block_begin = block_end;
353+ }
354+
355+ for (auto & th : threads) {
356+ th.join ();
357+ }
358+
359+ // aggregate results
360+ double total_sum = 0.0 ;
361+ Eigen::VectorXd shared_adjoints
362+ = Eigen::VectorXd::Zero (num_vars_shared_terms);
363+ std::stringstream all_msgs;
364+
365+ for (auto & w : workers) {
366+ total_sum += w->sum_ ;
367+ if (w->args_adjoints_ .size () != 0 ) {
368+ shared_adjoints += w->args_adjoints_ ;
369+ }
370+ all_msgs << w->msgs_ .str ();
371+ }
372+
373+ for (std::size_t i = 0 ; i < num_vars_shared_terms; ++i) {
374+ partials[num_vars_sliced_terms + i] = shared_adjoints.coeff (i);
375+ }
376+
377+ if (msgs) {
378+ *msgs << all_msgs.str ();
379+ }
380+
381+ return var (
382+ new precomputed_gradients_vari (total_sum,
383+ num_vars_sliced_terms + num_vars_shared_terms,
384+ varis, partials));
385+ }
386+ /*
224387 inline var operator()(Vec&& vmapped, bool auto_partitioning, int grainsize,
225388 std::ostream* msgs, Args&&... args) const {
226389 if (vmapped.empty()) {
@@ -278,6 +441,7 @@ struct reduce_sum_impl<ReduceFunction, require_var_t<ReturnType>, ReturnType,
278441 worker.sum_, num_vars_sliced_terms + num_vars_shared_terms, varis,
279442 partials));
280443 }
444+ */
281445};
282446} // namespace internal
283447
0 commit comments