Skip to content

Commit 4a812be

Browse files
authored
Merge pull request #3120 from lingium/feature/issue-3119-beta-neg-binomial-cdf
add beta_neg_binomial_cdf
2 parents 0a0831d + 9750889 commit 4a812be

3 files changed

Lines changed: 263 additions & 0 deletions

File tree

stan/math/prim/prob.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include <stan/math/prim/prob/beta_lccdf.hpp>
2626
#include <stan/math/prim/prob/beta_lcdf.hpp>
2727
#include <stan/math/prim/prob/beta_lpdf.hpp>
28+
#include <stan/math/prim/prob/beta_neg_binomial_cdf.hpp>
2829
#include <stan/math/prim/prob/beta_neg_binomial_lccdf.hpp>
2930
#include <stan/math/prim/prob/beta_neg_binomial_lcdf.hpp>
3031
#include <stan/math/prim/prob/beta_neg_binomial_lpmf.hpp>
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
#ifndef STAN_MATH_PRIM_PROB_BETA_NEG_BINOMIAL_CDF_HPP
2+
#define STAN_MATH_PRIM_PROB_BETA_NEG_BINOMIAL_CDF_HPP
3+
4+
#include <stan/math/prim/meta.hpp>
5+
#include <stan/math/prim/err.hpp>
6+
#include <stan/math/prim/fun/constants.hpp>
7+
#include <stan/math/prim/fun/digamma.hpp>
8+
#include <stan/math/prim/fun/hypergeometric_3F2.hpp>
9+
#include <stan/math/prim/fun/grad_F32.hpp>
10+
#include <stan/math/prim/fun/lbeta.hpp>
11+
#include <stan/math/prim/fun/lgamma.hpp>
12+
#include <stan/math/prim/fun/max_size.hpp>
13+
#include <stan/math/prim/fun/scalar_seq_view.hpp>
14+
#include <stan/math/prim/fun/size.hpp>
15+
#include <stan/math/prim/fun/size_zero.hpp>
16+
#include <stan/math/prim/functor/partials_propagator.hpp>
17+
#include <cmath>
18+
19+
namespace stan {
20+
namespace math {
21+
22+
/** \ingroup prob_dists
23+
* Returns the CDF of the Beta-Negative Binomial distribution with given
24+
* number of successes, prior success, and prior failure parameters.
25+
* Given containers of matching sizes, returns the product of probabilities.
26+
*
27+
* @tparam T_n type of failure parameter
28+
* @tparam T_r type of number of successes parameter
29+
* @tparam T_alpha type of prior success parameter
30+
* @tparam T_beta type of prior failure parameter
31+
*
32+
* @param n failure parameter
33+
* @param r Number of successes parameter
34+
* @param alpha prior success parameter
35+
* @param beta prior failure parameter
36+
* @param precision precision for `grad_F32`, default \f$10^{-8}\f$
37+
* @param max_steps max iteration allowed for `grad_F32`, default \f$10^{8}\f$
38+
* @return probability or sum of probabilities
39+
* @throw std::domain_error if r, alpha, or beta fails to be positive
40+
* @throw std::invalid_argument if container sizes mismatch
41+
*/
42+
template <typename T_n, typename T_r, typename T_alpha, typename T_beta>
43+
inline return_type_t<T_r, T_alpha, T_beta> beta_neg_binomial_cdf(
44+
const T_n& n, const T_r& r, const T_alpha& alpha, const T_beta& beta,
45+
const double precision = 1e-8, const int max_steps = 1e8) {
46+
static constexpr const char* function = "beta_neg_binomial_cdf";
47+
check_consistent_sizes(
48+
function, "Failures variable", n, "Number of successes parameter", r,
49+
"Prior success parameter", alpha, "Prior failure parameter", beta);
50+
if (size_zero(n, r, alpha, beta)) {
51+
return 1.0;
52+
}
53+
54+
using T_r_ref = ref_type_t<T_r>;
55+
T_r_ref r_ref = r;
56+
using T_alpha_ref = ref_type_t<T_alpha>;
57+
T_alpha_ref alpha_ref = alpha;
58+
using T_beta_ref = ref_type_t<T_beta>;
59+
T_beta_ref beta_ref = beta;
60+
check_positive_finite(function, "Number of successes parameter", r_ref);
61+
check_positive_finite(function, "Prior success parameter", alpha_ref);
62+
check_positive_finite(function, "Prior failure parameter", beta_ref);
63+
64+
scalar_seq_view<T_n> n_vec(n);
65+
scalar_seq_view<T_r_ref> r_vec(r_ref);
66+
scalar_seq_view<T_alpha_ref> alpha_vec(alpha_ref);
67+
scalar_seq_view<T_beta_ref> beta_vec(beta_ref);
68+
int size_n = stan::math::size(n);
69+
size_t max_size_seq_view = max_size(n, r, alpha, beta);
70+
71+
// Explicit return for extreme values
72+
// The gradients are technically ill-defined, but treated as zero
73+
for (int i = 0; i < size_n; i++) {
74+
if (n_vec.val(i) < 0) {
75+
return 0.0;
76+
}
77+
}
78+
79+
using T_partials_return = partials_return_t<T_n, T_r, T_alpha, T_beta>;
80+
T_partials_return cdf(1.0);
81+
auto ops_partials = make_partials_propagator(r_ref, alpha_ref, beta_ref);
82+
for (size_t i = 0; i < max_size_seq_view; i++) {
83+
// Explicit return for extreme values
84+
// The gradients are technically ill-defined, but treated as zero
85+
if (n_vec.val(i) == std::numeric_limits<int>::max()) {
86+
return 1.0;
87+
}
88+
auto n_dbl = n_vec.val(i);
89+
auto r_dbl = r_vec.val(i);
90+
auto alpha_dbl = alpha_vec.val(i);
91+
auto beta_dbl = beta_vec.val(i);
92+
auto b_plus_n = beta_dbl + n_dbl;
93+
auto r_plus_n = r_dbl + n_dbl;
94+
auto a_plus_r = alpha_dbl + r_dbl;
95+
using a_t = return_type_t<decltype(b_plus_n), decltype(r_plus_n)>;
96+
using b_t = return_type_t<decltype(n_dbl), decltype(a_plus_r),
97+
decltype(b_plus_n)>;
98+
auto F = hypergeometric_3F2(
99+
std::initializer_list<a_t>{1.0, b_plus_n + 1.0, r_plus_n + 1.0},
100+
std::initializer_list<b_t>{n_dbl + 2.0, a_plus_r + b_plus_n + 1.0},
101+
1.0);
102+
auto C = lgamma(r_plus_n + 1.0) + lbeta(a_plus_r, b_plus_n + 1.0)
103+
- lgamma(r_dbl) - lbeta(alpha_dbl, beta_dbl) - lgamma(n_dbl + 2.0);
104+
auto ccdf = stan::math::exp(C + stan::math::log(F));
105+
cdf *= 1.0 - ccdf;
106+
107+
if constexpr (!is_constant_all<T_r, T_alpha, T_beta>::value) {
108+
auto chain_rule_term = -ccdf / (1.0 - ccdf);
109+
auto digamma_n_r_alpha_beta = digamma(a_plus_r + b_plus_n + 1.0);
110+
T_partials_return dF[6];
111+
grad_F32<false, !is_constant<T_beta>::value, !is_constant_all<T_r>::value,
112+
false, true, false>(dF, 1.0, b_plus_n + 1.0, r_plus_n + 1.0,
113+
n_dbl + 2.0, a_plus_r + b_plus_n + 1.0, 1.0,
114+
precision, max_steps);
115+
116+
if constexpr (!is_constant<T_r>::value || !is_constant<T_alpha>::value) {
117+
auto digamma_r_alpha = digamma(a_plus_r);
118+
if constexpr (!is_constant<T_r>::value) {
119+
partials<0>(ops_partials)[i]
120+
+= (digamma(r_plus_n + 1)
121+
+ (digamma_r_alpha - digamma_n_r_alpha_beta)
122+
+ (dF[2] + dF[4]) / F - digamma(r_dbl))
123+
* chain_rule_term;
124+
}
125+
if constexpr (!is_constant<T_alpha>::value) {
126+
partials<1>(ops_partials)[i]
127+
+= (digamma_r_alpha - digamma_n_r_alpha_beta + dF[4] / F
128+
- digamma(alpha_dbl))
129+
* chain_rule_term;
130+
}
131+
}
132+
133+
if constexpr (!is_constant<T_alpha>::value
134+
|| !is_constant<T_beta>::value) {
135+
auto digamma_alpha_beta = digamma(alpha_dbl + beta_dbl);
136+
if constexpr (!is_constant<T_alpha>::value) {
137+
partials<1>(ops_partials)[i] += digamma_alpha_beta * chain_rule_term;
138+
}
139+
if constexpr (!is_constant<T_beta>::value) {
140+
partials<2>(ops_partials)[i]
141+
+= (digamma(b_plus_n + 1) - digamma_n_r_alpha_beta
142+
+ (dF[1] + dF[4]) / F
143+
- (digamma(beta_dbl) - digamma_alpha_beta))
144+
* chain_rule_term;
145+
}
146+
}
147+
}
148+
}
149+
150+
if constexpr (!is_constant<T_r>::value) {
151+
for (size_t i = 0; i < stan::math::size(r); ++i) {
152+
partials<0>(ops_partials)[i] *= cdf;
153+
}
154+
}
155+
if constexpr (!is_constant<T_alpha>::value) {
156+
for (size_t i = 0; i < stan::math::size(alpha); ++i) {
157+
partials<1>(ops_partials)[i] *= cdf;
158+
}
159+
}
160+
if constexpr (!is_constant<T_beta>::value) {
161+
for (size_t i = 0; i < stan::math::size(beta); ++i) {
162+
partials<2>(ops_partials)[i] *= cdf;
163+
}
164+
}
165+
166+
return ops_partials.build(cdf);
167+
}
168+
169+
} // namespace math
170+
} // namespace stan
171+
#endif
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
// Arguments: Ints, Doubles, Doubles, Doubles
2+
#include <stan/math/prim/prob/beta_neg_binomial_cdf.hpp>
3+
#include <stan/math/prim/fun/lbeta.hpp>
4+
#include <stan/math/prim/fun/lgamma.hpp>
5+
6+
using stan::math::var;
7+
using std::numeric_limits;
8+
using std::vector;
9+
10+
class AgradCdfBetaNegBinomial : public AgradCdfTest {
11+
public:
12+
void valid_values(vector<vector<double>>& parameters, vector<double>& cdf) {
13+
vector<double> param(4);
14+
15+
param[0] = 0; // n
16+
param[1] = 1.0; // r
17+
param[2] = 5.0; // alpha
18+
param[3] = 1.0; // beta
19+
parameters.push_back(param);
20+
cdf.push_back(0.833333333333333); // expected cdf
21+
}
22+
23+
void invalid_values(vector<size_t>& index, vector<double>& value) {
24+
// n
25+
26+
// r
27+
index.push_back(1U);
28+
value.push_back(0.0);
29+
30+
index.push_back(1U);
31+
value.push_back(-1.0);
32+
33+
index.push_back(1U);
34+
value.push_back(std::numeric_limits<double>::infinity());
35+
36+
// alpha
37+
index.push_back(2U);
38+
value.push_back(0.0);
39+
40+
index.push_back(2U);
41+
value.push_back(-1.0);
42+
43+
index.push_back(2U);
44+
value.push_back(std::numeric_limits<double>::infinity());
45+
46+
// beta
47+
index.push_back(3U);
48+
value.push_back(0.0);
49+
50+
index.push_back(3U);
51+
value.push_back(-1.0);
52+
53+
index.push_back(3U);
54+
value.push_back(std::numeric_limits<double>::infinity());
55+
}
56+
57+
// BOUND INCLUDED IN ORDER FOR TEST TO PASS WITH CURRENT FRAMEWORK
58+
bool has_lower_bound() { return false; }
59+
60+
bool has_upper_bound() { return false; }
61+
62+
template <typename T_n, typename T_r, typename T_size1, typename T_size2,
63+
typename T4, typename T5>
64+
stan::return_type_t<T_r, T_size1, T_size2> cdf(const T_n& n, const T_r& r,
65+
const T_size1& alpha,
66+
const T_size2& beta, const T4&,
67+
const T5&) {
68+
return stan::math::beta_neg_binomial_cdf(n, r, alpha, beta);
69+
}
70+
71+
template <typename T_n, typename T_r, typename T_size1, typename T_size2,
72+
typename T4, typename T5>
73+
stan::return_type_t<T_r, T_size1, T_size2> cdf_function(
74+
const T_n& n, const T_r& r, const T_size1& alpha, const T_size2& beta,
75+
const T4&, const T5&) {
76+
using stan::math::lbeta;
77+
using stan::math::lgamma;
78+
using stan::math::log_sum_exp;
79+
using std::vector;
80+
81+
vector<stan::return_type_t<T_r, T_size1, T_size2>> lpmf_values;
82+
83+
for (int i = 0; i <= n; i++) {
84+
auto lpmf = lbeta(i + r, alpha + beta) - lbeta(r, alpha)
85+
+ lgamma(i + beta) - lgamma(i + 1) - lgamma(beta);
86+
lpmf_values.push_back(lpmf);
87+
}
88+
89+
return exp(log_sum_exp(lpmf_values));
90+
}
91+
};

0 commit comments

Comments
 (0)