Skip to content

Commit d328e0b

Browse files
committed
add accumulator for sum
1 parent b31f4ee commit d328e0b

3 files changed

Lines changed: 100 additions & 8 deletions

File tree

stan/math/fwd/fun/accumulator.hpp

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,98 @@
66
#include <stan/math/fwd/meta.hpp>
77
#include <stan/math/fwd/fun/sum.hpp>
88
#include <stan/math/prim/fun/accumulator.hpp>
9+
910
#include <vector>
1011
#include <type_traits>
1112

13+
namespace stan {
14+
namespace math {
15+
template <typename T, typename>
16+
class accumulator;
17+
/**
18+
* Class to accumulate values and eventually return their sum. If
19+
* no values are ever added, the return value is 0.
20+
*
21+
* This class is useful for speeding up autodiff of long sums
22+
* because it uses the <code>sum()</code> operation (either from
23+
* <code>stan::math</code> or one defined by argument-dependent lookup.
24+
*
25+
* @tparam T Type of scalar added
26+
*/
27+
template <typename T>
28+
class accumulator<T, require_fvar_t<T>> {
29+
private:
30+
std::vector<T> buf_;
31+
32+
public:
33+
/**
34+
* Add the specified arithmetic type value to the buffer after
35+
* static casting it to the class type <code>T</code>.
36+
*
37+
* <p>See the std library doc for <code>std::is_arithmetic</code>
38+
* for information on what counts as an arithmetic type.
39+
*
40+
* @tparam S Type of argument
41+
* @param x Value to add
42+
*/
43+
template <typename S, typename = require_stan_scalar_t<S>>
44+
inline void add(S x) {
45+
buf_.push_back(x);
46+
}
47+
48+
/**
49+
* Add each entry in the specified matrix, vector, or row vector
50+
* of values to the buffer.
51+
*
52+
* @tparam S type of the matrix
53+
* @param m Matrix of values to add
54+
*/
55+
template <typename S, require_matrix_t<S>* = nullptr>
56+
inline void add(const S& m) {
57+
buf_.push_back(stan::math::sum(m));
58+
}
59+
60+
/**
61+
* Recursively add each entry in the specified standard vector
62+
* to the buffer. This will allow vectors of primitives,
63+
* autodiff variables to be added; if the vector entries
64+
* are collections, their elements are recursively added.
65+
*
66+
* @tparam S Type of value to recursively add.
67+
* @param xs Vector of entries to add
68+
*/
69+
template <typename S>
70+
inline void add(const std::vector<S>& xs) {
71+
for (size_t i = 0; i < xs.size(); ++i) {
72+
this->add(xs[i]);
73+
}
74+
}
75+
76+
#ifdef STAN_OPENCL
77+
78+
/**
79+
* Sum each entry and then push to the buffer.
80+
* @tparam S A Type inheriting from `matrix_cl_base`
81+
* @param xs An OpenCL matrix
82+
*/
83+
template <typename S,
84+
require_all_kernel_expressions_and_none_scalar_t<S>* = nullptr>
85+
inline void add(const S& xs) {
86+
buf_.push_back(stan::math::sum(xs));
87+
}
88+
89+
#endif
90+
91+
/**
92+
* Return the sum of the accumulated values.
93+
*
94+
* @return Sum of accumulated values.
95+
*/
96+
inline T sum() const { return stan::math::sum(buf_); }
97+
};
98+
99+
} // namespace math
100+
} // namespace stan
101+
102+
12103
#endif

stan/math/fwd/fun/sum.hpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
#ifndef STAN_MATH_FWD_FUN_SUM_HPP
22
#define STAN_MATH_FWD_FUN_SUM_HPP
33

4+
#include <stan/math/fwd/core.hpp>
5+
#include <stan/math/fwd/meta.hpp>
46
#include <stan/math/prim/meta.hpp>
57
#include <stan/math/prim/fun/Eigen.hpp>
68
#include <stan/math/prim/fun/sum.hpp>
7-
#include <stan/math/fwd/core.hpp>
89
#include <vector>
910

1011
namespace stan {
@@ -18,18 +19,18 @@ namespace math {
1819
* @param m Vector.
1920
* @return Sum of vector entries.
2021
*/
21-
template <typename T>
22-
inline fvar<T> sum(const std::vector<fvar<T>>& m) {
22+
template <typename T, require_fvar_t<T>* = nullptr>
23+
inline auto sum(const std::vector<T>& m) {
2324
if (m.size() == 0) {
24-
return 0.0;
25+
return T(0.0);
2526
}
26-
std::vector<T> vals(m.size());
27-
std::vector<T> tans(m.size());
27+
std::vector<partials_type_t<T>> vals(m.size());
28+
std::vector<partials_type_t<T>> tans(m.size());
2829
for (size_t i = 0; i < m.size(); ++i) {
2930
vals[i] = m[i].val();
3031
tans[i] = m[i].d();
3132
}
32-
return fvar<T>(sum(vals), sum(tans));
33+
return T(sum(vals), sum(tans));
3334
}
3435

3536
/**

stan/math/prim/fun/sum.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ inline T sum(T&& m) {
2929
* @param m Standard vector to sum.
3030
* @return Sum of elements.
3131
*/
32-
template <typename T, require_not_var_t<T>* = nullptr>
32+
template <typename T, require_not_autodiff_t<T>* = nullptr>
3333
inline T sum(const std::vector<T>& m) {
3434
return std::accumulate(m.begin(), m.end(), T{0});
3535
}

0 commit comments

Comments
 (0)