Skip to content

Commit 004c7fe

Browse files
committed
Free impl
1 parent 4197bfd commit 004c7fe

3 files changed

Lines changed: 29 additions & 9 deletions

File tree

stan/math/prim/constraint/sum_to_zero_free.hpp

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,30 +19,47 @@ namespace math {
1919
* sum of those elements.
2020
*
2121
* @tparam ColVec a column vector type
22-
* @param x Vector of length K.
22+
* @param z Vector of length K.
2323
* @return Free vector of length (K-1).
24-
* @throw std::domain_error if x does not sum to zero
24+
* @throw std::domain_error if z does not sum to zero
2525
*/
2626
template <typename Vec, require_eigen_vector_t<Vec>* = nullptr>
27-
inline plain_type_t<Vec> sum_to_zero_free(const Vec& x) {
28-
const auto& x_ref = to_ref(x);
27+
inline plain_type_t<Vec> sum_to_zero_free(const Vec& z) {
28+
const auto& z_ref = to_ref(z);
2929
check_sum_to_zero("stan::math::sum_to_zero_free", "sum_to_zero variable",
30-
x_ref);
30+
z_ref);
3131

32-
return x_ref.head(x_ref.size() - 1);
32+
const auto N = z.size() - 1;
33+
34+
plain_type_t<Vec> y = Eigen::VectorXd::Zero(N);
35+
if (unlikely(N == 0)) {
36+
return y;
37+
}
38+
39+
y.coeffRef(N - 1) = -z_ref(N) * sqrt(N * (N + 1)) / N;
40+
typename plain_type_t<Vec>::Scalar total(0);
41+
42+
for (int i = N - 2; i >= 0; --i) {
43+
double n = i + 1;
44+
auto w = y(i + 1) / sqrt((n + 1) * (n + 2));
45+
total += w;
46+
y.coeffRef(i) = (total - z_ref(i + 1)) * sqrt(n * (n + 1)) / n;
47+
}
48+
49+
return y;
3350
}
3451

3552
/**
3653
* Overload of `sum_to_zero_free()` to untransform each Eigen vector
3754
* in a standard vector.
3855
* @tparam T A standard vector with with a `value_type` which inherits from
3956
* `Eigen::MatrixBase` with compile time rows or columns equal to 1.
40-
* @param x The standard vector to untransform.
57+
* @param z The standard vector to untransform.
4158
*/
4259
template <typename T, require_std_vector_t<T>* = nullptr>
43-
auto sum_to_zero_free(const T& x) {
60+
auto sum_to_zero_free(const T& z) {
4461
return apply_vector_unary<T>::apply(
45-
x, [](auto&& v) { return sum_to_zero_free(v); });
62+
z, [](auto&& v) { return sum_to_zero_free(v); });
4663
}
4764

4865
} // namespace math

stan/math/rev/constraint/sum_to_zero_constrain.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <stan/math/rev/core/arena_matrix.hpp>
77
#include <stan/math/rev/fun/value_of.hpp>
88
#include <stan/math/prim/fun/Eigen.hpp>
9+
#include <stan/math/prim/constraint/sum_to_zero_constrain.hpp>
910
#include <cmath>
1011
#include <tuple>
1112
#include <vector>

test/unit/math/prim/constraint/sum_to_zero_transform_test.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ TEST(prob_transform, sum_to_zero_rt0) {
1111
std::vector<Matrix<double, Dynamic, 1>> x_vec{x, x, x};
1212
std::vector<Matrix<double, Dynamic, 1>> y_vec
1313
= stan::math::sum_to_zero_constrain<false>(x_vec, lp);
14+
EXPECT_NO_THROW(stan::math::check_sum_to_zero("checkSumToZero", "y", y_vec));
15+
1416
for (auto&& y_i : y_vec) {
1517
EXPECT_MATRIX_FLOAT_EQ(Eigen::VectorXd::Zero(5), y_i);
1618
}

0 commit comments

Comments
 (0)