33
44#include < stan/math/prim/meta.hpp>
55#include < stan/math/prim/fun/Eigen.hpp>
6- #include < stan/math/prim/fun/inv_logit .hpp>
6+ #include < stan/math/prim/fun/constants .hpp>
77#include < stan/math/prim/fun/log.hpp>
8- #include < stan/math/prim/fun/log1p_exp.hpp>
9- #include < stan/math/prim/fun/logit.hpp>
8+ #include < stan/math/prim/fun/fmax.hpp>
9+ #include < stan/math/prim/fun/exp.hpp>
10+ #include < stan/math/prim/fun/inv_sqrt.hpp>
1011#include < cmath>
1112
1213namespace stan {
@@ -18,7 +19,11 @@ namespace math {
1819 * to 0 that sum to 1. A vector with (K-1) unconstrained values
1920 * will produce a simplex of size K.
2021 *
21- * The transform is based on a centered stick-breaking process.
22+ * The simplex transform is defined using the inverse of the
23+ * isometric log ratio (ILR) transform. This code is equivalent to
24+ * `softmax(sum_to_zero_constrain(y))`, but is more efficient and
25+ * stable if computed this way thanks to the use of the online
26+ * softmax algorithm courtesy of https://arxiv.org/abs/1805.02867.
2227 *
2328 * @tparam Vec type of the vector
2429 * @param y Free vector input of dimensionality K - 1.
@@ -27,29 +32,54 @@ namespace math {
2732template <typename Vec, require_eigen_vector_t <Vec>* = nullptr ,
2833 require_not_st_var<Vec>* = nullptr >
2934inline plain_type_t <Vec> simplex_constrain (const Vec& y) {
30- // cut & paste simplex_constrain(Eigen::Matrix, T) w/o Jacobian
31- using std::log;
3235 using T = value_type_t <Vec>;
36+ const auto N = y.size ();
37+
38+ plain_type_t <Vec> z = Eigen::VectorXd::Zero (N + 1 );
39+ if (unlikely (N == 0 )) {
40+ z.coeffRef (0 ) = 1 ;
41+ return z;
42+ }
43+
44+ auto && y_ref = to_ref (y);
45+ T sum_w (0 );
46+
47+ T d (0 ); // sum of exponentials
48+ T max_val (0 );
49+ T max_val_old (negative_infinity ());
3350
34- int Km1 = y.size ();
35- plain_type_t <Vec> x (Km1 + 1 );
36- T stick_len (1.0 );
37- for (Eigen::Index k = 0 ; k < Km1; ++k) {
38- T z_k = inv_logit (y.coeff (k) - log (Km1 - k));
39- x.coeffRef (k) = stick_len * z_k;
40- stick_len -= x.coeff (k);
51+ for (int i = N; i > 0 ; --i) {
52+ double n = static_cast <double >(i);
53+ auto w = y_ref (i - 1 ) * inv_sqrt (n * (n + 1 ));
54+ sum_w += w;
55+
56+ z.coeffRef (i - 1 ) += sum_w;
57+ z.coeffRef (i) -= w * n;
58+
59+ max_val = fmax (max_val_old, z.coeff (i));
60+ d = d * exp (max_val_old - max_val) + exp (z.coeff (i) - max_val);
61+ max_val_old = max_val;
4162 }
42- x.coeffRef (Km1) = stick_len;
43- return x;
63+
64+ // above loop doesn't reach i==0
65+ max_val = fmax (max_val_old, z.coeff (0 ));
66+ d = d * exp (max_val_old - max_val) + exp (z.coeff (0 ) - max_val);
67+
68+ z.array () = (z.array () - max_val).exp () / d;
69+
70+ return z;
4471}
4572
4673/* *
4774 * Return the simplex corresponding to the specified free vector
4875 * and increment the specified log probability reference with
4976 * the log absolute Jacobian determinant of the transform.
5077 *
51- * The simplex transform is defined through a centered
52- * stick-breaking process.
78+ * The simplex transform is defined using the inverse of the
79+ * isometric log ratio (ILR) transform. This code is equivalent to
80+ * `softmax(sum_to_zero_constrain(y))`, but is more efficient and
81+ * stable if computed this way thanks to the use of the online
82+ * softmax algorithm courtesy of https://arxiv.org/abs/1805.02867.
5383 *
5484 * @tparam Vec type of the vector
5585 * @tparam Lp A scalar type for the lp argument. The scalar type of Vec should
@@ -62,26 +92,46 @@ template <typename Vec, typename Lp, require_eigen_vector_t<Vec>* = nullptr,
6292 require_not_st_var<Vec>* = nullptr ,
6393 require_convertible_t <value_type_t <Vec>, Lp>* = nullptr >
6494inline plain_type_t <Vec> simplex_constrain (const Vec& y, Lp& lp) {
65- using Eigen::Dynamic;
66- using Eigen::Matrix;
6795 using std::log;
6896 using T = value_type_t <Vec>;
97+ const auto N = y.size ();
6998
70- int Km1 = y.size (); // K = Km1 + 1
71- plain_type_t <Vec> x (Km1 + 1 );
72- T stick_len (1.0 );
73- for (Eigen::Index k = 0 ; k < Km1; ++k) {
74- double eq_share = -log (Km1 - k); // = logit(1.0/(Km1 + 1 - k));
75- T adj_y_k = y.coeff (k) + eq_share;
76- T z_k = inv_logit (adj_y_k);
77- x.coeffRef (k) = stick_len * z_k;
78- lp += log (stick_len);
79- lp -= log1p_exp (-adj_y_k);
80- lp -= log1p_exp (adj_y_k);
81- stick_len -= x.coeff (k); // equivalently *= (1 - z_k);
99+ plain_type_t <Vec> z = Eigen::VectorXd::Zero (N + 1 );
100+ if (unlikely (N == 0 )) {
101+ z.coeffRef (0 ) = 1 ;
102+ return z;
82103 }
83- x.coeffRef (Km1) = stick_len; // no Jacobian contrib for last dim
84- return x;
104+
105+ auto && y_ref = to_ref (y);
106+ T sum_w (0 );
107+
108+ T d (0 ); // sum of exponentials
109+ T max_val (0 );
110+ T max_val_old (negative_infinity ());
111+
112+ for (int i = N; i > 0 ; --i) {
113+ double n = static_cast <double >(i);
114+ auto w = y_ref (i - 1 ) * inv_sqrt (n * (n + 1 ));
115+ sum_w += w;
116+
117+ z.coeffRef (i - 1 ) += sum_w;
118+ z.coeffRef (i) -= w * n;
119+
120+ max_val = fmax (max_val_old, z.coeff (i));
121+ d = d * exp (max_val_old - max_val) + exp (z.coeff (i) - max_val);
122+ max_val_old = max_val;
123+ }
124+
125+ // above loop doesn't reach i==0
126+ max_val = fmax (max_val_old, z.coeff (0 ));
127+ d = d * exp (max_val_old - max_val) + exp (z.coeff (0 ) - max_val);
128+
129+ z.array () = (z.array () - max_val).exp () / d;
130+
131+ // equivalent to z.log().sum() + 0.5 * log(N + 1)
132+ lp += -(N + 1 ) * (max_val + log (d)) + 0.5 * log (N + 1 );
133+
134+ return z;
85135}
86136
87137/* *
0 commit comments