Skip to content

Commit 9a934dc

Browse files
committed
Avoid need to modify adj by re-working algebraically
1 parent ad4d067 commit 9a934dc

3 files changed

Lines changed: 19 additions & 15 deletions

File tree

stan/math/rev/constraint/simplex_constrain.hpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,15 @@ inline auto simplex_constrain(const T& y, scalar_type_t<T>& lp) {
8787
reverse_pass_callback([arena_y, arena_x, lp]() mutable {
8888
auto&& res_val = arena_x.val();
8989

90-
// backprop for log jacobian contribution to log density
91-
arena_x.adj().array() += lp.adj() / res_val.array();
90+
// backprop for log jacobian contribution to log density is equivalent to
91+
// arena_x.adj().array() += lp.adj() / res_val.array();
92+
// but is folded into the following to avoid needing to modify the adjoints
93+
// in-place
9294

9395
// backprop for softmax
94-
Eigen::VectorXd x_pre_softmax_adj = -res_val * arena_x.adj().dot(res_val)
95-
+ res_val.cwiseProduct(arena_x.adj());
96+
Eigen::VectorXd x_pre_softmax_adj
97+
= -res_val * (arena_x.adj().dot(res_val) + res_val.size() * lp.adj())
98+
+ (res_val.cwiseProduct(arena_x.adj()).array() + lp.adj()).matrix();
9699

97100
// backprop for sum_to_zero_constrain
98101
internal::sum_to_zero_vector_backprop(arena_y.adj(), x_pre_softmax_adj);

stan/math/rev/constraint/stochastic_column_constrain.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -95,18 +95,18 @@ inline plain_type_t<T> stochastic_column_constrain(const T& y,
9595
const auto M = arena_y.cols();
9696

9797
auto&& x_val = arena_x.val_op();
98-
99-
// backprop for log jacobian contribution to log density
100-
arena_x.adj().array() += lp.adj() / x_val.array();
101-
10298
auto&& x_adj = arena_x.adj_op();
10399

100+
const auto x_val_rows = x_val.rows();
101+
104102
Eigen::VectorXd x_pre_softmax_adj(x_val.rows());
105103
for (Eigen::Index i = 0; i < M; ++i) {
106104
// backprop for softmax
107105
x_pre_softmax_adj.noalias()
108-
= -x_val.col(i) * x_adj.col(i).dot(x_val.col(i))
109-
+ x_val.col(i).cwiseProduct(x_adj.col(i));
106+
= -x_val.col(i)
107+
* (x_adj.col(i).dot(x_val.col(i)) + lp.adj() * x_val_rows)
108+
+ (x_val.col(i).cwiseProduct(x_adj.col(i)).array() + lp.adj())
109+
.matrix();
110110

111111
// backprop for sum_to_zero_constrain
112112
internal::sum_to_zero_vector_backprop(arena_y.col(i).adj(),

stan/math/rev/constraint/stochastic_row_constrain.hpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,17 +93,18 @@ inline plain_type_t<T> stochastic_row_constrain(const T& y,
9393
const auto N = arena_y.rows();
9494

9595
auto&& x_val = arena_x.val_op();
96-
// backprop for log jacobian contribution to log density
97-
arena_x.adj().array() += lp.adj() / x_val.array();
98-
9996
auto&& x_adj = arena_x.adj_op();
10097

98+
const auto x_val_cols = x_val.cols();
99+
101100
Eigen::VectorXd x_pre_softmax_adj(x_val.cols());
102101
for (Eigen::Index i = 0; i < N; ++i) {
103102
// backprop for softmax
104103
x_pre_softmax_adj.noalias()
105-
= -x_val.row(i) * x_adj.row(i).dot(x_val.row(i))
106-
+ x_val.row(i).cwiseProduct(x_adj.row(i));
104+
= -x_val.row(i)
105+
* (x_adj.row(i).dot(x_val.row(i)) + lp.adj() * x_val_cols)
106+
+ (x_val.row(i).cwiseProduct(x_adj.row(i)).array() + lp.adj())
107+
.matrix();
107108

108109
// backprop for sum_to_zero_constrain
109110
internal::sum_to_zero_vector_backprop(arena_y.row(i).adj(),

0 commit comments

Comments
 (0)