Skip to content

Commit 4943a41

Browse files
committed
apply_scalary_binary -> apply_scalar_unary
1 parent 4f53279 commit 4943a41

2 files changed

Lines changed: 10 additions & 9 deletions

File tree

stan/math/rev/fun/inv_Phi.hpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,8 @@ template <typename T, require_var_matrix_t<T>* = nullptr>
3838
inline auto inv_Phi(const T& p) {
3939
const auto& arena_rtn = to_arena(inv_Phi(p.val()));
4040
return make_callback_var(arena_rtn, [p, arena_rtn](auto& vi) mutable {
41-
p.adj() += apply_scalar_binary(
42-
vi.adj(), arena_rtn.val(), [](const double adj, const double rtn_val) {
43-
return adj * exp(-std_normal_lpdf(rtn_val));
44-
});
41+
auto deriv = arena_rtn.unaryExpr([](auto x) { return exp(-std_normal_lpdf(x)); });
42+
p.adj() += elt_multiply(vi.adj(), deriv);
4543
});
4644
}
4745

stan/math/rev/prob/std_normal_log_qf.hpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,14 @@ template <typename T, require_stan_scalar_or_eigen_t<T>* = nullptr>
2121
inline auto std_normal_log_qf(const var_value<T>& log_p) {
2222
const auto& arena_rtn = to_arena(std_normal_log_qf(log_p.val()));
2323
return make_callback_var(arena_rtn, [log_p, arena_rtn](auto& vi) mutable {
24-
auto deriv = apply_scalar_binary(
25-
log_p.val(), arena_rtn, [](const auto& logp_val, const auto& rtn_val) {
26-
return exp(logp_val - std_normal_lpdf(rtn_val));
27-
});
28-
log_p.adj() += elt_multiply(vi.adj(), deriv);
24+
if constexpr (is_eigen<decltype(arena_rtn)>::value) {
25+
auto deriv = exp(log_p.val() - arena_rtn.unaryExpr([](auto x) {
26+
return std_normal_lpdf(x);}));
27+
log_p.adj() += elt_multiply(vi.adj(), deriv);
28+
} else {
29+
auto deriv = exp(log_p.val() - std_normal_lpdf(arena_rtn));
30+
log_p.adj() += vi.adj() * deriv;
31+
}
2932
});
3033
}
3134

0 commit comments

Comments
 (0)