Skip to content

Commit 0a0831d

Browse files
authored
Merge pull request #3124 from stan-dev/fix/kronecker_scalar
use eigen internal traits for getting the scalar type of a eigen type
2 parents 7ada875 + 283f7bf commit 0a0831d

3 files changed

Lines changed: 76 additions & 10 deletions

File tree

stan/math/prim/meta/is_eigen.hpp

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,28 @@ template <typename T>
2020
struct is_eigen
2121
: bool_constant<is_base_pointer_convertible<Eigen::EigenBase, T>::value> {};
2222

23+
namespace internal {
24+
// primary template handles types that have no nested ::type member:
25+
template <class, class = void>
26+
struct has_internal_trait : std::false_type {};
27+
28+
// specialization recognizes types that do have a nested ::type member:
29+
template <class T>
30+
struct has_internal_trait<T,
31+
std::void_t<Eigen::internal::traits<std::decay_t<T>>>>
32+
: std::true_type {};
33+
34+
// primary template handles types that have no nested ::type member:
35+
template <class, class = void>
36+
struct has_scalar_trait : std::false_type {};
37+
38+
// specialization recognizes types that do have a nested ::type member:
39+
template <class T>
40+
struct has_scalar_trait<T, std::void_t<typename std::decay_t<T>::Scalar>>
41+
: std::true_type {};
42+
43+
} // namespace internal
44+
2345
/**
2446
* Template metaprogram defining the base scalar type of
2547
* values stored in an Eigen matrix.
@@ -28,7 +50,9 @@ struct is_eigen
2850
* @ingroup type_trait
2951
*/
3052
template <typename T>
31-
struct scalar_type<T, std::enable_if_t<is_eigen<T>::value>> {
53+
struct scalar_type<T,
54+
std::enable_if_t<is_eigen<T>::value
55+
&& internal::has_scalar_trait<T>::value>> {
3256
using type = scalar_type_t<typename std::decay_t<T>::Scalar>;
3357
};
3458

@@ -40,10 +64,41 @@ struct scalar_type<T, std::enable_if_t<is_eigen<T>::value>> {
4064
* @ingroup type_trait
4165
*/
4266
template <typename T>
43-
struct value_type<T, std::enable_if_t<is_eigen<T>::value>> {
67+
struct value_type<T,
68+
std::enable_if_t<is_eigen<T>::value
69+
&& internal::has_scalar_trait<T>::value>> {
4470
using type = typename std::decay_t<T>::Scalar;
4571
};
4672

73+
/**
74+
* Template metaprogram defining the base scalar type of
75+
* values stored in an Eigen matrix.
76+
*
77+
* @tparam T type to check.
78+
* @ingroup type_trait
79+
*/
80+
template <typename T>
81+
struct scalar_type<T,
82+
std::enable_if_t<is_eigen<T>::value
83+
&& !internal::has_scalar_trait<T>::value>> {
84+
using type = scalar_type_t<
85+
typename Eigen::internal::traits<std::decay_t<T>>::Scalar>;
86+
};
87+
88+
/**
89+
* Template metaprogram defining the type of values stored in an
90+
* Eigen matrix, vector, or row vector.
91+
*
92+
* @tparam T type to check
93+
* @ingroup type_trait
94+
*/
95+
template <typename T>
96+
struct value_type<T,
97+
std::enable_if_t<is_eigen<T>::value
98+
&& !internal::has_scalar_trait<T>::value>> {
99+
using type = typename Eigen::internal::traits<std::decay_t<T>>::Scalar;
100+
};
101+
47102
/*! \ingroup require_eigens_types */
48103
/*! \defgroup eigen_types eigen */
49104
/*! \addtogroup eigen_types */

stan/math/rev/core/arena_matrix.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,7 @@ namespace internal {
482482
template <typename T>
483483
struct traits<stan::math::arena_matrix<T>> {
484484
using base = traits<Eigen::Map<T>>;
485+
using Scalar = typename base::Scalar;
485486
using XprKind = typename Eigen::internal::traits<std::decay_t<T>>::XprKind;
486487
enum {
487488
PlainObjectTypeInnerSize = base::PlainObjectTypeInnerSize,

test/unit/math/prim/meta/value_type_test.cpp

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
#include <stan/math/prim/meta.hpp>
21
#include <test/unit/util.hpp>
2+
#include <stan/math/prim/meta.hpp>
3+
#include <unsupported/Eigen/KroneckerProduct>
34
#include <gtest/gtest.h>
45
#include <vector>
56

@@ -8,16 +9,16 @@ TEST(MathMetaPrim, value_type_vector) {
89
using std::vector;
910

1011
EXPECT_SAME_TYPE(vector<double>::value_type,
11-
value_type<vector<double> >::type);
12+
value_type<vector<double>>::type);
1213

1314
EXPECT_SAME_TYPE(vector<double>::value_type,
14-
value_type<const vector<double> >::type);
15+
value_type<const vector<double>>::type);
1516

16-
EXPECT_SAME_TYPE(vector<vector<int> >::value_type,
17-
value_type<vector<vector<int> > >::type);
17+
EXPECT_SAME_TYPE(vector<vector<int>>::value_type,
18+
value_type<vector<vector<int>>>::type);
1819

19-
EXPECT_SAME_TYPE(vector<vector<int> >::value_type,
20-
value_type<const vector<vector<int> > >::type);
20+
EXPECT_SAME_TYPE(vector<vector<int>>::value_type,
21+
value_type<const vector<vector<int>>>::type);
2122
}
2223

2324
TEST(MathMetaPrim, value_type_matrix) {
@@ -33,5 +34,14 @@ TEST(MathMetaPrim, value_type_matrix) {
3334
value_type<Eigen::RowVectorXd>::type);
3435

3536
EXPECT_SAME_TYPE(Eigen::RowVectorXd,
36-
value_type<std::vector<Eigen::RowVectorXd> >::type);
37+
value_type<std::vector<Eigen::RowVectorXd>>::type);
38+
}
39+
40+
TEST(MathMetaPrim, value_type_kronecker) {
41+
Eigen::Matrix<double, 2, 2> A;
42+
const auto B
43+
= Eigen::kroneckerProduct(A, Eigen::Matrix<double, 2, 2>::Identity());
44+
Eigen::Matrix<double, 4, 1> C = Eigen::Matrix<double, 4, 1>::Random(4, 1);
45+
EXPECT_TRUE((std::is_same<double, stan::value_type_t<decltype(B)>>::value));
46+
Eigen::MatrixXd D = B * C;
3747
}

0 commit comments

Comments
 (0)