@@ -44,13 +44,12 @@ inline auto simplex_constrain(const T& y) {
4444 reverse_pass_callback ([arena_y, arena_x]() mutable {
4545 const auto & res_val = to_ref (arena_x.val ());
4646
47- Eigen::VectorXd x_pre_softmax_adj = Eigen::VectorXd::Zero (res_val.size ());
4847 // backprop for softmax
49- x_pre_softmax_adj + = -res_val * arena_x.adj ().dot (res_val)
50- + res_val.cwiseProduct (arena_x.adj ());
48+ Eigen::VectorXd x_pre_softmax_adj = -res_val * arena_x.adj ().dot (res_val)
49+ + res_val.cwiseProduct (arena_x.adj ());
5150
5251 // backprop for sum_to_zero_constrain
53- internal::sum_to_zero_vector_backprop (arena_y, x_pre_softmax_adj);
52+ internal::sum_to_zero_vector_backprop (arena_y. adj () , x_pre_softmax_adj);
5453 });
5554
5655 return ret_type (arena_x);
@@ -91,13 +90,12 @@ inline auto simplex_constrain(const T& y, scalar_type_t<T>& lp) {
9190 // backprop for log jacobian contribution to log density
9291 arena_x.adj ().array () += lp.adj () / res_val.array ();
9392
94- Eigen::VectorXd x_pre_softmax_adj = Eigen::VectorXd::Zero (res_val.size ());
9593 // backprop for softmax
96- x_pre_softmax_adj + = -res_val * arena_x.adj ().dot (res_val)
97- + res_val.cwiseProduct (arena_x.adj ());
94+ Eigen::VectorXd x_pre_softmax_adj = -res_val * arena_x.adj ().dot (res_val)
95+ + res_val.cwiseProduct (arena_x.adj ());
9896
9997 // backprop for sum_to_zero_constrain
100- internal::sum_to_zero_vector_backprop (arena_y, x_pre_softmax_adj);
98+ internal::sum_to_zero_vector_backprop (arena_y. adj () , x_pre_softmax_adj);
10199 });
102100
103101 return ret_type (arena_x);
0 commit comments