33
44#include < stan/math/rev/meta.hpp>
55#include < stan/math/rev/core.hpp>
6- #include < stan/math/prim/fun/constants .hpp>
6+ #include < stan/math/prim/fun/exp .hpp>
77#include < stan/math/prim/fun/inv_Phi.hpp>
8+ #include < stan/math/prim/prob/std_normal_lpdf.hpp>
9+ #include < stan/math/prim/functor/apply_scalar_binary.hpp>
810#include < cmath>
911
1012namespace stan {
@@ -19,8 +21,9 @@ namespace math {
1921 * @return The unit normal inverse cdf evaluated at p
2022 */
2123inline var inv_Phi (const var& p) {
22- return make_callback_var (inv_Phi (p.val ()), [p](auto & vi) mutable {
23- p.adj () += vi.adj () * SQRT_TWO_PI / std::exp (-0.5 * vi.val () * vi.val ());
24+ double val = inv_Phi (p.val ());
25+ return make_callback_var (val, [p, val](auto & vi) mutable {
26+ p.adj () += vi.adj () * exp (-std_normal_lpdf (val));
2427 });
2528}
2629
@@ -33,9 +36,11 @@ inline var inv_Phi(const var& p) {
3336 */
3437template <typename T, require_var_matrix_t <T>* = nullptr >
3538inline auto inv_Phi (const T& p) {
36- return make_callback_var (inv_Phi (p.val ()), [p](auto & vi) mutable {
37- p.adj ().array () += vi.adj ().array () * SQRT_TWO_PI
38- / (-0.5 * vi.val ().array ().square ()).exp ();
39+ auto arena_rtn = to_arena (inv_Phi (p.val ()));
40+ return make_callback_var (arena_rtn, [p, arena_rtn](auto & vi) mutable {
41+ auto deriv
42+ = arena_rtn.unaryExpr ([](auto x) { return exp (-std_normal_lpdf (x)); });
43+ p.adj () += elt_multiply (vi.adj (), deriv);
3944 });
4045}
4146
0 commit comments