Skip to content

Commit ca1c0cb

Browse files
authored
Merge branch 'stan-dev:develop' into feature/issue-3121-beta-neg-binomial-rng
2 parents 949fddd + 4a812be commit ca1c0cb

4 files changed

Lines changed: 91 additions & 22 deletions

File tree

stan/math/prim/meta/is_eigen.hpp

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,28 @@ template <typename T>
2020
struct is_eigen
2121
: bool_constant<is_base_pointer_convertible<Eigen::EigenBase, T>::value> {};
2222

23+
namespace internal {
24+
// primary template handles types that have no nested ::type member:
25+
template <class, class = void>
26+
struct has_internal_trait : std::false_type {};
27+
28+
// specialization recognizes types that do have a nested ::type member:
29+
template <class T>
30+
struct has_internal_trait<T,
31+
std::void_t<Eigen::internal::traits<std::decay_t<T>>>>
32+
: std::true_type {};
33+
34+
// primary template handles types that have no nested ::type member:
35+
template <class, class = void>
36+
struct has_scalar_trait : std::false_type {};
37+
38+
// specialization recognizes types that do have a nested ::type member:
39+
template <class T>
40+
struct has_scalar_trait<T, std::void_t<typename std::decay_t<T>::Scalar>>
41+
: std::true_type {};
42+
43+
} // namespace internal
44+
2345
/**
2446
* Template metaprogram defining the base scalar type of
2547
* values stored in an Eigen matrix.
@@ -28,7 +50,9 @@ struct is_eigen
2850
* @ingroup type_trait
2951
*/
3052
template <typename T>
31-
struct scalar_type<T, std::enable_if_t<is_eigen<T>::value>> {
53+
struct scalar_type<T,
54+
std::enable_if_t<is_eigen<T>::value
55+
&& internal::has_scalar_trait<T>::value>> {
3256
using type = scalar_type_t<typename std::decay_t<T>::Scalar>;
3357
};
3458

@@ -40,10 +64,41 @@ struct scalar_type<T, std::enable_if_t<is_eigen<T>::value>> {
4064
* @ingroup type_trait
4165
*/
4266
template <typename T>
43-
struct value_type<T, std::enable_if_t<is_eigen<T>::value>> {
67+
struct value_type<T,
68+
std::enable_if_t<is_eigen<T>::value
69+
&& internal::has_scalar_trait<T>::value>> {
4470
using type = typename std::decay_t<T>::Scalar;
4571
};
4672

73+
/**
74+
* Template metaprogram defining the base scalar type of
75+
* values stored in an Eigen matrix.
76+
*
77+
* @tparam T type to check.
78+
* @ingroup type_trait
79+
*/
80+
template <typename T>
81+
struct scalar_type<T,
82+
std::enable_if_t<is_eigen<T>::value
83+
&& !internal::has_scalar_trait<T>::value>> {
84+
using type = scalar_type_t<
85+
typename Eigen::internal::traits<std::decay_t<T>>::Scalar>;
86+
};
87+
88+
/**
89+
* Template metaprogram defining the type of values stored in an
90+
* Eigen matrix, vector, or row vector.
91+
*
92+
* @tparam T type to check
93+
* @ingroup type_trait
94+
*/
95+
template <typename T>
96+
struct value_type<T,
97+
std::enable_if_t<is_eigen<T>::value
98+
&& !internal::has_scalar_trait<T>::value>> {
99+
using type = typename Eigen::internal::traits<std::decay_t<T>>::Scalar;
100+
};
101+
47102
/*! \ingroup require_eigens_types */
48103
/*! \defgroup eigen_types eigen */
49104
/*! \addtogroup eigen_types */

stan/math/prim/prob/beta_neg_binomial_cdf.hpp

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ inline return_type_t<T_r, T_alpha, T_beta> beta_neg_binomial_cdf(
101101
1.0);
102102
auto C = lgamma(r_plus_n + 1.0) + lbeta(a_plus_r, b_plus_n + 1.0)
103103
- lgamma(r_dbl) - lbeta(alpha_dbl, beta_dbl) - lgamma(n_dbl + 2.0);
104-
auto ccdf = stan::math::exp(C) * F;
104+
auto ccdf = stan::math::exp(C + stan::math::log(F));
105105
cdf *= 1.0 - ccdf;
106106

107107
if constexpr (!is_constant_all<T_r, T_alpha, T_beta>::value) {
@@ -116,15 +116,17 @@ inline return_type_t<T_r, T_alpha, T_beta> beta_neg_binomial_cdf(
116116
if constexpr (!is_constant<T_r>::value || !is_constant<T_alpha>::value) {
117117
auto digamma_r_alpha = digamma(a_plus_r);
118118
if constexpr (!is_constant<T_r>::value) {
119-
auto partial_lccdf = digamma(r_plus_n + 1.0)
120-
+ (digamma_r_alpha - digamma_n_r_alpha_beta)
121-
+ (dF[2] + dF[4]) / F - digamma(r_dbl);
122-
partials<0>(ops_partials)[i] += partial_lccdf * chain_rule_term;
119+
partials<0>(ops_partials)[i]
120+
+= (digamma(r_plus_n + 1)
121+
+ (digamma_r_alpha - digamma_n_r_alpha_beta)
122+
+ (dF[2] + dF[4]) / F - digamma(r_dbl))
123+
* chain_rule_term;
123124
}
124125
if constexpr (!is_constant<T_alpha>::value) {
125-
auto partial_lccdf = digamma_r_alpha - digamma_n_r_alpha_beta
126-
+ dF[4] / F - digamma(alpha_dbl);
127-
partials<1>(ops_partials)[i] += partial_lccdf * chain_rule_term;
126+
partials<1>(ops_partials)[i]
127+
+= (digamma_r_alpha - digamma_n_r_alpha_beta + dF[4] / F
128+
- digamma(alpha_dbl))
129+
* chain_rule_term;
128130
}
129131
}
130132

@@ -135,10 +137,11 @@ inline return_type_t<T_r, T_alpha, T_beta> beta_neg_binomial_cdf(
135137
partials<1>(ops_partials)[i] += digamma_alpha_beta * chain_rule_term;
136138
}
137139
if constexpr (!is_constant<T_beta>::value) {
138-
auto partial_lccdf = digamma(b_plus_n + 1.0) - digamma_n_r_alpha_beta
139-
+ (dF[1] + dF[4]) / F
140-
- (digamma(beta_dbl) - digamma_alpha_beta);
141-
partials<2>(ops_partials)[i] += partial_lccdf * chain_rule_term;
140+
partials<2>(ops_partials)[i]
141+
+= (digamma(b_plus_n + 1) - digamma_n_r_alpha_beta
142+
+ (dF[1] + dF[4]) / F
143+
- (digamma(beta_dbl) - digamma_alpha_beta))
144+
* chain_rule_term;
142145
}
143146
}
144147
}

stan/math/rev/core/arena_matrix.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,7 @@ namespace internal {
482482
template <typename T>
483483
struct traits<stan::math::arena_matrix<T>> {
484484
using base = traits<Eigen::Map<T>>;
485+
using Scalar = typename base::Scalar;
485486
using XprKind = typename Eigen::internal::traits<std::decay_t<T>>::XprKind;
486487
enum {
487488
PlainObjectTypeInnerSize = base::PlainObjectTypeInnerSize,

test/unit/math/prim/meta/value_type_test.cpp

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
#include <stan/math/prim/meta.hpp>
21
#include <test/unit/util.hpp>
2+
#include <stan/math/prim/meta.hpp>
3+
#include <unsupported/Eigen/KroneckerProduct>
34
#include <gtest/gtest.h>
45
#include <vector>
56

@@ -8,16 +9,16 @@ TEST(MathMetaPrim, value_type_vector) {
89
using std::vector;
910

1011
EXPECT_SAME_TYPE(vector<double>::value_type,
11-
value_type<vector<double> >::type);
12+
value_type<vector<double>>::type);
1213

1314
EXPECT_SAME_TYPE(vector<double>::value_type,
14-
value_type<const vector<double> >::type);
15+
value_type<const vector<double>>::type);
1516

16-
EXPECT_SAME_TYPE(vector<vector<int> >::value_type,
17-
value_type<vector<vector<int> > >::type);
17+
EXPECT_SAME_TYPE(vector<vector<int>>::value_type,
18+
value_type<vector<vector<int>>>::type);
1819

19-
EXPECT_SAME_TYPE(vector<vector<int> >::value_type,
20-
value_type<const vector<vector<int> > >::type);
20+
EXPECT_SAME_TYPE(vector<vector<int>>::value_type,
21+
value_type<const vector<vector<int>>>::type);
2122
}
2223

2324
TEST(MathMetaPrim, value_type_matrix) {
@@ -33,5 +34,14 @@ TEST(MathMetaPrim, value_type_matrix) {
3334
value_type<Eigen::RowVectorXd>::type);
3435

3536
EXPECT_SAME_TYPE(Eigen::RowVectorXd,
36-
value_type<std::vector<Eigen::RowVectorXd> >::type);
37+
value_type<std::vector<Eigen::RowVectorXd>>::type);
38+
}
39+
40+
TEST(MathMetaPrim, value_type_kronecker) {
41+
Eigen::Matrix<double, 2, 2> A;
42+
const auto B
43+
= Eigen::kroneckerProduct(A, Eigen::Matrix<double, 2, 2>::Identity());
44+
Eigen::Matrix<double, 4, 1> C = Eigen::Matrix<double, 4, 1>::Random(4, 1);
45+
EXPECT_TRUE((std::is_same<double, stan::value_type_t<decltype(B)>>::value));
46+
Eigen::MatrixXd D = B * C;
3747
}

0 commit comments

Comments
 (0)