55#include < stan/math/rev/meta.hpp>
66#include < stan/math/rev/core/reverse_pass_callback.hpp>
77#include < stan/math/rev/core/arena_matrix.hpp>
8- #include < stan/math/rev/fun/value_of.hpp>
9- #include < stan/math/prim/constraint/sum_to_zero_constrain.hpp>
10- #include < stan/math/prim/fun/softmax.hpp>
11- #include < stan/math/prim/fun/log_softmax.hpp>
8+ #include < stan/math/rev/constraint/sum_to_zero_constrain.hpp>
9+ #include < stan/math/prim/constraint/simplex_constrain.hpp>
1210#include < cmath>
1311#include < tuple>
1412#include < vector>
@@ -32,7 +30,30 @@ namespace math {
3230 */
3331template <typename T, require_rev_col_vector_t <T>* = nullptr >
3432inline auto simplex_constrain (const T& y) {
35- return softmax (sum_to_zero_constrain (y));
33+ using ret_type = plain_type_t <T>;
34+
35+ const auto N = y.size ();
36+ arena_t <T> arena_y = y;
37+
38+ arena_t <ret_type> arena_x = simplex_constrain (arena_y.val ());
39+
40+ if (unlikely (N == 0 )) {
41+ return ret_type (arena_x);
42+ }
43+
44+ reverse_pass_callback ([arena_y, arena_x]() mutable {
45+ const auto & res_val = to_ref (arena_x.val ());
46+
47+ Eigen::VectorXd x_pre_softmax_adj = Eigen::VectorXd::Zero (res_val.size ());
48+ // backprop for softmax
49+ x_pre_softmax_adj += -res_val * arena_x.adj ().dot (res_val)
50+ + res_val.cwiseProduct (arena_x.adj ());
51+
52+ // backprop for sum_to_zero_constrain
53+ internal::sum_to_zero_vector_backprop (arena_y, x_pre_softmax_adj);
54+ });
55+
56+ return ret_type (arena_x);
3657}
3758
3859/* *
@@ -50,16 +71,36 @@ inline auto simplex_constrain(const T& y) {
5071 * @return Simplex of dimensionality N + 1.
5172 */
5273template <typename T, require_rev_col_vector_t <T>* = nullptr >
53- auto simplex_constrain (const T& y, scalar_type_t <T>& lp) {
74+ inline auto simplex_constrain (const T& y, scalar_type_t <T>& lp) {
5475 using ret_type = plain_type_t <T>;
5576
56- arena_t <ret_type> log_x = log_softmax (sum_to_zero_constrain (y));
77+ const auto N = y.size ();
78+ arena_t <T> arena_y = y;
79+
80+ double lp_val = 0.0 ;
81+ arena_t <ret_type> arena_x = simplex_constrain (arena_y.val (), lp_val);
82+ lp += lp_val;
83+
84+ if (unlikely (N == 0 )) {
85+ return ret_type (arena_x);
86+ }
87+
88+ reverse_pass_callback ([arena_y, arena_x, lp]() mutable {
89+ const auto & res_val = to_ref (arena_x.val ());
90+
91+ // backprop for log jacobian contribution to log density
92+ arena_x.adj ().array () += lp.adj () / res_val.array ();
5793
58- const auto N = y.size () + 1 ;
94+ Eigen::VectorXd x_pre_softmax_adj = Eigen::VectorXd::Zero (res_val.size ());
95+ // backprop for softmax
96+ x_pre_softmax_adj += -res_val * arena_x.adj ().dot (res_val)
97+ + res_val.cwiseProduct (arena_x.adj ());
5998
60- lp += sum (log_x) + 0.5 * log (N);
99+ // backprop for sum_to_zero_constrain
100+ internal::sum_to_zero_vector_backprop (arena_y, x_pre_softmax_adj);
101+ });
61102
62- return ret_type (exp (log_x) );
103+ return ret_type (arena_x );
63104}
64105
65106} // namespace math
0 commit comments