Skip to content

Commit 9cf06f1

Browse files
committed
Loop version
1 parent e132643 commit 9cf06f1

1 file changed

Lines changed: 9 additions & 18 deletions

File tree

stan/math/rev/constraint/sum_to_zero_constrain.hpp

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -40,28 +40,19 @@ inline auto sum_to_zero_constrain(const T& y) {
4040
reverse_pass_callback([arena_y, arena_z]() mutable {
4141
const auto N = arena_y.size();
4242

43-
Eigen::VectorXd ns = linspaced_vector(N, 1, N );
43+
double sum_u_adj = 0;
44+
for (int i = 0; i < N; ++i) {
45+
double n = i + 1;
4446

45-
/*
46-
u3.adj += z.adj[1:N-1]
47-
v.adj -= z.adj[2:N]
48-
w.adj += v.adj .* ns
49-
u2.adj += reverse(u3.adj)
50-
u1.adj += cumulative_sum(u2.adj)
51-
w.adj += reverse(u1.adj)
52-
y.adj += w.adj ./ sqrt(ns .* (ns + 1))
53-
*/
47+
double u_adj = arena_z.adj()(i);
48+
sum_u_adj += u_adj;
5449

55-
Eigen::VectorXd u_adj = arena_z.adj_op().head(N).eval();
56-
Eigen::VectorXd v_adj = -arena_z.adj_op().tail(N).eval();
50+
double v_adj = -arena_z.adj()(i + 1);
5751

58-
Eigen::VectorXd w_from_v = v_adj.array() * ns.array();
52+
double w = (v_adj * n) + sum_u_adj;
5953

60-
Eigen::VectorXd w_from_u = cumulative_sum(u_adj);
61-
62-
Eigen::VectorXd w_adj = (w_from_v.array() + w_from_u.array());
63-
64-
arena_y.adj() += (w_adj.array() / sqrt(ns.array() * (ns.array() + 1))).matrix();
54+
arena_y.adj()(i) += w / sqrt(n * (n + 1));
55+
}
6556
});
6657

6758
return arena_z;

0 commit comments

Comments
 (0)