11#ifndef STAN_MATH_REV_CONSTRAINT_STOCHASTIC_COLUMN_CONSTRAIN_HPP
22#define STAN_MATH_REV_CONSTRAINT_STOCHASTIC_COLUMN_CONSTRAIN_HPP
33
4+ #include < stan/math/prim/fun/Eigen.hpp>
45#include < stan/math/rev/meta.hpp>
56#include < stan/math/rev/core/reverse_pass_callback.hpp>
67#include < stan/math/rev/core/arena_matrix.hpp>
78#include < stan/math/rev/fun/value_of.hpp>
8- #include < stan/math/prim/fun/Eigen.hpp>
9- #include < stan/math/prim/fun/inv_logit.hpp>
10- #include < stan/math/prim/fun/log1p_exp.hpp>
9+ #include < stan/math/prim/constraint/stochastic_column_constrain.hpp>
10+ #include < stan/math/rev/constraint/sum_to_zero_constrain.hpp>
1111#include < cmath>
1212#include < tuple>
1313#include < vector>
@@ -27,44 +27,36 @@ namespace math {
2727template <typename T, require_rev_matrix_t <T>* = nullptr >
2828inline plain_type_t <T> stochastic_column_constrain (const T& y) {
2929 using ret_type = plain_type_t <T>;
30- const Eigen::Index N = y.rows ();
31- const Eigen::Index M = y.cols ();
32- using eigen_mat_rowmajor
33- = Eigen::Matrix<double , Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
34- arena_t <eigen_mat_rowmajor> x_val (N + 1 , M);
30+
31+ const auto N = y.rows ();
32+ const auto M = y.cols ();
33+ arena_t <T> arena_y = y;
34+
35+ arena_t <ret_type> arena_x = stochastic_column_constrain (arena_y.val_op ());
36+
3537 if (unlikely (N == 0 || M == 0 )) {
36- return ret_type (x_val);
37- }
38- arena_t <change_eigen_options_t <T, Eigen::RowMajor>> arena_y = y;
39- arena_t <eigen_mat_rowmajor> arena_z (N, M);
40- using arr_vec = Eigen::Array<double , 1 , -1 >;
41- arr_vec stick_len = arr_vec::Constant (M, 1.0 );
42- for (Eigen::Index k = 0 ; k < N; ++k) {
43- const double log_N_minus_k = std::log (N - k);
44- arena_z.row (k)
45- = inv_logit (arena_y.array ().row (k).val_op () - log_N_minus_k).matrix ();
46- x_val.row (k) = stick_len.array () * arena_z.array ().row (k);
47- stick_len -= x_val.array ().row (k);
38+ return arena_x;
4839 }
49- x_val. row (N) = stick_len;
50- arena_t <ret_type> arena_x = x_val;
51- reverse_pass_callback ([arena_y, arena_x, arena_z]() mutable {
52- const Eigen::Index N = arena_y. rows ();
53- auto arena_x_arr = arena_x.array ( );
54- auto arena_y_arr = arena_y. array ( );
55- auto arena_z_arr = arena_z. array ();
56- auto stick_len_val = arena_x. array (). row (N). val (). eval ();
57- auto stick_len_adj = arena_x. array (). row (N). adj (). eval ();
58- for ( Eigen::Index k = N; k-- > 0 ;) {
59- arena_x_arr. row (k). adj () -= stick_len_adj;
60- stick_len_val += arena_x_arr. row (k). val ( );
61- stick_len_adj += arena_x_arr. row (k). adj () * arena_z_arr. row (k);
62- auto arena_z_adj = arena_x_arr. row (k). adj () * stick_len_val;
63- arena_y_arr. row (k ).adj ()
64- += arena_z_adj * arena_z_arr. row (k) * ( 1.0 - arena_z_arr. row (k) );
40+
41+ reverse_pass_callback ([arena_y, arena_x]() mutable {
42+ const auto M = arena_y. cols ();
43+
44+ const auto & x_val = to_ref ( arena_x.val_op () );
45+ const auto & x_adj = to_ref (arena_x. adj_op () );
46+
47+ for (Eigen::Index i = 0 ; i < M; ++i) {
48+ // backprop for softmax
49+ Eigen::VectorXd x_pre_softmax_adj
50+ = -x_val. col (i) * x_adj. col (i). dot (x_val. col (i))
51+ + x_val. col (i). cwiseProduct (x_adj. col (i) );
52+
53+ // backprop for sum_to_zero_constrain
54+ internal::sum_to_zero_vector_backprop (arena_y. col (i ).adj (),
55+ x_pre_softmax_adj );
6556 }
6657 });
67- return ret_type (arena_x);
58+
59+ return arena_x;
6860}
6961
7062/* *
@@ -84,51 +76,43 @@ template <typename T, require_rev_matrix_t<T>* = nullptr>
8476inline plain_type_t <T> stochastic_column_constrain (const T& y,
8577 scalar_type_t <T>& lp) {
8678 using ret_type = plain_type_t <T>;
87- const Eigen::Index N = y.rows ();
88- const Eigen::Index M = y.cols ();
89- using eigen_mat_rowmajor
90- = Eigen::Matrix<double , Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
91- arena_t <eigen_mat_rowmajor> x_val (N + 1 , M);
79+
80+ const auto N = y.rows ();
81+ const auto M = y.cols ();
82+ arena_t <T> arena_y = y;
83+
84+ double lp_val = 0 ;
85+ arena_t <ret_type> arena_x
86+ = stochastic_column_constrain (arena_y.val_op (), lp_val);
87+ lp += lp_val;
88+
9289 if (unlikely (N == 0 || M == 0 )) {
93- return ret_type (x_val) ;
90+ return arena_x ;
9491 }
95- arena_t <change_eigen_options_t <T, Eigen::RowMajor>> arena_y = y;
96- arena_t <eigen_mat_rowmajor> arena_z (N, M);
97- using arr_vec = Eigen::Array<double , 1 , -1 >;
98- arr_vec stick_len = arr_vec::Constant (M, 1.0 );
99- arr_vec adj_y_k (N);
100- for (Eigen::Index k = 0 ; k < N; ++k) {
101- double log_N_minus_k = std::log (N - k);
102- adj_y_k = arena_y.array ().row (k).val () - log_N_minus_k;
103- arena_z.array ().row (k) = inv_logit (adj_y_k);
104- x_val.array ().row (k) = stick_len * arena_z.array ().row (k);
105- lp += sum (log (stick_len)) - sum (log1p_exp (-adj_y_k))
106- - sum (log1p_exp (adj_y_k));
107- stick_len -= x_val.array ().row (k);
108- }
109- x_val.array ().row (N) = stick_len;
110- arena_t <ret_type> arena_x = x_val;
111- reverse_pass_callback ([arena_y, arena_x, arena_z, lp]() mutable {
112- const Eigen::Index N = arena_y.rows ();
113- auto arena_x_arr = arena_x.array ();
114- auto arena_y_arr = arena_y.array ();
115- auto arena_z_arr = arena_z.array ();
116- auto stick_len_val = arena_x.array ().row (N).val ().eval ();
117- auto stick_len_adj = arena_x.array ().row (N).adj ().eval ();
118- for (Eigen::Index k = N; k-- > 0 ;) {
119- const double log_N_minus_k = std::log (N - k);
120- arena_x_arr.row (k).adj () -= stick_len_adj;
121- stick_len_val += arena_x_arr.row (k).val ();
122- stick_len_adj += lp.adj () / stick_len_val
123- + arena_x_arr.row (k).adj () * arena_z_arr.row (k);
124- auto adj_y_k = arena_y_arr.row (k).val () - log_N_minus_k;
125- auto arena_z_adj = arena_x_arr.row (k).adj () * stick_len_val;
126- arena_y_arr.row (k).adj ()
127- += -(lp.adj () * inv_logit (adj_y_k)) + lp.adj () * inv_logit (-adj_y_k)
128- + arena_z_adj * arena_z_arr.row (k) * (1.0 - arena_z_arr.row (k));
92+
93+ reverse_pass_callback ([arena_y, arena_x, lp]() mutable {
94+ const auto M = arena_y.cols ();
95+
96+ const auto & x_val = to_ref (arena_x.val_op ());
97+
98+ // backprop for log jacobian contribution to log density
99+ arena_x.adj ().array () += lp.adj () / x_val.array ();
100+
101+ const auto & x_adj = to_ref (arena_x.adj_op ());
102+
103+ for (Eigen::Index i = 0 ; i < M; ++i) {
104+ // backprop for softmax
105+ Eigen::VectorXd x_pre_softmax_adj
106+ = -x_val.col (i) * x_adj.col (i).dot (x_val.col (i))
107+ + x_val.col (i).cwiseProduct (x_adj.col (i));
108+
109+ // backprop for sum_to_zero_constrain
110+ internal::sum_to_zero_vector_backprop (arena_y.col (i).adj (),
111+ x_pre_softmax_adj);
129112 }
130113 });
131- return ret_type (arena_x);
114+
115+ return arena_x;
132116}
133117
134118} // namespace math
0 commit comments