Skip to content

Commit e132643

Browse files
committed
Basic rev impl
1 parent 004c7fe commit e132643

3 files changed

Lines changed: 42 additions & 16 deletions

File tree

stan/math/prim/constraint/sum_to_zero_constrain.hpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,14 @@ inline plain_type_t<Vec> sum_to_zero_constrain(const Vec& y) {
3333

3434
auto&& y_ref = to_ref(y);
3535

36-
typename plain_type_t<Vec>::Scalar total(0);
37-
for (int i = N - 1; i >= 0; --i) {
38-
double n = i + 1;
39-
auto w = y_ref(i) * inv_sqrt(n * (n + 1));
40-
total += w;
36+
typename plain_type_t<Vec>::Scalar sum_w(0);
37+
for (int i = N; i > 0; --i) {
38+
double n = i;
39+
auto w = y_ref(i-1) * inv_sqrt(n * (n + 1));
40+
sum_w += w;
4141

42-
z.coeffRef(i) += total;
43-
z.coeffRef(i + 1) -= w * n;
42+
z.coeffRef(i-1) += sum_w;
43+
z.coeffRef(i) -= w * n;
4444
}
4545

4646
return z;

stan/math/prim/constraint/sum_to_zero_free.hpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,14 @@ inline plain_type_t<Vec> sum_to_zero_free(const Vec& z) {
3737
}
3838

3939
y.coeffRef(N - 1) = -z_ref(N) * sqrt(N * (N + 1)) / N;
40-
typename plain_type_t<Vec>::Scalar total(0);
40+
41+
typename plain_type_t<Vec>::Scalar sum_w(0);
4142

4243
for (int i = N - 2; i >= 0; --i) {
4344
double n = i + 1;
4445
auto w = y(i + 1) / sqrt((n + 1) * (n + 2));
45-
total += w;
46-
y.coeffRef(i) = (total - z_ref(i + 1)) * sqrt(n * (n + 1)) / n;
46+
sum_w += w;
47+
y.coeffRef(i) = (sum_w - z_ref(i + 1)) * sqrt(n * (n + 1)) / n;
4748
}
4849

4950
return y;

stan/math/rev/constraint/sum_to_zero_constrain.hpp

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
#include <stan/math/rev/fun/value_of.hpp>
88
#include <stan/math/prim/fun/Eigen.hpp>
99
#include <stan/math/prim/constraint/sum_to_zero_constrain.hpp>
10+
#include <stan/math/prim/fun/linspaced_vector.hpp>
11+
#include <stan/math/prim/fun/cumulative_sum.hpp>
1012
#include <cmath>
1113
#include <tuple>
1214
#include <vector>
@@ -29,17 +31,40 @@ namespace math {
2931
template <typename T, require_rev_col_vector_t<T>* = nullptr>
3032
inline auto sum_to_zero_constrain(const T& y) {
3133
using ret_type = plain_type_t<T>;
32-
const auto N = y.size();
33-
if (unlikely(N == 0)) {
34+
if (unlikely(y.size() == 0)) {
3435
return arena_t<ret_type>(Eigen::VectorXd{{0}});
3536
}
3637
auto arena_y = to_arena(y);
37-
arena_t<ret_type> arena_x = sum_to_zero_constrain(arena_y.val());
38+
arena_t<ret_type> arena_z = sum_to_zero_constrain(arena_y.val());
3839

39-
reverse_pass_callback([arena_y, arena_x]() mutable {
40-
// TODO
40+
reverse_pass_callback([arena_y, arena_z]() mutable {
41+
const auto N = arena_y.size();
42+
43+
Eigen::VectorXd ns = linspaced_vector(N, 1, N );
44+
45+
/*
46+
u3.adj += z.adj[1:N-1]
47+
v.adj -= z.adj[2:N]
48+
w.adj += v.adj .* ns
49+
u2.adj += reverse(u3.adj)
50+
u1.adj += cumulative_sum(u2.adj)
51+
w.adj += reverse(u1.adj)
52+
y.adj += w.adj ./ sqrt(ns .* (ns + 1))
53+
*/
54+
55+
Eigen::VectorXd u_adj = arena_z.adj_op().head(N).eval();
56+
Eigen::VectorXd v_adj = -arena_z.adj_op().tail(N).eval();
57+
58+
Eigen::VectorXd w_from_v = v_adj.array() * ns.array();
59+
60+
Eigen::VectorXd w_from_u = cumulative_sum(u_adj);
61+
62+
Eigen::VectorXd w_adj = (w_from_v.array() + w_from_u.array());
63+
64+
arena_y.adj() += (w_adj.array() / sqrt(ns.array() * (ns.array() + 1))).matrix();
4165
});
42-
return arena_x;
66+
67+
return arena_z;
4368
}
4469

4570
/**

0 commit comments

Comments
 (0)