Skip to content

Commit 65fc8e6

Browse files
authored
Merge pull request #3114 from lingium/feature/issue-3113-beta-neg-binomial-lccdf
add beta_neg_binomial_lccdf
2 parents 2fdd3ed + f2bebaf commit 65fc8e6

6 files changed

Lines changed: 355 additions & 83 deletions

File tree

stan/math/prim/fun/grad_F32.hpp

Lines changed: 80 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,20 @@ namespace math {
2424
* This power-series representation converges for all gradients
2525
* under the same conditions as the 3F2 function itself.
2626
*
27-
* @tparam T type of arguments and result
27+
* @tparam grad_a1 boolean indicating if gradient with respect to a1 is required
28+
* @tparam grad_a2 boolean indicating if gradient with respect to a2 is required
29+
* @tparam grad_a3 boolean indicating if gradient with respect to a3 is required
30+
* @tparam grad_b1 boolean indicating if gradient with respect to b1 is required
31+
* @tparam grad_b2 boolean indicating if gradient with respect to b2 is required
32+
* @tparam grad_z boolean indicating if gradient with respect to z is required
33+
* @tparam T1 a scalar type
34+
* @tparam T2 a scalar type
35+
* @tparam T3 a scalar type
36+
* @tparam T4 a scalar type
37+
* @tparam T5 a scalar type
38+
* @tparam T6 a scalar type
39+
* @tparam T7 a scalar type
40+
* @tparam T8 a scalar type
2841
* @param[out] g g pointer to array of six values of type T, result.
2942
* @param[in] a1 a1 see generalized hypergeometric function definition.
3043
* @param[in] a2 a2 see generalized hypergeometric function definition.
@@ -35,84 +48,96 @@ namespace math {
3548
* @param[in] precision precision of the infinite sum
3649
* @param[in] max_steps number of steps to take
3750
*/
38-
template <typename T>
39-
void grad_F32(T* g, const T& a1, const T& a2, const T& a3, const T& b1,
40-
const T& b2, const T& z, const T& precision = 1e-6,
51+
template <bool grad_a1 = true, bool grad_a2 = true, bool grad_a3 = true,
52+
bool grad_b1 = true, bool grad_b2 = true, bool grad_z = true,
53+
typename T1, typename T2, typename T3, typename T4, typename T5,
54+
typename T6, typename T7, typename T8 = double>
55+
void grad_F32(T1* g, const T2& a1, const T3& a2, const T4& a3, const T5& b1,
56+
const T6& b2, const T7& z, const T8& precision = 1e-6,
4157
int max_steps = 1e5) {
4258
check_3F2_converges("grad_F32", a1, a2, a3, b1, b2, z);
4359

44-
using std::exp;
45-
using std::fabs;
46-
using std::log;
47-
4860
for (int i = 0; i < 6; ++i) {
4961
g[i] = 0.0;
5062
}
5163

52-
T log_g_old[6];
64+
T1 log_g_old[6];
5365
for (auto& x : log_g_old) {
5466
x = NEGATIVE_INFTY;
5567
}
5668

57-
T log_t_old = 0.0;
58-
T log_t_new = 0.0;
69+
T1 log_t_old = 0.0;
70+
T1 log_t_new = 0.0;
5971

60-
T log_z = log(z);
72+
T7 log_z = log(z);
6173

62-
double log_t_new_sign = 1.0;
63-
double log_t_old_sign = 1.0;
64-
double log_g_old_sign[6];
74+
T1 log_t_new_sign = 1.0;
75+
T1 log_t_old_sign = 1.0;
76+
T1 log_g_old_sign[6];
6577
for (int i = 0; i < 6; ++i) {
6678
log_g_old_sign[i] = 1.0;
6779
}
68-
80+
std::array<T1, 6> term{0};
6981
for (int k = 0; k <= max_steps; ++k) {
70-
T p = (a1 + k) * (a2 + k) * (a3 + k) / ((b1 + k) * (b2 + k) * (1 + k));
82+
T1 p = (a1 + k) * (a2 + k) * (a3 + k) / ((b1 + k) * (b2 + k) * (1 + k));
7183
if (p == 0) {
7284
return;
7385
}
7486

7587
log_t_new += log(fabs(p)) + log_z;
7688
log_t_new_sign = p >= 0.0 ? log_t_new_sign : -log_t_new_sign;
89+
if constexpr (grad_a1) {
90+
term[0]
91+
= log_g_old_sign[0] * log_t_old_sign * exp(log_g_old[0] - log_t_old)
92+
+ inv(a1 + k);
93+
log_g_old[0] = log_t_new + log(fabs(term[0]));
94+
log_g_old_sign[0] = term[0] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
95+
g[0] += log_g_old_sign[0] * exp(log_g_old[0]);
96+
}
97+
98+
if constexpr (grad_a2) {
99+
term[1]
100+
= log_g_old_sign[1] * log_t_old_sign * exp(log_g_old[1] - log_t_old)
101+
+ inv(a2 + k);
102+
log_g_old[1] = log_t_new + log(fabs(term[1]));
103+
log_g_old_sign[1] = term[1] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
104+
g[1] += log_g_old_sign[1] * exp(log_g_old[1]);
105+
}
106+
107+
if constexpr (grad_a3) {
108+
term[2]
109+
= log_g_old_sign[2] * log_t_old_sign * exp(log_g_old[2] - log_t_old)
110+
+ inv(a3 + k);
111+
log_g_old[2] = log_t_new + log(fabs(term[2]));
112+
log_g_old_sign[2] = term[2] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
113+
g[2] += log_g_old_sign[2] * exp(log_g_old[2]);
114+
}
115+
116+
if constexpr (grad_b1) {
117+
term[3]
118+
= log_g_old_sign[3] * log_t_old_sign * exp(log_g_old[3] - log_t_old)
119+
- inv(b1 + k);
120+
log_g_old[3] = log_t_new + log(fabs(term[3]));
121+
log_g_old_sign[3] = term[3] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
122+
g[3] += log_g_old_sign[3] * exp(log_g_old[3]);
123+
}
124+
125+
if constexpr (grad_b2) {
126+
term[4]
127+
= log_g_old_sign[4] * log_t_old_sign * exp(log_g_old[4] - log_t_old)
128+
- inv(b2 + k);
129+
log_g_old[4] = log_t_new + log(fabs(term[4]));
130+
log_g_old_sign[4] = term[4] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
131+
g[4] += log_g_old_sign[4] * exp(log_g_old[4]);
132+
}
77133

78-
// g_old[0] = t_new * (g_old[0] / t_old + 1.0 / (a1 + k));
79-
T term = log_g_old_sign[0] * log_t_old_sign * exp(log_g_old[0] - log_t_old)
80-
+ inv(a1 + k);
81-
log_g_old[0] = log_t_new + log(fabs(term));
82-
log_g_old_sign[0] = term >= 0.0 ? log_t_new_sign : -log_t_new_sign;
83-
84-
// g_old[1] = t_new * (g_old[1] / t_old + 1.0 / (a2 + k));
85-
term = log_g_old_sign[1] * log_t_old_sign * exp(log_g_old[1] - log_t_old)
86-
+ inv(a2 + k);
87-
log_g_old[1] = log_t_new + log(fabs(term));
88-
log_g_old_sign[1] = term >= 0.0 ? log_t_new_sign : -log_t_new_sign;
89-
90-
// g_old[2] = t_new * (g_old[2] / t_old + 1.0 / (a3 + k));
91-
term = log_g_old_sign[2] * log_t_old_sign * exp(log_g_old[2] - log_t_old)
92-
+ inv(a3 + k);
93-
log_g_old[2] = log_t_new + log(fabs(term));
94-
log_g_old_sign[2] = term >= 0.0 ? log_t_new_sign : -log_t_new_sign;
95-
96-
// g_old[3] = t_new * (g_old[3] / t_old - 1.0 / (b1 + k));
97-
term = log_g_old_sign[3] * log_t_old_sign * exp(log_g_old[3] - log_t_old)
98-
- inv(b1 + k);
99-
log_g_old[3] = log_t_new + log(fabs(term));
100-
log_g_old_sign[3] = term >= 0.0 ? log_t_new_sign : -log_t_new_sign;
101-
102-
// g_old[4] = t_new * (g_old[4] / t_old - 1.0 / (b2 + k));
103-
term = log_g_old_sign[4] * log_t_old_sign * exp(log_g_old[4] - log_t_old)
104-
- inv(b2 + k);
105-
log_g_old[4] = log_t_new + log(fabs(term));
106-
log_g_old_sign[4] = term >= 0.0 ? log_t_new_sign : -log_t_new_sign;
107-
108-
// g_old[5] = t_new * (g_old[5] / t_old + 1.0 / z);
109-
term = log_g_old_sign[5] * log_t_old_sign * exp(log_g_old[5] - log_t_old)
110-
+ inv(z);
111-
log_g_old[5] = log_t_new + log(fabs(term));
112-
log_g_old_sign[5] = term >= 0.0 ? log_t_new_sign : -log_t_new_sign;
113-
114-
for (int i = 0; i < 6; ++i) {
115-
g[i] += log_g_old_sign[i] * exp(log_g_old[i]);
134+
if constexpr (grad_z) {
135+
term[5]
136+
= log_g_old_sign[5] * log_t_old_sign * exp(log_g_old[5] - log_t_old)
137+
+ inv(z);
138+
log_g_old[5] = log_t_new + log(fabs(term[5]));
139+
log_g_old_sign[5] = term[5] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
140+
g[5] += log_g_old_sign[5] * exp(log_g_old[5]);
116141
}
117142

118143
if (log_t_new <= log(precision)) {

stan/math/prim/fun/grad_pFq.hpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,11 @@ template <bool calc_a = true, bool calc_b = true, bool calc_z = true,
8989
typename T_Rtn = return_type_t<Ta, Tb, Tz>,
9090
typename Ta_Rtn = promote_scalar_t<T_Rtn, plain_type_t<Ta>>,
9191
typename Tb_Rtn = promote_scalar_t<T_Rtn, plain_type_t<Tb>>>
92-
std::tuple<Ta_Rtn, Tb_Rtn, T_Rtn> grad_pFq(const TpFq& pfq_val, const Ta& a,
93-
const Tb& b, const Tz& z,
94-
double precision = 1e-14,
95-
int max_steps = 1e6) {
92+
inline std::tuple<Ta_Rtn, Tb_Rtn, T_Rtn> grad_pFq(const TpFq& pfq_val,
93+
const Ta& a, const Tb& b,
94+
const Tz& z,
95+
double precision = 1e-14,
96+
int max_steps = 1e6) {
9697
using std::max;
9798
using Ta_Array = Eigen::Array<return_type_t<Ta>, -1, 1>;
9899
using Tb_Array = Eigen::Array<return_type_t<Tb>, -1, 1>;

stan/math/prim/fun/hypergeometric_3F2.hpp

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,36 +17,34 @@ namespace stan {
1717
namespace math {
1818
namespace internal {
1919
template <typename Ta, typename Tb, typename Tz,
20-
typename T_return = return_type_t<Ta, Tb, Tz>,
21-
typename ArrayAT = Eigen::Array<scalar_type_t<Ta>, 3, 1>,
22-
typename ArrayBT = Eigen::Array<scalar_type_t<Ta>, 3, 1>,
2320
require_all_vector_t<Ta, Tb>* = nullptr,
2421
require_stan_scalar_t<Tz>* = nullptr>
25-
T_return hypergeometric_3F2_infsum(const Ta& a, const Tb& b, const Tz& z,
26-
double precision = 1e-6,
27-
int max_steps = 1e5) {
28-
ArrayAT a_array = as_array_or_scalar(a);
29-
ArrayBT b_array = append_row(as_array_or_scalar(b), 1.0);
22+
inline return_type_t<Ta, Tb, Tz> hypergeometric_3F2_infsum(
23+
const Ta& a, const Tb& b, const Tz& z, double precision = 1e-6,
24+
int max_steps = 1e5) {
25+
using T_return = return_type_t<Ta, Tb, Tz>;
26+
Eigen::Array<scalar_type_t<Ta>, 3, 1> a_array = as_array_or_scalar(a);
27+
Eigen::Array<scalar_type_t<Tb>, 3, 1> b_array
28+
= append_row(as_array_or_scalar(b), 1.0);
3029
check_3F2_converges("hypergeometric_3F2", a_array[0], a_array[1], a_array[2],
3130
b_array[0], b_array[1], z);
3231

3332
T_return t_acc = 1.0;
3433
T_return log_t = 0.0;
35-
T_return log_z = log(fabs(z));
36-
Eigen::ArrayXi a_signs = sign(value_of_rec(a_array));
37-
Eigen::ArrayXi b_signs = sign(value_of_rec(b_array));
38-
plain_type_t<decltype(a_array)> apk = a_array;
39-
plain_type_t<decltype(b_array)> bpk = b_array;
34+
auto log_z = log(fabs(z));
35+
Eigen::Array<int, 3, 1> a_signs = sign(value_of_rec(a_array));
36+
Eigen::Array<int, 3, 1> b_signs = sign(value_of_rec(b_array));
4037
int z_sign = sign(value_of_rec(z));
4138
int t_sign = z_sign * a_signs.prod() * b_signs.prod();
4239

4340
int k = 0;
44-
while (k <= max_steps && log_t >= log(precision)) {
41+
const double log_precision = log(precision);
42+
while (k <= max_steps && log_t >= log_precision) {
4543
// Replace zero values with 1 prior to taking the log so that we accumulate
4644
// 0.0 rather than -inf
47-
const auto& abs_apk = math::fabs((apk == 0).select(1.0, apk));
48-
const auto& abs_bpk = math::fabs((bpk == 0).select(1.0, bpk));
49-
T_return p = sum(log(abs_apk)) - sum(log(abs_bpk));
45+
const auto& abs_apk = math::fabs((a_array == 0).select(1.0, a_array));
46+
const auto& abs_bpk = math::fabs((b_array == 0).select(1.0, b_array));
47+
auto p = sum(log(abs_apk)) - sum(log(abs_bpk));
5048
if (p == NEGATIVE_INFTY) {
5149
return t_acc;
5250
}
@@ -59,10 +57,10 @@ T_return hypergeometric_3F2_infsum(const Ta& a, const Tb& b, const Tz& z,
5957
"overflow hypergeometric function did not converge.");
6058
}
6159
k++;
62-
apk.array() += 1.0;
63-
bpk.array() += 1.0;
64-
a_signs = sign(value_of_rec(apk));
65-
b_signs = sign(value_of_rec(bpk));
60+
a_array += 1.0;
61+
b_array += 1.0;
62+
a_signs = sign(value_of_rec(a_array));
63+
b_signs = sign(value_of_rec(b_array));
6664
t_sign = a_signs.prod() * b_signs.prod() * t_sign;
6765
}
6866
if (k == max_steps) {
@@ -115,7 +113,7 @@ T_return hypergeometric_3F2_infsum(const Ta& a, const Tb& b, const Tz& z,
115113
template <typename Ta, typename Tb, typename Tz,
116114
require_all_vector_t<Ta, Tb>* = nullptr,
117115
require_stan_scalar_t<Tz>* = nullptr>
118-
auto hypergeometric_3F2(const Ta& a, const Tb& b, const Tz& z) {
116+
inline auto hypergeometric_3F2(const Ta& a, const Tb& b, const Tz& z) {
119117
check_3F2_converges("hypergeometric_3F2", a[0], a[1], a[2], b[0], b[1], z);
120118
// Boost's pFq throws convergence errors in some cases, fallback to naive
121119
// infinite-sum approach (tests pass for these)
@@ -143,8 +141,9 @@ auto hypergeometric_3F2(const Ta& a, const Tb& b, const Tz& z) {
143141
*/
144142
template <typename Ta, typename Tb, typename Tz,
145143
require_all_stan_scalar_t<Ta, Tb, Tz>* = nullptr>
146-
auto hypergeometric_3F2(const std::initializer_list<Ta>& a,
147-
const std::initializer_list<Tb>& b, const Tz& z) {
144+
inline auto hypergeometric_3F2(const std::initializer_list<Ta>& a,
145+
const std::initializer_list<Tb>& b,
146+
const Tz& z) {
148147
return hypergeometric_3F2(std::vector<Ta>(a), std::vector<Tb>(b), z);
149148
}
150149

stan/math/prim/prob.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include <stan/math/prim/prob/beta_lccdf.hpp>
2626
#include <stan/math/prim/prob/beta_lcdf.hpp>
2727
#include <stan/math/prim/prob/beta_lpdf.hpp>
28+
#include <stan/math/prim/prob/beta_neg_binomial_lccdf.hpp>
2829
#include <stan/math/prim/prob/beta_neg_binomial_lpmf.hpp>
2930
#include <stan/math/prim/prob/beta_proportion_ccdf_log.hpp>
3031
#include <stan/math/prim/prob/beta_proportion_cdf_log.hpp>

0 commit comments

Comments
 (0)