|
5 | 5 | #include <stan/math/rev/core/reverse_pass_callback.hpp> |
6 | 6 | #include <stan/math/rev/core/arena_matrix.hpp> |
7 | 7 | #include <stan/math/prim/fun/Eigen.hpp> |
| 8 | +#include <stan/math/prim/fun/inv_sqrt.hpp> |
8 | 9 | #include <stan/math/prim/fun/sqrt.hpp> |
9 | 10 | #include <stan/math/prim/constraint/sum_to_zero_constrain.hpp> |
10 | 11 | #include <cmath> |
@@ -55,14 +56,14 @@ inline auto sum_to_zero_constrain(T&& y) { |
55 | 56 | double n = static_cast<double>(i + 1); |
56 | 57 |
|
57 | 58 | // adjoint of the reverse cumulative sum computed in the forward mode |
58 | | - sum_u_adj += arena_z.adj()(i); |
| 59 | + sum_u_adj += arena_z.adj().coeff(i); |
59 | 60 |
|
60 | 61 | // adjoint of the offset subtraction |
61 | | - double v_adj = -arena_z.adj()(i + 1) * n; |
| 62 | + double v_adj = -arena_z.adj().coeff(i + 1) * n; |
62 | 63 |
|
63 | 64 | double w_adj = v_adj + sum_u_adj; |
64 | 65 |
|
65 | | - arena_y.adj()(i) += w_adj / sqrt(n * (n + 1)); |
| 66 | + arena_y.adj().coeffRef(i) += w_adj / sqrt(n * (n + 1)); |
66 | 67 | } |
67 | 68 | }); |
68 | 69 |
|
@@ -90,8 +91,35 @@ inline auto sum_to_zero_constrain(T&& x) { |
90 | 91 | arena_t<ret_type> arena_z = sum_to_zero_constrain(arena_x.val()); |
91 | 92 |
|
92 | 93 | reverse_pass_callback([arena_x, arena_z]() mutable { |
93 | | - const auto N = arena_x.rows(); |
94 | | - const auto M = arena_x.cols(); |
| 94 | + const auto Nf = arena_x.val().rows(); |
| 95 | + const auto Mf = arena_x.val().cols(); |
| 96 | + const auto N = Nf + 1; |
| 97 | + const auto M = Mf + 1; |
| 98 | + const auto s = std::max(Nf, Mf); |
| 99 | + |
| 100 | + Eigen::VectorXd d_beta = Eigen::VectorXd::Zero(Nf); |
| 101 | + |
| 102 | + for (int j = 0; j < Mf; ++j) { |
| 103 | + double a_j = inv_sqrt((j + 1.0) * (j + 2.0)); |
| 104 | + double b_j = (j + 1.0) * a_j; |
| 105 | + |
| 106 | + double d_ax = 0.0; |
| 107 | + |
| 108 | + for (int i = 0; i < Nf; ++i) { |
| 109 | + double a_i = inv_sqrt((i + 1.0) * (i + 2.0)); |
| 110 | + double b_i = (i + 1.0) * a_i; |
| 111 | + |
| 112 | + double dY = arena_z.adj().coeff(i, j) - arena_z.adj().coeff(Nf, j) |
| 113 | + + arena_z.adj().coeff(Nf, Mf) - arena_z.adj().coeff(i, Mf); |
| 114 | + double dI_from_beta = a_j * d_beta.coeff(i); |
| 115 | + d_beta.coeffRef(i) += -dY; |
| 116 | + |
| 117 | + double dI_from_alpha = b_j * dY; |
| 118 | + double dI = dI_from_alpha + dI_from_beta; |
| 119 | + arena_x.adj().coeffRef(i, j) += b_i * dI + a_i * d_ax; |
| 120 | + d_ax -= dI; |
| 121 | + } |
| 122 | + } |
95 | 123 | }); |
96 | 124 |
|
97 | 125 | return arena_z; |
|
0 commit comments