77#include < stan/math/rev/fun/value_of.hpp>
88#include < stan/math/prim/fun/Eigen.hpp>
99#include < stan/math/prim/constraint/sum_to_zero_constrain.hpp>
10+ #include < stan/math/prim/fun/linspaced_vector.hpp>
11+ #include < stan/math/prim/fun/cumulative_sum.hpp>
1012#include < cmath>
1113#include < tuple>
1214#include < vector>
@@ -29,17 +31,40 @@ namespace math {
2931template <typename T, require_rev_col_vector_t <T>* = nullptr >
3032inline auto sum_to_zero_constrain (const T& y) {
3133 using ret_type = plain_type_t <T>;
32- const auto N = y.size ();
33- if (unlikely (N == 0 )) {
34+ if (unlikely (y.size () == 0 )) {
3435 return arena_t <ret_type>(Eigen::VectorXd{{0 }});
3536 }
3637 auto arena_y = to_arena (y);
37- arena_t <ret_type> arena_x = sum_to_zero_constrain (arena_y.val ());
38+ arena_t <ret_type> arena_z = sum_to_zero_constrain (arena_y.val ());
3839
39- reverse_pass_callback ([arena_y, arena_x]() mutable {
40- // TODO
40+ reverse_pass_callback ([arena_y, arena_z]() mutable {
41+ const auto N = arena_y.size ();
42+
43+ Eigen::VectorXd ns = linspaced_vector (N, 1 , N );
44+
45+ /*
46+ u3.adj += z.adj[1:N-1]
47+ v.adj -= z.adj[2:N]
48+ w.adj += v.adj .* ns
49+ u2.adj += reverse(u3.adj)
50+ u1.adj += cumulative_sum(u2.adj)
51+ w.adj += reverse(u1.adj)
52+ y.adj += w.adj ./ sqrt(ns .* (ns + 1))
53+ */
54+
55+ Eigen::VectorXd u_adj = arena_z.adj_op ().head (N).eval ();
56+ Eigen::VectorXd v_adj = -arena_z.adj_op ().tail (N).eval ();
57+
58+ Eigen::VectorXd w_from_v = v_adj.array () * ns.array ();
59+
60+ Eigen::VectorXd w_from_u = cumulative_sum (u_adj);
61+
62+ Eigen::VectorXd w_adj = (w_from_v.array () + w_from_u.array ());
63+
64+ arena_y.adj () += (w_adj.array () / sqrt (ns.array () * (ns.array () + 1 ))).matrix ();
4165 });
42- return arena_x;
66+
67+ return arena_z;
4368}
4469
4570/* *
0 commit comments