Skip to content

Commit 01fa440

Browse files
authored
Merge pull request #3171 from stan-dev/ilr-simplex
Replace simplex transform with one based on the ILR transform
2 parents 208340f + 1c20638 commit 01fa440

9 files changed

Lines changed: 330 additions & 327 deletions

stan/math/prim/constraint/simplex_constrain.hpp

Lines changed: 83 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33

44
#include <stan/math/prim/meta.hpp>
55
#include <stan/math/prim/fun/Eigen.hpp>
6-
#include <stan/math/prim/fun/inv_logit.hpp>
6+
#include <stan/math/prim/fun/constants.hpp>
77
#include <stan/math/prim/fun/log.hpp>
8-
#include <stan/math/prim/fun/log1p_exp.hpp>
9-
#include <stan/math/prim/fun/logit.hpp>
8+
#include <stan/math/prim/fun/fmax.hpp>
9+
#include <stan/math/prim/fun/exp.hpp>
10+
#include <stan/math/prim/fun/inv_sqrt.hpp>
1011
#include <cmath>
1112

1213
namespace stan {
@@ -18,7 +19,11 @@ namespace math {
1819
* to 0 that sum to 1. A vector with (K-1) unconstrained values
1920
* will produce a simplex of size K.
2021
*
21-
* The transform is based on a centered stick-breaking process.
22+
* The simplex transform is defined using the inverse of the
23+
* isometric log ratio (ILR) transform. This code is equivalent to
24+
* `softmax(sum_to_zero_constrain(y))`, but is more efficient and
25+
* stable if computed this way thanks to the use of the online
26+
* softmax algorithm courtesy of https://arxiv.org/abs/1805.02867.
2227
*
2328
* @tparam Vec type of the vector
2429
* @param y Free vector input of dimensionality K - 1.
@@ -27,29 +32,54 @@ namespace math {
2732
template <typename Vec, require_eigen_vector_t<Vec>* = nullptr,
2833
require_not_st_var<Vec>* = nullptr>
2934
inline plain_type_t<Vec> simplex_constrain(const Vec& y) {
30-
// cut & paste simplex_constrain(Eigen::Matrix, T) w/o Jacobian
31-
using std::log;
3235
using T = value_type_t<Vec>;
36+
const auto N = y.size();
37+
38+
plain_type_t<Vec> z = Eigen::VectorXd::Zero(N + 1);
39+
if (unlikely(N == 0)) {
40+
z.coeffRef(0) = 1;
41+
return z;
42+
}
43+
44+
auto&& y_ref = to_ref(y);
45+
T sum_w(0);
46+
47+
T d(0); // sum of exponentials
48+
T max_val(0);
49+
T max_val_old(negative_infinity());
3350

34-
int Km1 = y.size();
35-
plain_type_t<Vec> x(Km1 + 1);
36-
T stick_len(1.0);
37-
for (Eigen::Index k = 0; k < Km1; ++k) {
38-
T z_k = inv_logit(y.coeff(k) - log(Km1 - k));
39-
x.coeffRef(k) = stick_len * z_k;
40-
stick_len -= x.coeff(k);
51+
for (int i = N; i > 0; --i) {
52+
double n = static_cast<double>(i);
53+
auto w = y_ref(i - 1) * inv_sqrt(n * (n + 1));
54+
sum_w += w;
55+
56+
z.coeffRef(i - 1) += sum_w;
57+
z.coeffRef(i) -= w * n;
58+
59+
max_val = fmax(max_val_old, z.coeff(i));
60+
d = d * exp(max_val_old - max_val) + exp(z.coeff(i) - max_val);
61+
max_val_old = max_val;
4162
}
42-
x.coeffRef(Km1) = stick_len;
43-
return x;
63+
64+
// above loop doesn't reach i==0
65+
max_val = fmax(max_val_old, z.coeff(0));
66+
d = d * exp(max_val_old - max_val) + exp(z.coeff(0) - max_val);
67+
68+
z.array() = (z.array() - max_val).exp() / d;
69+
70+
return z;
4471
}
4572

4673
/**
4774
* Return the simplex corresponding to the specified free vector
4875
* and increment the specified log probability reference with
4976
* the log absolute Jacobian determinant of the transform.
5077
*
51-
* The simplex transform is defined through a centered
52-
* stick-breaking process.
78+
* The simplex transform is defined using the inverse of the
79+
* isometric log ratio (ILR) transform. This code is equivalent to
80+
* `softmax(sum_to_zero_constrain(y))`, but is more efficient and
81+
* stable if computed this way thanks to the use of the online
82+
* softmax algorithm courtesy of https://arxiv.org/abs/1805.02867.
5383
*
5484
* @tparam Vec type of the vector
5585
* @tparam Lp A scalar type for the lp argument. The scalar type of Vec should
@@ -62,26 +92,46 @@ template <typename Vec, typename Lp, require_eigen_vector_t<Vec>* = nullptr,
6292
require_not_st_var<Vec>* = nullptr,
6393
require_convertible_t<value_type_t<Vec>, Lp>* = nullptr>
6494
inline plain_type_t<Vec> simplex_constrain(const Vec& y, Lp& lp) {
65-
using Eigen::Dynamic;
66-
using Eigen::Matrix;
6795
using std::log;
6896
using T = value_type_t<Vec>;
97+
const auto N = y.size();
6998

70-
int Km1 = y.size(); // K = Km1 + 1
71-
plain_type_t<Vec> x(Km1 + 1);
72-
T stick_len(1.0);
73-
for (Eigen::Index k = 0; k < Km1; ++k) {
74-
double eq_share = -log(Km1 - k); // = logit(1.0/(Km1 + 1 - k));
75-
T adj_y_k = y.coeff(k) + eq_share;
76-
T z_k = inv_logit(adj_y_k);
77-
x.coeffRef(k) = stick_len * z_k;
78-
lp += log(stick_len);
79-
lp -= log1p_exp(-adj_y_k);
80-
lp -= log1p_exp(adj_y_k);
81-
stick_len -= x.coeff(k); // equivalently *= (1 - z_k);
99+
plain_type_t<Vec> z = Eigen::VectorXd::Zero(N + 1);
100+
if (unlikely(N == 0)) {
101+
z.coeffRef(0) = 1;
102+
return z;
82103
}
83-
x.coeffRef(Km1) = stick_len; // no Jacobian contrib for last dim
84-
return x;
104+
105+
auto&& y_ref = to_ref(y);
106+
T sum_w(0);
107+
108+
T d(0); // sum of exponentials
109+
T max_val(0);
110+
T max_val_old(negative_infinity());
111+
112+
for (int i = N; i > 0; --i) {
113+
double n = static_cast<double>(i);
114+
auto w = y_ref(i - 1) * inv_sqrt(n * (n + 1));
115+
sum_w += w;
116+
117+
z.coeffRef(i - 1) += sum_w;
118+
z.coeffRef(i) -= w * n;
119+
120+
max_val = fmax(max_val_old, z.coeff(i));
121+
d = d * exp(max_val_old - max_val) + exp(z.coeff(i) - max_val);
122+
max_val_old = max_val;
123+
}
124+
125+
// above loop doesn't reach i==0
126+
max_val = fmax(max_val_old, z.coeff(0));
127+
d = d * exp(max_val_old - max_val) + exp(z.coeff(0) - max_val);
128+
129+
z.array() = (z.array() - max_val).exp() / d;
130+
131+
// equivalent to z.log().sum() + 0.5 * log(N + 1)
132+
lp += -(N + 1) * (max_val + log(d)) + 0.5 * log(N + 1);
133+
134+
return z;
85135
}
86136

87137
/**

stan/math/prim/constraint/simplex_free.hpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#include <stan/math/prim/err.hpp>
66
#include <stan/math/prim/fun/Eigen.hpp>
77
#include <stan/math/prim/fun/log.hpp>
8-
#include <stan/math/prim/fun/logit.hpp>
8+
#include <stan/math/prim/fun/sqrt.hpp>
99
#include <stan/math/prim/fun/to_ref.hpp>
1010
#include <cmath>
1111

@@ -17,8 +17,8 @@ namespace math {
1717
* the specified simplex. It applies to a simplex of dimensionality
1818
* K and produces an unconstrained vector of dimensionality (K-1).
1919
*
20-
* <p>The simplex transform is defined through a centered
21-
* stick-breaking process.
20+
* The simplex transform is defined using isometric log ratio (ILR)
21+
* transform
2222
*
2323
* @tparam ColVec type of the simplex (must be a column vector)
2424
* @param x Simplex of dimensionality K.
@@ -28,20 +28,20 @@ namespace math {
2828
*/
2929
template <typename Vec, require_eigen_vector_t<Vec>* = nullptr>
3030
inline plain_type_t<Vec> simplex_free(const Vec& x) {
31-
using std::log;
3231
using T = value_type_t<Vec>;
3332

3433
const auto& x_ref = to_ref(x);
3534
check_simplex("stan::math::simplex_free", "Simplex variable", x_ref);
3635
Eigen::Index Km1 = x_ref.size() - 1;
3736
plain_type_t<Vec> y(Km1);
38-
T stick_len = x_ref.coeff(Km1);
39-
for (Eigen::Index k = Km1; --k >= 0;) {
40-
stick_len += x_ref.coeff(k);
41-
T z_k = x_ref.coeff(k) / stick_len;
42-
y.coeffRef(k) = logit(z_k) + log(Km1 - k);
43-
// note: log(Km1 - k) = logit(1.0 / (Km1 + 1 - k));
37+
38+
T cumsum = 0.0;
39+
for (int i = 0; i < Km1; ++i) {
40+
cumsum += log(x_ref.coeff(i));
41+
double n = static_cast<double>(i + 1);
42+
y.coeffRef(i) = (cumsum - n * log(x_ref.coeff(i + 1))) / sqrt(n * (n + 1));
4443
}
44+
4545
return y;
4646
}
4747

stan/math/prim/constraint/stochastic_column_constrain.hpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,6 @@
33

44
#include <stan/math/prim/meta.hpp>
55
#include <stan/math/prim/fun/Eigen.hpp>
6-
#include <stan/math/prim/fun/inv_logit.hpp>
7-
#include <stan/math/prim/fun/log.hpp>
8-
#include <stan/math/prim/fun/log1p_exp.hpp>
9-
#include <stan/math/prim/fun/logit.hpp>
106
#include <stan/math/prim/constraint/simplex_constrain.hpp>
117
#include <cmath>
128

@@ -16,7 +12,8 @@ namespace math {
1612
/**
1713
* Return a column stochastic matrix.
1814
*
19-
* The transform is based on a centered stick-breaking process.
15+
* The transform is defined using the inverse of the
16+
* isometric log ratio (ILR) transform
2017
*
2118
* @tparam Mat type of the Matrix
2219
* @param y Free Matrix input of dimensionality (K - 1, M)
@@ -39,8 +36,8 @@ inline plain_type_t<Mat> stochastic_column_constrain(const Mat& y) {
3936
* and increment the specified log probability reference with
4037
* the log absolute Jacobian determinant of the transform.
4138
*
42-
* The simplex transform is defined through a centered
43-
* stick-breaking process.
39+
* The simplex transform is defined using the inverse of the
40+
* isometric log ratio (ILR) transform
4441
*
4542
* @tparam Mat type of the Matrix
4643
* @tparam Lp A scalar type for the lp argument. The scalar type of Mat should

stan/math/prim/constraint/stochastic_row_constrain.hpp

Lines changed: 12 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,6 @@
33

44
#include <stan/math/prim/meta.hpp>
55
#include <stan/math/prim/fun/Eigen.hpp>
6-
#include <stan/math/prim/fun/inv_logit.hpp>
7-
#include <stan/math/prim/fun/log.hpp>
8-
#include <stan/math/prim/fun/log1p_exp.hpp>
9-
#include <stan/math/prim/fun/logit.hpp>
106
#include <stan/math/prim/constraint/simplex_constrain.hpp>
117
#include <cmath>
128

@@ -16,7 +12,8 @@ namespace math {
1612
/**
1713
* Return a row stochastic matrix.
1814
*
19-
* The transform is based on a centered stick-breaking process.
15+
* The transform is defined using the inverse of the
16+
* isometric log ratio (ILR) transform
2017
*
2118
* @tparam Mat type of the Matrix
2219
* @param y Free Matrix input of dimensionality (N, K - 1).
@@ -27,23 +24,17 @@ template <typename Mat, require_eigen_matrix_dynamic_t<Mat>* = nullptr,
2724
inline plain_type_t<Mat> stochastic_row_constrain(const Mat& y) {
2825
auto&& y_ref = to_ref(y);
2926
const Eigen::Index N = y_ref.rows();
30-
int Km1 = y_ref.cols();
31-
plain_type_t<Mat> x(N, Km1 + 1);
32-
using eigen_arr = Eigen::Array<scalar_type_t<Mat>, -1, 1>;
33-
eigen_arr stick_len = eigen_arr::Constant(N, 1.0);
34-
for (Eigen::Index k = 0; k < Km1; ++k) {
35-
auto z_k = inv_logit(y_ref.array().col(k) - log(Km1 - k));
36-
x.array().col(k) = stick_len * z_k;
37-
stick_len -= x.array().col(k);
27+
plain_type_t<Mat> ret(N, y_ref.cols() + 1);
28+
for (Eigen::Index i = 0; i < N; ++i) {
29+
ret.row(i) = simplex_constrain(y_ref.row(i));
3830
}
39-
x.array().col(Km1) = stick_len;
40-
return x;
31+
return ret;
4132
}
4233

4334
/**
4435
* Return a row stochastic matrix.
45-
* The simplex transform is defined through a centered
46-
* stick-breaking process.
36+
* The simplex transform is defined using the inverse of the
37+
* isometric log ratio (ILR) transform
4738
*
4839
* @tparam Mat type of the matrix
4940
* @tparam Lp A scalar type for the lp argument. The scalar type of Mat should
@@ -59,21 +50,11 @@ template <typename Mat, typename Lp,
5950
inline plain_type_t<Mat> stochastic_row_constrain(const Mat& y, Lp& lp) {
6051
auto&& y_ref = to_ref(y);
6152
const Eigen::Index N = y_ref.rows();
62-
Eigen::Index Km1 = y_ref.cols();
63-
plain_type_t<Mat> x(N, Km1 + 1);
64-
Eigen::Array<scalar_type_t<Mat>, -1, 1> stick_len
65-
= Eigen::Array<scalar_type_t<Mat>, -1, 1>::Constant(N, 1.0);
66-
for (Eigen::Index k = 0; k < Km1; ++k) {
67-
const auto eq_share = -log(Km1 - k); // = logit(1.0/(Km1 + 1 - k));
68-
auto adj_y_k = (y_ref.array().col(k) + eq_share).eval();
69-
auto z_k = inv_logit(adj_y_k);
70-
x.array().col(k) = stick_len * z_k;
71-
lp += -sum(log1p_exp(adj_y_k)) - sum(log1p_exp(-adj_y_k))
72-
+ sum(log(stick_len));
73-
stick_len -= x.array().col(k); // equivalently *= (1 - z_k);
53+
plain_type_t<Mat> ret(N, y_ref.cols() + 1);
54+
for (Eigen::Index i = 0; i < N; ++i) {
55+
ret.row(i) = simplex_constrain(y_ref.row(i), lp);
7456
}
75-
x.col(Km1).array() = stick_len;
76-
return x;
57+
return ret;
7758
}
7859

7960
/**

0 commit comments

Comments
 (0)