|
6 | 6 | #include <stan/math/fwd/meta.hpp> |
7 | 7 | #include <stan/math/fwd/fun/sum.hpp> |
8 | 8 | #include <stan/math/prim/fun/accumulator.hpp> |
| 9 | + |
9 | 10 | #include <vector> |
10 | 11 | #include <type_traits> |
11 | 12 |
|
| 13 | +namespace stan { |
| 14 | +namespace math { |
| 15 | +template <typename T, typename> |
| 16 | +class accumulator; |
| 17 | +/** |
| 18 | + * Class to accumulate values and eventually return their sum. If |
| 19 | + * no values are ever added, the return value is 0. |
| 20 | + * |
| 21 | + * This class is useful for speeding up autodiff of long sums |
| 22 | + * because it uses the <code>sum()</code> operation (either from |
| 23 | + * <code>stan::math</code> or one defined by argument-dependent lookup. |
| 24 | + * |
| 25 | + * @tparam T Type of scalar added |
| 26 | + */ |
| 27 | +template <typename T> |
| 28 | +class accumulator<T, require_fvar_t<T>> { |
| 29 | + private: |
| 30 | + std::vector<T> buf_; |
| 31 | + |
| 32 | + public: |
| 33 | + /** |
| 34 | + * Add the specified arithmetic type value to the buffer after |
| 35 | + * static casting it to the class type <code>T</code>. |
| 36 | + * |
| 37 | + * <p>See the std library doc for <code>std::is_arithmetic</code> |
| 38 | + * for information on what counts as an arithmetic type. |
| 39 | + * |
| 40 | + * @tparam S Type of argument |
| 41 | + * @param x Value to add |
| 42 | + */ |
| 43 | + template <typename S, typename = require_stan_scalar_t<S>> |
| 44 | + inline void add(S x) { |
| 45 | + buf_.push_back(x); |
| 46 | + } |
| 47 | + |
| 48 | + /** |
| 49 | + * Add each entry in the specified matrix, vector, or row vector |
| 50 | + * of values to the buffer. |
| 51 | + * |
| 52 | + * @tparam S type of the matrix |
| 53 | + * @param m Matrix of values to add |
| 54 | + */ |
| 55 | + template <typename S, require_matrix_t<S>* = nullptr> |
| 56 | + inline void add(const S& m) { |
| 57 | + buf_.push_back(stan::math::sum(m)); |
| 58 | + } |
| 59 | + |
| 60 | + /** |
| 61 | + * Recursively add each entry in the specified standard vector |
| 62 | + * to the buffer. This will allow vectors of primitives, |
| 63 | + * autodiff variables to be added; if the vector entries |
| 64 | + * are collections, their elements are recursively added. |
| 65 | + * |
| 66 | + * @tparam S Type of value to recursively add. |
| 67 | + * @param xs Vector of entries to add |
| 68 | + */ |
| 69 | + template <typename S> |
| 70 | + inline void add(const std::vector<S>& xs) { |
| 71 | + for (size_t i = 0; i < xs.size(); ++i) { |
| 72 | + this->add(xs[i]); |
| 73 | + } |
| 74 | + } |
| 75 | + |
| 76 | +#ifdef STAN_OPENCL |
| 77 | + |
| 78 | + /** |
| 79 | + * Sum each entry and then push to the buffer. |
| 80 | + * @tparam S A Type inheriting from `matrix_cl_base` |
| 81 | + * @param xs An OpenCL matrix |
| 82 | + */ |
| 83 | + template <typename S, |
| 84 | + require_all_kernel_expressions_and_none_scalar_t<S>* = nullptr> |
| 85 | + inline void add(const S& xs) { |
| 86 | + buf_.push_back(stan::math::sum(xs)); |
| 87 | + } |
| 88 | + |
| 89 | +#endif |
| 90 | + |
| 91 | + /** |
| 92 | + * Return the sum of the accumulated values. |
| 93 | + * |
| 94 | + * @return Sum of accumulated values. |
| 95 | + */ |
| 96 | + inline T sum() const { return stan::math::sum(buf_); } |
| 97 | +}; |
| 98 | + |
| 99 | +} // namespace math |
| 100 | +} // namespace stan |
| 101 | + |
12 | 102 | #endif |
0 commit comments