Skip to content

Commit b494757

Browse files
committed
Reverse mode from @spinkney
1 parent 859f2c8 commit b494757

3 files changed

Lines changed: 48 additions & 15 deletions

File tree

stan/math/prim/constraint/sum_to_zero_constrain.hpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ inline plain_type_t<Vec> sum_to_zero_constrain(const Vec& y) {
5050
value_type_t<Vec> sum_w(0);
5151
for (int i = N; i > 0; --i) {
5252
double n = static_cast<double>(i);
53-
auto w = y_ref(i - 1) * inv_sqrt(n * (n + 1));
53+
auto w = y_ref.coeff(i - 1) * inv_sqrt(n * (n + 1));
5454
sum_w += w;
5555

5656
z.coeffRef(i - 1) += sum_w;
@@ -88,24 +88,24 @@ inline plain_type_t<Mat> sum_to_zero_constrain(const Mat& x) {
8888
for (int j = M - 1; j >= 0; --j) {
8989
value_type_t<Mat> ax_previous(0);
9090

91-
double a_j = 1.0 / std::sqrt((j + 1.0) * (j + 2.0));
91+
double a_j = inv_sqrt((j + 1.0) * (j + 2.0));
9292
double b_j = (j + 1.0) * a_j;
9393

9494
for (int i = N - 1; i >= 0; --i) {
95-
double a_i = 1.0 / std::sqrt((i + 1.0) * (i + 2.0));
95+
double a_i = inv_sqrt((i + 1.0) * (i + 2.0));
9696
double b_i = (i + 1.0) * a_i;
9797

98-
auto b_i_x = b_i * x_ref(i, j) - ax_previous;
98+
auto b_i_x = b_i * x_ref.coeff(i, j) - ax_previous;
9999

100-
Z(i, j) = (b_j * b_i_x) - beta(i);
101-
beta(i) += a_j * b_i_x;
100+
Z.coeffRef(i, j) = (b_j * b_i_x) - beta.coeff(i);
101+
beta.coeffRef(i) += a_j * b_i_x;
102102

103-
Z(N, j) -= Z(i, j);
104-
Z(i, M) -= Z(i, j);
103+
Z.coeffRef(N, j) -= Z.coeff(i, j);
104+
Z.coeffRef(i, M) -= Z.coeff(i, j);
105105

106-
ax_previous += a_i * x_ref(i, j);
106+
ax_previous += a_i * x_ref.coeff(i, j);
107107
}
108-
Z(N, M) -= Z(N, j);
108+
Z.coeffRef(N, M) -= Z.coeff(N, j);
109109
}
110110

111111
return Z;

stan/math/rev/constraint/sum_to_zero_constrain.hpp

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <stan/math/rev/core/reverse_pass_callback.hpp>
66
#include <stan/math/rev/core/arena_matrix.hpp>
77
#include <stan/math/prim/fun/Eigen.hpp>
8+
#include <stan/math/prim/fun/inv_sqrt.hpp>
89
#include <stan/math/prim/fun/sqrt.hpp>
910
#include <stan/math/prim/constraint/sum_to_zero_constrain.hpp>
1011
#include <cmath>
@@ -55,14 +56,14 @@ inline auto sum_to_zero_constrain(T&& y) {
5556
double n = static_cast<double>(i + 1);
5657

5758
// adjoint of the reverse cumulative sum computed in the forward mode
58-
sum_u_adj += arena_z.adj()(i);
59+
sum_u_adj += arena_z.adj().coeff(i);
5960

6061
// adjoint of the offset subtraction
61-
double v_adj = -arena_z.adj()(i + 1) * n;
62+
double v_adj = -arena_z.adj().coeff(i + 1) * n;
6263

6364
double w_adj = v_adj + sum_u_adj;
6465

65-
arena_y.adj()(i) += w_adj / sqrt(n * (n + 1));
66+
arena_y.adj().coeffRef(i) += w_adj / sqrt(n * (n + 1));
6667
}
6768
});
6869

@@ -90,8 +91,35 @@ inline auto sum_to_zero_constrain(T&& x) {
9091
arena_t<ret_type> arena_z = sum_to_zero_constrain(arena_x.val());
9192

9293
reverse_pass_callback([arena_x, arena_z]() mutable {
93-
const auto N = arena_x.rows();
94-
const auto M = arena_x.cols();
94+
const auto Nf = arena_x.val().rows();
95+
const auto Mf = arena_x.val().cols();
96+
const auto N = Nf + 1;
97+
const auto M = Mf + 1;
98+
const auto s = std::max(Nf, Mf);
99+
100+
Eigen::VectorXd d_beta = Eigen::VectorXd::Zero(Nf);
101+
102+
for (int j = 0; j < Mf; ++j) {
103+
double a_j = inv_sqrt((j + 1.0) * (j + 2.0));
104+
double b_j = (j + 1.0) * a_j;
105+
106+
double d_ax = 0.0;
107+
108+
for (int i = 0; i < Nf; ++i) {
109+
double a_i = inv_sqrt((i + 1.0) * (i + 2.0));
110+
double b_i = (i + 1.0) * a_i;
111+
112+
double dY = arena_z.adj().coeff(i, j) - arena_z.adj().coeff(Nf, j)
113+
+ arena_z.adj().coeff(Nf, Mf) - arena_z.adj().coeff(i, Mf);
114+
double dI_from_beta = a_j * d_beta.coeff(i);
115+
d_beta.coeffRef(i) += -dY;
116+
117+
double dI_from_alpha = b_j * dY;
118+
double dI = dI_from_alpha + dI_from_beta;
119+
arena_x.adj().coeffRef(i, j) += b_i * dI + a_i * d_ax;
120+
d_ax -= dI;
121+
}
122+
}
95123
});
96124

97125
return arena_z;

test/unit/math/mix/constraint/sum_to_zero_constrain_test.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,9 @@ TEST(MathMixMatFun, sum_to_zeroTransform) {
5555
Eigen::VectorXd v5(5);
5656
v5 << 1, -3, 2, 0, -1;
5757
sum_to_zero_constrain_test::expect_sum_to_zero_transform(v5);
58+
}
5859

60+
TEST(MathMixMatFun, sum_to_zero_matrixTransform) {
5961
Eigen::MatrixXd m0_0(0, 0);
6062
sum_to_zero_constrain_test::expect_sum_to_zero_transform(m0_0);
6163

@@ -71,4 +73,7 @@ TEST(MathMixMatFun, sum_to_zeroTransform) {
7173
m3_4 << 1, 2, -3, 4, 5, 6, -7, 8, 9, -10, 11, -12;
7274

7375
sum_to_zero_constrain_test::expect_sum_to_zero_transform(m3_4);
76+
77+
Eigen::MatrixXd m4_3 = m3_4.transpose();
78+
sum_to_zero_constrain_test::expect_sum_to_zero_transform(m4_3);
7479
}

0 commit comments

Comments
 (0)