Skip to content

Commit afea210

Browse files
authored
Merge pull request #3139 from stan-dev/normal-quantile-gradients
Improve numerical stability of normal quantile gradients
2 parents ac977e4 + d2ec33c commit afea210

2 files changed

Lines changed: 25 additions & 18 deletions

File tree

stan/math/rev/fun/inv_Phi.hpp

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33

44
#include <stan/math/rev/meta.hpp>
55
#include <stan/math/rev/core.hpp>
6-
#include <stan/math/prim/fun/constants.hpp>
6+
#include <stan/math/prim/fun/exp.hpp>
77
#include <stan/math/prim/fun/inv_Phi.hpp>
8+
#include <stan/math/prim/prob/std_normal_lpdf.hpp>
9+
#include <stan/math/prim/functor/apply_scalar_binary.hpp>
810
#include <cmath>
911

1012
namespace stan {
@@ -19,8 +21,9 @@ namespace math {
1921
* @return The unit normal inverse cdf evaluated at p
2022
*/
2123
inline var inv_Phi(const var& p) {
22-
return make_callback_var(inv_Phi(p.val()), [p](auto& vi) mutable {
23-
p.adj() += vi.adj() * SQRT_TWO_PI / std::exp(-0.5 * vi.val() * vi.val());
24+
double val = inv_Phi(p.val());
25+
return make_callback_var(val, [p, val](auto& vi) mutable {
26+
p.adj() += vi.adj() * exp(-std_normal_lpdf(val));
2427
});
2528
}
2629

@@ -33,9 +36,11 @@ inline var inv_Phi(const var& p) {
3336
*/
3437
template <typename T, require_var_matrix_t<T>* = nullptr>
3538
inline auto inv_Phi(const T& p) {
36-
return make_callback_var(inv_Phi(p.val()), [p](auto& vi) mutable {
37-
p.adj().array() += vi.adj().array() * SQRT_TWO_PI
38-
/ (-0.5 * vi.val().array().square()).exp();
39+
auto arena_rtn = to_arena(inv_Phi(p.val()));
40+
return make_callback_var(arena_rtn, [p, arena_rtn](auto& vi) mutable {
41+
auto deriv
42+
= arena_rtn.unaryExpr([](auto x) { return exp(-std_normal_lpdf(x)); });
43+
p.adj() += elt_multiply(vi.adj(), deriv);
3944
});
4045
}
4146

stan/math/rev/prob/std_normal_log_qf.hpp

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33

44
#include <stan/math/rev/meta.hpp>
55
#include <stan/math/rev/core.hpp>
6-
#include <stan/math/prim/fun/constants.hpp>
7-
#include <stan/math/prim/fun/sign.hpp>
86
#include <stan/math/prim/prob/std_normal_log_qf.hpp>
7+
#include <stan/math/prim/functor/apply_scalar_binary.hpp>
8+
#include <stan/math/prim/fun/elt_multiply.hpp>
99
#include <cmath>
1010

1111
namespace stan {
@@ -19,16 +19,18 @@ namespace math {
1919
*/
2020
template <typename T, require_stan_scalar_or_eigen_t<T>* = nullptr>
2121
inline auto std_normal_log_qf(const var_value<T>& log_p) {
22-
return make_callback_var(
23-
std_normal_log_qf(log_p.val()), [log_p](auto& vi) mutable {
24-
auto vi_array = as_array_or_scalar(vi.val());
25-
auto vi_sign = sign(as_array_or_scalar(vi.adj()));
26-
27-
const auto& deriv = as_array_or_scalar(log_p).val()
28-
+ log(as_array_or_scalar(vi.adj()) * vi_sign)
29-
- NEG_LOG_SQRT_TWO_PI + 0.5 * square(vi_array);
30-
as_array_or_scalar(log_p).adj() += vi_sign * exp(deriv);
31-
});
22+
auto arena_rtn = to_arena(std_normal_log_qf(log_p.val()));
23+
return make_callback_var(arena_rtn, [log_p, arena_rtn](auto& vi) mutable {
24+
if constexpr (is_eigen<decltype(arena_rtn)>::value) {
25+
auto deriv = exp(log_p.val() - arena_rtn.unaryExpr([](auto x) {
26+
return std_normal_lpdf(x);
27+
}));
28+
log_p.adj() += elt_multiply(vi.adj(), deriv);
29+
} else {
30+
auto deriv = exp(log_p.val() - std_normal_lpdf(arena_rtn));
31+
log_p.adj() += vi.adj() * deriv;
32+
}
33+
});
3234
}
3335

3436
} // namespace math

0 commit comments

Comments
 (0)