Skip to content

Commit 87bf545

Browse files
committed
Allow Eigen to use the apply_scalar_unary framework
1 parent 3703953 commit 87bf545

2 files changed

Lines changed: 30 additions & 27 deletions

File tree

stan/math/prim/fun/inv_logit.hpp

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -68,38 +68,41 @@ inline double inv_logit(double a) {
6868
*/
6969
struct inv_logit_fun {
7070
template <typename T>
71-
static inline auto fun(const T& x) {
72-
return inv_logit(x);
71+
static inline auto fun(T&& x) {
72+
return inv_logit(std::forward<T>(x));
7373
}
7474
};
7575

7676
/**
77-
* Vectorized version of inv_logit() for Eigen types.
77+
* Vectorized version of inv_logit() for std::vector's containing ad types.
7878
*
79-
* @tparam T A type inheriting from `Eigen::DenseBase` that does not have a
80-
* `var` scalar type.
81-
* @param x Eigen expression
79+
* @tparam T type of std::vector
80+
* @param x std::vector
8281
* @return Inverse logit applied to each value in x.
8382
*/
84-
template <typename T, require_eigen_t<T>* = nullptr,
85-
require_not_vt_var<T>* = nullptr>
86-
inline auto inv_logit(T&& x) {
87-
return std::forward<T>(x).array().logistic().matrix();
83+
template <typename Container, require_ad_container_t<Container>* = nullptr,
84+
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<Container>* = nullptr,
85+
require_not_rev_matrix_t<Container>* = nullptr>
86+
inline auto inv_logit(Container&& x) {
87+
return apply_scalar_unary<inv_logit_fun, Container>::apply(std::forward<Container>(x));
8888
}
8989

9090
/**
91-
* Vectorized version of inv_logit() for std::vector.
91+
* Vectorized version of inv_logit() for Eigen types.
9292
*
93-
* @tparam T type of std::vector
94-
* @param x std::vector
93+
* @tparam T A type of either `std::vector` whose inner type inherits from `Eigen::DenseBase` or a
94+
* type that directly inherits from `Eigen::DenseBase`. The inner scalar type must not have a
95+
* `var` scalar type.
96+
* @param x Eigen expression
9597
* @return Inverse logit applied to each value in x.
9698
*/
97-
template <typename T, require_std_vector_t<T>* = nullptr>
98-
inline auto inv_logit(T&& x) {
99-
return apply_scalar_unary<inv_logit_fun, std::decay_t<T>>::apply(
100-
std::forward<T>(x));
99+
template <typename Container,
100+
require_container_bt<std::is_arithmetic, Container>* = nullptr,
101+
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<Container>* = nullptr>
102+
inline auto inv_logit(Container&& x) {
103+
return apply_vector_unary<Container>::apply(
104+
std::forward<Container>(x), [](const auto& v) { return v.array().logistic(); });
101105
}
102-
103106
} // namespace math
104107
} // namespace stan
105108

stan/math/prim/functor/apply_scalar_unary.hpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ struct apply_scalar_unary<F, T, require_eigen_t<T>> {
5858
* @return Componentwise application of the function specified
5959
* by F to the specified matrix.
6060
*/
61-
static inline auto apply(const T& x) {
61+
static inline auto apply(const std::decay_t<T>& x) {
6262
return x.unaryExpr([](auto&& x) {
6363
return apply_scalar_unary<F, std::decay_t<decltype(x)>>::apply(x);
6464
});
@@ -69,7 +69,7 @@ struct apply_scalar_unary<F, T, require_eigen_t<T>> {
6969
* expression template of type T.
7070
*/
7171
using return_t = std::decay_t<decltype(
72-
apply_scalar_unary<F, T>::apply(std::declval<T>()))>;
72+
apply_scalar_unary<F, std::decay_t<T>>::apply(std::declval<T>()))>;
7373
};
7474

7575
/**
@@ -83,7 +83,7 @@ struct apply_scalar_unary<F, T, require_floating_point_t<T>> {
8383
/**
8484
* The return type, double.
8585
*/
86-
using return_t = std::decay_t<decltype(F::fun(std::declval<T>()))>;
86+
using return_t = std::decay_t<decltype(F::fun(std::declval<std::decay_t<T>>()))>;
8787

8888
/**
8989
* Apply the function specified by F to the specified argument.
@@ -114,11 +114,11 @@ struct apply_scalar_unary<F, T, require_complex_t<T>> {
114114
* @param x Argument scalar.
115115
* @return Result of applying F to the scalar.
116116
*/
117-
static inline auto apply(const T& x) { return F::fun(x); }
117+
static inline auto apply(const std::decay_t<T>& x) { return F::fun(x); }
118118
/**
119119
* The return type
120120
*/
121-
using return_t = std::decay_t<decltype(F::fun(std::declval<T>()))>;
121+
using return_t = std::decay_t<decltype(F::fun(std::declval<std::decay_t<T>>()))>;
122122
};
123123

124124
/**
@@ -157,13 +157,13 @@ struct apply_scalar_unary<F, T, require_integral_t<T>> {
157157
* @tparam T Type of element contained in standard vector.
158158
*/
159159
template <typename F, typename T>
160-
struct apply_scalar_unary<F, std::vector<T>> {
160+
struct apply_scalar_unary<F, T, require_std_vector_t<T>> {
161161
/**
162162
* Return type, which is calculated recursively as a standard
163163
* vector of the return type of the contained type T.
164164
*/
165165
using return_t = typename std::vector<
166-
plain_type_t<typename apply_scalar_unary<F, T>::return_t>>;
166+
plain_type_t<typename apply_scalar_unary<F, value_type_t<std::decay_t<T>>>::return_t>>;
167167

168168
/**
169169
* Apply the function specified by F elementwise to the
@@ -174,10 +174,10 @@ struct apply_scalar_unary<F, std::vector<T>> {
174174
* @return Elementwise application of F to the elements of the
175175
* container.
176176
*/
177-
static inline auto apply(const std::vector<T>& x) {
177+
static inline auto apply(const std::decay_t<T>& x) {
178178
return_t fx(x.size());
179179
for (size_t i = 0; i < x.size(); ++i) {
180-
fx[i] = apply_scalar_unary<F, T>::apply(x[i]);
180+
fx[i] = apply_scalar_unary<F, value_type_t<T>>::apply(x[i]);
181181
}
182182
return fx;
183183
}

0 commit comments

Comments
 (0)