@@ -41,12 +41,13 @@ inline plain_type_t<T> stochastic_column_constrain(const T& y) {
4141 reverse_pass_callback ([arena_y, arena_x]() mutable {
4242 const auto M = arena_y.cols ();
4343
44- const auto & x_val = to_ref ( arena_x.val_op () );
45- const auto & x_adj = to_ref ( arena_x.adj_op () );
44+ auto && x_val = arena_x.val_op ();
45+ auto && x_adj = arena_x.adj_op ();
4646
47+ Eigen::VectorXd x_pre_softmax_adj (x_val.rows ());
4748 for (Eigen::Index i = 0 ; i < M; ++i) {
4849 // backprop for softmax
49- Eigen::VectorXd x_pre_softmax_adj
50+ x_pre_softmax_adj. noalias ()
5051 = -x_val.col (i) * x_adj.col (i).dot (x_val.col (i))
5152 + x_val.col (i).cwiseProduct (x_adj.col (i));
5253
@@ -93,16 +94,17 @@ inline plain_type_t<T> stochastic_column_constrain(const T& y,
9394 reverse_pass_callback ([arena_y, arena_x, lp]() mutable {
9495 const auto M = arena_y.cols ();
9596
96- const auto & x_val = to_ref ( arena_x.val_op () );
97+ auto && x_val = arena_x.val_op ();
9798
9899 // backprop for log jacobian contribution to log density
99100 arena_x.adj ().array () += lp.adj () / x_val.array ();
100101
101- const auto & x_adj = to_ref ( arena_x.adj_op () );
102+ auto && x_adj = arena_x.adj_op ();
102103
104+ Eigen::VectorXd x_pre_softmax_adj (x_val.rows ());
103105 for (Eigen::Index i = 0 ; i < M; ++i) {
104106 // backprop for softmax
105- Eigen::VectorXd x_pre_softmax_adj
107+ x_pre_softmax_adj. noalias ()
106108 = -x_val.col (i) * x_adj.col (i).dot (x_val.col (i))
107109 + x_val.col (i).cwiseProduct (x_adj.col (i));
108110
0 commit comments