Skip to content

Commit 62094ed

Browse files
committed
breakup value_type and scalar_type for Eigen types based on whether they have a Scalar type trait
1 parent 3f065fa commit 62094ed

2 files changed

Lines changed: 51 additions & 2 deletions

File tree

stan/math/prim/meta/is_eigen.hpp

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,28 @@ template <typename T>
2020
struct is_eigen
2121
: bool_constant<is_base_pointer_convertible<Eigen::EigenBase, T>::value> {};
2222

23+
namespace internal {
24+
// primary template handles types that have no nested ::type member:
25+
template <class, class = void>
26+
struct has_internal_trait : std::false_type {};
27+
28+
// specialization recognizes types that do have a nested ::type member:
29+
template <class T>
30+
struct has_internal_trait<T, std::void_t<Eigen::internal::traits<std::decay_t<T>>>>
31+
: std::true_type {};
32+
33+
// primary template handles types that have no nested ::type member:
34+
template <class, class = void>
35+
struct has_scalar_trait : std::false_type {};
36+
37+
// specialization recognizes types that do have a nested ::type member:
38+
template <class T>
39+
struct has_scalar_trait<T, std::void_t<typename std::decay_t<T>::Scalar>>
40+
: std::true_type {};
41+
42+
} // namespace internal
43+
44+
2345
/**
2446
* Template metaprogram defining the base scalar type of
2547
* values stored in an Eigen matrix.
@@ -28,7 +50,7 @@ struct is_eigen
2850
* @ingroup type_trait
2951
*/
3052
template <typename T>
31-
struct scalar_type<T, std::enable_if_t<is_eigen<T>::value>> {
53+
struct scalar_type<T, std::enable_if_t<is_eigen<T>::value && internal::has_scalar_trait<T>::value>> {
3254
using type = scalar_type_t<typename std::decay_t<T>::Scalar>;
3355
};
3456

@@ -40,10 +62,36 @@ struct scalar_type<T, std::enable_if_t<is_eigen<T>::value>> {
4062
* @ingroup type_trait
4163
*/
4264
template <typename T>
43-
struct value_type<T, std::enable_if_t<is_eigen<T>::value>> {
65+
struct value_type<T, std::enable_if_t<is_eigen<T>::value && internal::has_scalar_trait<T>::value>> {
66+
using type = typename std::decay_t<T>::Scalar;
67+
};
68+
69+
70+
/**
71+
* Template metaprogram defining the base scalar type of
72+
* values stored in an Eigen matrix.
73+
*
74+
* @tparam T type to check.
75+
* @ingroup type_trait
76+
*/
77+
template <typename T>
78+
struct scalar_type<T, std::enable_if_t<is_eigen<T>::value && !internal::has_scalar_trait<T>::value>> {
79+
using type = scalar_type_t<typename Eigen::internal::traits<std::decay_t<T>>::Scalar>;
80+
};
81+
82+
/**
83+
* Template metaprogram defining the type of values stored in an
84+
* Eigen matrix, vector, or row vector.
85+
*
86+
* @tparam T type to check
87+
* @ingroup type_trait
88+
*/
89+
template <typename T>
90+
struct value_type<T, std::enable_if_t<is_eigen<T>::value && !internal::has_scalar_trait<T>::value>> {
4491
using type = typename Eigen::internal::traits<std::decay_t<T>>::Scalar;
4592
};
4693

94+
4795
/*! \ingroup require_eigens_types */
4896
/*! \defgroup eigen_types eigen */
4997
/*! \addtogroup eigen_types */

stan/math/rev/core/arena_matrix.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,7 @@ namespace internal {
482482
template <typename T>
483483
struct traits<stan::math::arena_matrix<T>> {
484484
using base = traits<Eigen::Map<T>>;
485+
using Scalar = typename base::Scalar;
485486
using XprKind = typename Eigen::internal::traits<std::decay_t<T>>::XprKind;
486487
enum {
487488
PlainObjectTypeInnerSize = base::PlainObjectTypeInnerSize,

0 commit comments

Comments
 (0)