Skip to content

Commit 0258fca

Browse files
committed
fixes signatures and rev for inv_logit
1 parent afea210 commit 0258fca

2 files changed

Lines changed: 47 additions & 14 deletions

File tree

stan/math/prim/fun/inv_logit.hpp

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,14 @@ namespace math {
4949
* @return Inverse logit of argument.
5050
*/
5151
inline double inv_logit(double a) {
52-
using std::exp;
5352
if (a < 0) {
54-
double exp_a = exp(a);
53+
double exp_a = std::exp(a);
5554
if (a < LOG_EPSILON) {
5655
return exp_a;
5756
}
58-
return exp_a / (1 + exp_a);
57+
return exp_a / (1.0 + exp_a);
5958
}
60-
return inv(1 + exp(-a));
59+
return inv(1 + std::exp(-a));
6160
}
6261

6362
/**
@@ -74,22 +73,31 @@ struct inv_logit_fun {
7473
}
7574
};
7675

76+
7777
/**
78-
* Vectorized version of inv_logit().
78+
* Vectorized version of inv_logit() for Eigen types with arithmetic value type.
7979
*
80-
* @tparam T type of container
81-
* @param x container
80+
* @tparam T type of Eigen expression
81+
* @param x Eigen expression
8282
* @return Inverse logit applied to each value in x.
8383
*/
84-
template <
85-
typename T, require_not_var_matrix_t<T>* = nullptr,
86-
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<T>* = nullptr>
87-
inline auto inv_logit(const T& x) {
88-
return apply_scalar_unary<inv_logit_fun, T>::apply(x);
84+
template <typename T, require_eigen_t<T>* = nullptr,
85+
require_not_vt_var<T>* = nullptr>
86+
inline auto inv_logit(T&& x) {
87+
return std::forward<T>(x).array().logistic().matrix();
8988
}
9089

91-
// TODO(Tadej): Eigen is introducing their implementation logistic() of this
92-
// in 3.4. Use that once we switch to Eigen 3.4
90+
/**
91+
* Vectorized version of inv_logit() for std::vector.
92+
*
93+
* @tparam T type of std::vector
94+
* @param x std::vector
95+
* @return Inverse logit applied to each value in x.
96+
*/
97+
template <typename T, require_std_vector_t<T>* = nullptr>
98+
inline auto inv_logit(T&& x) {
99+
return apply_scalar_unary<inv_logit_fun, std::decay_t<T>>::apply(std::forward<T>(x));
100+
}
93101

94102
} // namespace math
95103
} // namespace stan

stan/math/rev/fun/inv_logit.hpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,31 @@ inline auto inv_logit(const var_value<T>& a) {
3131
});
3232
}
3333

34+
/**
35+
* The inverse logit function for Eigen expressions with var value type.
36+
*
37+
* See inv_logit() for the double-based version.
38+
*
39+
* The derivative of inverse logit is
40+
*
41+
* \f$\frac{d}{dx} \mbox{logit}^{-1}(x) = \mbox{logit}^{-1}(x) (1 -
42+
* \mbox{logit}^{-1}(x))\f$.
43+
*
44+
* @tparam T type of Eigen expression
45+
* @param x Eigen expression
46+
* @return Inverse logit of argument.
47+
*/
48+
template <typename T, require_eigen_vt<is_var, T>* = nullptr>
49+
inline auto inv_logit(T&& x) {
50+
auto x_arena = to_arena(std::forward<T>(x));
51+
arena_t<T> ret = inv_logit(x_arena.val());
52+
reverse_pass_callback([x_arena, ret]() mutable {
53+
x_arena.adj().array()
54+
+= ret.adj().array() * ret.val().array() * (1.0 - ret.val().array());
55+
});
56+
return ret;
57+
}
58+
3459
} // namespace math
3560
} // namespace stan
3661
#endif

0 commit comments

Comments
 (0)