Skip to content

Commit 9f72cc3

Browse files
committed
wip: reduce_sum
1 parent 5df6fc5 commit 9f72cc3

3 files changed

Lines changed: 203 additions & 15 deletions

File tree

make/compiler_flags

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ endif
342342
# Sets up CXXFLAGS_THREADS to use threading
343343

344344
ifdef STAN_THREADS
345-
CXXFLAGS_THREADS ?= -DSTAN_THREADS
345+
CXXFLAGS_THREADS ?= -DSTAN_THREADS -pthread
346346
endif
347347

348348
################################################################################

stan/math/rev/functor/map_rect_concurrent.hpp

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
#include <tbb/blocked_range.h>
1313

1414
#include <algorithm>
15+
#include <numeric>
16+
#include <thread>
1517
#include <vector>
1618

1719
namespace stan {
@@ -46,18 +48,40 @@ map_rect_concurrent(
4648
};
4749

4850
#ifdef STAN_THREADS
49-
// we must use task isolation as described here:
50-
// https://software.intel.com/content/www/us/en/develop/documentation/tbb-documentation/top/intel-threading-building-blocks-developer-guide/task-isolation.html
51-
// this is to ensure that the thread local AD tape ressource is
52-
// not being modified from a different task which may happen
53-
// whenever this function is being used itself in a parallel
54-
// context (like running multiple chains for Stan)
55-
tbb::this_task_arena::isolate([&] {
56-
tbb::parallel_for(tbb::blocked_range<std::size_t>(0, num_jobs),
57-
[&](const tbb::blocked_range<size_t>& r) {
58-
execute_chunk(r.begin(), r.end());
59-
});
60-
});
51+
std::cout << "********************************************************************************" << std::endl;
52+
if (num_jobs > 1) {
53+
// simple chunked threading over [0, num_jobs)
54+
unsigned hw_threads = std::thread::hardware_concurrency();
55+
if (hw_threads == 0) {
56+
hw_threads = 2; // arbitrary but > 0
57+
}
58+
59+
const unsigned max_threads
60+
= static_cast<unsigned>(std::min<std::size_t>(hw_threads, num_jobs));
61+
std::cout << "max_threads = " << max_threads << std::endl;
62+
std::vector<std::thread> threads;
63+
threads.reserve(max_threads);
64+
65+
const std::size_t chunk
66+
= (num_jobs + max_threads - 1) / max_threads; // ceil
67+
68+
for (unsigned t = 0; t < max_threads; ++t) {
69+
const std::size_t start = t * chunk;
70+
if (start >= num_jobs) break;
71+
const std::size_t end
72+
= std::min<std::size_t>(start + chunk, num_jobs);
73+
74+
threads.emplace_back([&, start, end] {
75+
execute_chunk(start, end);
76+
});
77+
}
78+
79+
for (auto& th : threads) {
80+
th.join();
81+
}
82+
} else {
83+
execute_chunk(0, num_jobs);
84+
}
6185
#else
6286
execute_chunk(0, num_jobs);
6387
#endif

stan/math/rev/functor/reduce_sum.hpp

Lines changed: 166 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#include <stan/math/prim/functor.hpp>
66
#include <stan/math/rev/core.hpp>
77

8+
#include <thread>
9+
810
#include <tbb/task_arena.h>
911
#include <tbb/parallel_reduce.h>
1012
#include <tbb/blocked_range.h>
@@ -74,13 +76,79 @@ struct reduce_sum_impl<ReduceFunction, require_var_t<ReturnType>, ReturnType,
7476
* to zero since the newly created reducer is used to accumulate
7577
* an independent partial sum.
7678
*/
79+
/*
7780
recursive_reducer(recursive_reducer& other, tbb::split)
7881
: num_vars_per_term_(other.num_vars_per_term_),
7982
num_vars_shared_terms_(other.num_vars_shared_terms_),
8083
sliced_partials_(other.sliced_partials_),
8184
vmapped_(other.vmapped_),
8285
args_tuple_(other.args_tuple_) {}
86+
*/
87+
88+
inline void operator()(std::size_t begin, std::size_t end) {
89+
if (begin == end) {
90+
return;
91+
}
92+
93+
if (args_adjoints_.size() == 0) {
94+
args_adjoints_ = Eigen::VectorXd::Zero(num_vars_shared_terms_);
95+
}
8396

97+
// local copy of shared arguments in a local stack
98+
if (!local_args_tuple_scope_.args_tuple_holder_) {
99+
local_args_tuple_scope_.stack_.execute([&]() {
100+
math::apply(
101+
[&](auto&&... args) {
102+
local_args_tuple_scope_.args_tuple_holder_ =
103+
std::make_unique<typename scoped_args_tuple::args_tuple_t>(
104+
deep_copy_vars(args)...);
105+
},
106+
args_tuple_);
107+
});
108+
} else {
109+
// set adjoints of shared arguments to zero
110+
local_args_tuple_scope_.stack_.execute([] { set_zero_all_adjoints(); });
111+
}
112+
113+
auto& args_tuple_local = *(local_args_tuple_scope_.args_tuple_holder_);
114+
115+
// Initialize nested autodiff stack
116+
const nested_rev_autodiff begin_nest;
117+
118+
// Create nested autodiff copies of sliced argument that do not point
119+
// back to main autodiff stack
120+
std::decay_t<Vec> local_sub_slice;
121+
local_sub_slice.reserve(end - begin);
122+
for (std::size_t i = begin; i < end; ++i) {
123+
local_sub_slice.emplace_back(deep_copy_vars(vmapped_[i]));
124+
}
125+
126+
// Perform calculation
127+
var sub_sum_v = math::apply(
128+
[&](auto&&... args) {
129+
return ReduceFunction()(local_sub_slice, begin, end - 1, &msgs_,
130+
args...);
131+
},
132+
args_tuple_local);
133+
134+
// Compute Jacobian
135+
sub_sum_v.grad();
136+
137+
// accumulate value
138+
sum_ += sub_sum_v.val();
139+
140+
// accumulate adjoints of sliced_arguments
141+
accumulate_adjoints(sliced_partials_ + begin * num_vars_per_term_,
142+
std::move(local_sub_slice));
143+
144+
// accumulate adjoints of shared_arguments
145+
math::apply(
146+
[&](auto&&... args) {
147+
accumulate_adjoints(args_adjoints_.data(), args...);
148+
},
149+
args_tuple_local);
150+
}
151+
84152
/**
85153
* Compute, using nested autodiff, the value and Jacobian of
86154
* `ReduceFunction` called over the range defined by r and accumulate those
@@ -94,7 +162,7 @@ struct reduce_sum_impl<ReduceFunction, require_var_t<ReturnType>, ReturnType,
94162
*
95163
* @param r Range over which to compute reduce_sum
96164
*/
97-
inline void operator()(const tbb::blocked_range<size_t>& r) {
165+
/* inline void operator()(const tbb::blocked_range<size_t>& r) {
98166
if (r.empty()) {
99167
return;
100168
}
@@ -163,7 +231,7 @@ struct reduce_sum_impl<ReduceFunction, require_var_t<ReturnType>, ReturnType,
163231
},
164232
args_tuple_local);
165233
}
166-
234+
*/
167235
/**
168236
* Join reducers. Accumuluate the value (sum_) and Jacobian (arg_adoints_)
169237
* of the other reducer.
@@ -221,6 +289,101 @@ struct reduce_sum_impl<ReduceFunction, require_var_t<ReturnType>, ReturnType,
221289
* @param args Shared arguments used in every sum term
222290
* @return Summation of all terms
223291
*/
292+
inline var operator()(Vec&& vmapped, bool /*auto_partitioning*/, int /*grainsize*/,
293+
std::ostream* msgs, Args&&... args) const {
294+
if (vmapped.empty()) {
295+
return var(0.0);
296+
}
297+
298+
const std::size_t num_terms = vmapped.size();
299+
const std::size_t num_vars_per_term = count_vars(vmapped[0]);
300+
const std::size_t num_vars_sliced_terms = num_terms * num_vars_per_term;
301+
const std::size_t num_vars_shared_terms = count_vars(args...);
302+
303+
vari** varis
304+
= ChainableStack::instance_->memalloc_.alloc_array<vari*>(
305+
num_vars_sliced_terms + num_vars_shared_terms);
306+
double* partials
307+
= ChainableStack::instance_->memalloc_.alloc_array<double>(
308+
num_vars_sliced_terms + num_vars_shared_terms);
309+
310+
save_varis(varis, vmapped);
311+
save_varis(varis + num_vars_sliced_terms, args...);
312+
313+
for (std::size_t i = 0; i < num_vars_sliced_terms; ++i) {
314+
partials[i] = 0.0;
315+
}
316+
317+
// --- simple std::thread parallelism ---
318+
319+
// how many threads to use
320+
const unsigned hw = std::thread::hardware_concurrency();
321+
const std::size_t max_threads = hw == 0 ? 2 : hw;
322+
const std::size_t num_threads = std::min<std::size_t>(max_threads, num_terms);
323+
324+
// each thread gets its own reducer, but they all share the same partials buffer
325+
// (sliced_partials_) and write to disjoint regions
326+
std::vector<std::unique_ptr<recursive_reducer>> workers;
327+
workers.reserve(num_threads);
328+
329+
std::vector<std::thread> threads;
330+
threads.reserve(num_threads);
331+
332+
std::size_t block_begin = 0;
333+
for (std::size_t t = 0; t < num_threads; ++t) {
334+
std::size_t block_end
335+
= (t + 1 == num_threads)
336+
? num_terms
337+
: (num_terms * (t + 1)) / num_threads;
338+
339+
// construct reducer for this thread
340+
workers.emplace_back(std::make_unique<recursive_reducer>(
341+
num_vars_per_term, num_vars_shared_terms, partials,
342+
vmapped, args...));
343+
344+
auto* wptr = workers.back().get();
345+
346+
threads.emplace_back([wptr, block_begin, block_end]() {
347+
// each worker thread needs its own AD tape
348+
static thread_local ChainableStack ad_tape;
349+
wptr->operator()(block_begin, block_end);
350+
});
351+
352+
block_begin = block_end;
353+
}
354+
355+
for (auto& th : threads) {
356+
th.join();
357+
}
358+
359+
// aggregate results
360+
double total_sum = 0.0;
361+
Eigen::VectorXd shared_adjoints
362+
= Eigen::VectorXd::Zero(num_vars_shared_terms);
363+
std::stringstream all_msgs;
364+
365+
for (auto& w : workers) {
366+
total_sum += w->sum_;
367+
if (w->args_adjoints_.size() != 0) {
368+
shared_adjoints += w->args_adjoints_;
369+
}
370+
all_msgs << w->msgs_.str();
371+
}
372+
373+
for (std::size_t i = 0; i < num_vars_shared_terms; ++i) {
374+
partials[num_vars_sliced_terms + i] = shared_adjoints.coeff(i);
375+
}
376+
377+
if (msgs) {
378+
*msgs << all_msgs.str();
379+
}
380+
381+
return var(
382+
new precomputed_gradients_vari(total_sum,
383+
num_vars_sliced_terms + num_vars_shared_terms,
384+
varis, partials));
385+
}
386+
/*
224387
inline var operator()(Vec&& vmapped, bool auto_partitioning, int grainsize,
225388
std::ostream* msgs, Args&&... args) const {
226389
if (vmapped.empty()) {
@@ -278,6 +441,7 @@ struct reduce_sum_impl<ReduceFunction, require_var_t<ReturnType>, ReturnType,
278441
worker.sum_, num_vars_sliced_terms + num_vars_shared_terms, varis,
279442
partials));
280443
}
444+
*/
281445
};
282446
} // namespace internal
283447

0 commit comments

Comments
 (0)