@@ -19,30 +19,47 @@ namespace math {
1919 * sum of those elements.
2020 *
2121 * @tparam ColVec a column vector type
22- * @param x Vector of length K.
22+ * @param z Vector of length K.
2323 * @return Free vector of length (K-1).
24- * @throw std::domain_error if x does not sum to zero
24+ * @throw std::domain_error if z does not sum to zero
2525 */
2626template <typename Vec, require_eigen_vector_t <Vec>* = nullptr >
27- inline plain_type_t <Vec> sum_to_zero_free (const Vec& x ) {
28- const auto & x_ref = to_ref (x );
27+ inline plain_type_t <Vec> sum_to_zero_free (const Vec& z ) {
28+ const auto & z_ref = to_ref (z );
2929 check_sum_to_zero (" stan::math::sum_to_zero_free" , " sum_to_zero variable" ,
30- x_ref );
30+ z_ref );
3131
32- return x_ref.head (x_ref.size () - 1 );
32+ const auto N = z.size () - 1 ;
33+
34+ plain_type_t <Vec> y = Eigen::VectorXd::Zero (N);
35+ if (unlikely (N == 0 )) {
36+ return y;
37+ }
38+
39+ y.coeffRef (N - 1 ) = -z_ref (N) * sqrt (N * (N + 1 )) / N;
40+ typename plain_type_t <Vec>::Scalar total (0 );
41+
42+ for (int i = N - 2 ; i >= 0 ; --i) {
43+ double n = i + 1 ;
44+ auto w = y (i + 1 ) / sqrt ((n + 1 ) * (n + 2 ));
45+ total += w;
46+ y.coeffRef (i) = (total - z_ref (i + 1 )) * sqrt (n * (n + 1 )) / n;
47+ }
48+
49+ return y;
3350}
3451
3552/* *
3653 * Overload of `sum_to_zero_free()` to untransform each Eigen vector
3754 * in a standard vector.
3855 * @tparam T A standard vector with with a `value_type` which inherits from
3956 * `Eigen::MatrixBase` with compile time rows or columns equal to 1.
40- * @param x The standard vector to untransform.
57+ * @param z The standard vector to untransform.
4158 */
4259template <typename T, require_std_vector_t <T>* = nullptr >
43- auto sum_to_zero_free (const T& x ) {
60+ auto sum_to_zero_free (const T& z ) {
4461 return apply_vector_unary<T>::apply (
45- x , [](auto && v) { return sum_to_zero_free (v); });
62+ z , [](auto && v) { return sum_to_zero_free (v); });
4663}
4764
4865} // namespace math
0 commit comments