Skip to content

Commit d49a254

Browse files
committed
Clean up sum_to_zero_free
1 parent c00cd6e commit d49a254

1 file changed

Lines changed: 10 additions & 9 deletions

File tree

stan/math/prim/constraint/sum_to_zero_free.hpp

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <stan/math/prim/err.hpp>
66
#include <stan/math/prim/fun/Eigen.hpp>
77
#include <stan/math/prim/fun/to_ref.hpp>
8+
#include <stan/math/prim/fun/inv_sqrt.hpp>
89
#include <stan/math/prim/fun/sqrt.hpp>
910
#include <stan/math/prim/functor/apply_vector_unary.hpp>
1011
#include <cmath>
@@ -47,15 +48,15 @@ inline plain_type_t<Vec> sum_to_zero_free(const Vec& z) {
4748
return y;
4849
}
4950

50-
y.coeffRef(N - 1) = -z_ref(N) * sqrt(N * (N + 1)) / N;
51+
y.coeffRef(N - 1) = -z_ref.coeff(N) * sqrt(N * (N + 1)) / N;
5152

5253
value_type_t<Vec> sum_w(0);
5354

5455
for (int i = N - 2; i >= 0; --i) {
5556
double n = static_cast<double>(i + 1);
56-
auto w = y(i + 1) / sqrt((n + 1) * (n + 2));
57+
auto w = y.coeff(i + 1) / sqrt((n + 1) * (n + 2));
5758
sum_w += w;
58-
y.coeffRef(i) = (sum_w - z_ref(i + 1)) * sqrt(n * (n + 1)) / n;
59+
y.coeffRef(i) = (sum_w - z_ref.coeff(i + 1)) * sqrt(n * (n + 1)) / n;
5960
}
6061

6162
return y;
@@ -88,18 +89,18 @@ inline plain_type_t<Mat> sum_to_zero_free(const Mat& z) {
8889
for (int j = M - 1; j >= 0; --j) {
8990
value_type_t<Mat> ax_previous(0);
9091

91-
double a_j = 1.0 / std::sqrt((j + 1.0) * (j + 2.0));
92+
double a_j = inv_sqrt((j + 1.0) * (j + 2.0));
9293
double b_j = (j + 1.0) * a_j;
9394

9495
for (int i = N - 1; i >= 0; --i) {
95-
double a_i = 1.0 / std::sqrt((i + 1.0) * (i + 2.0));
96+
double a_i = inv_sqrt((i + 1.0) * (i + 2.0));
9697
double b_i = (i + 1.0) * a_i;
9798

98-
auto alpha_plus_beta = z_ref(i, j) + beta(i);
99+
auto alpha_plus_beta = z_ref.coeff(i, j) + beta.coeff(i);
99100

100-
x(i, j) = (alpha_plus_beta + b_j * ax_previous) / (b_j * b_i);
101-
beta(i) += a_j * (b_i * x(i, j) - ax_previous);
102-
ax_previous += a_i * x(i, j);
101+
x.coeffRef(i, j) = (alpha_plus_beta + b_j * ax_previous) / (b_j * b_i);
102+
beta.coeffRef(i) += a_j * (b_i * x.coeff(i, j) - ax_previous);
103+
ax_previous += a_i * x.coeff(i, j);
103104
}
104105
}
105106

0 commit comments

Comments
 (0)