Skip to content

Commit d2b5f26

Browse files
committed
Feature/issue-3107 add beta negative binomial lpmf with test
1 parent 9b1f08b commit d2b5f26

3 files changed

Lines changed: 319 additions & 0 deletions

File tree

stan/math/prim/prob.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include <stan/math/prim/prob/beta_lccdf.hpp>
2626
#include <stan/math/prim/prob/beta_lcdf.hpp>
2727
#include <stan/math/prim/prob/beta_lpdf.hpp>
28+
#include <stan/math/prim/prob/beta_neg_binomial_lpmf.hpp>
2829
#include <stan/math/prim/prob/beta_proportion_ccdf_log.hpp>
2930
#include <stan/math/prim/prob/beta_proportion_cdf_log.hpp>
3031
#include <stan/math/prim/prob/beta_proportion_lccdf.hpp>
Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
#ifndef STAN_MATH_PRIM_PROB_BETA_NEG_BINOMIAL_LPMF_HPP
2+
#define STAN_MATH_PRIM_PROB_BETA_NEG_BINOMIAL_LPMF_HPP
3+
4+
#include <stan/math/prim/meta.hpp>
5+
#include <stan/math/prim/err.hpp>
6+
#include <stan/math/prim/fun/constants.hpp>
7+
#include <stan/math/prim/fun/digamma.hpp>
8+
#include <stan/math/prim/fun/lbeta.hpp>
9+
#include <stan/math/prim/fun/lgamma.hpp>
10+
#include <stan/math/prim/fun/max_size.hpp>
11+
#include <stan/math/prim/fun/scalar_seq_view.hpp>
12+
#include <stan/math/prim/fun/size.hpp>
13+
#include <stan/math/prim/fun/size_zero.hpp>
14+
#include <stan/math/prim/fun/value_of.hpp>
15+
#include <stan/math/prim/functor/partials_propagator.hpp>
16+
17+
namespace stan {
18+
namespace math {
19+
20+
/** \ingroup prob_dists
21+
* Returns the log PMF of the Beta Negative Binomial distribution with given
22+
* number of successes, prior success, and prior failure parameters.
23+
* Given containers of matching sizes, returns the log sum of probabilities.
24+
*
25+
* @tparam T_n type of failure parameter
26+
* @tparam T_r type of number of successes parameter
27+
* @tparam T_alpha type of prior success parameter
28+
* @tparam T_beta type of prior failure parameter
29+
*
30+
* @param n failure parameter
31+
* @param r Number of successes parameter
32+
* @param alpha prior success parameter
33+
* @param beta prior failure parameter
34+
* @return log probability or log sum of probabilities
35+
* @throw std::domain_error if r, alpha, or beta fails to be positive
36+
* @throw std::invalid_argument if container sizes mismatch
37+
*/
38+
template <bool propto, typename T_n, typename T_r, typename T_alpha,
39+
typename T_beta,
40+
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
41+
T_n, T_r, T_alpha, T_beta>* = nullptr>
42+
inline return_type_t<T_r, T_alpha, T_beta> beta_neg_binomial_lpmf(
43+
const T_n& n, const T_r& r, const T_alpha& alpha, const T_beta& beta) {
44+
using T_partials_return = partials_return_t<T_n, T_r, T_alpha, T_beta>;
45+
using T_n_ref = ref_type_t<T_n>;
46+
using T_r_ref = ref_type_t<T_r>;
47+
using T_alpha_ref = ref_type_t<T_alpha>;
48+
using T_beta_ref = ref_type_t<T_beta>;
49+
static constexpr const char* function = "beta_neg_binomial_lpmf";
50+
check_consistent_sizes(
51+
function, "Failures variable", n, "Number of successes parameter", r,
52+
"Prior success parameter", alpha, "Prior failure parameter", beta);
53+
if (size_zero(n, r, alpha, beta)) {
54+
return 0.0;
55+
}
56+
57+
T_n_ref n_ref = n;
58+
T_r_ref r_ref = r;
59+
T_alpha_ref alpha_ref = alpha;
60+
T_beta_ref beta_ref = beta;
61+
check_nonnegative(function, "Failures variable", n_ref);
62+
check_positive_finite(function, "Number of successes parameter", r_ref);
63+
check_positive_finite(function, "Prior success parameter", alpha_ref);
64+
check_positive_finite(function, "Prior failure parameter", beta_ref);
65+
66+
if (!include_summand<propto, T_r, T_alpha, T_beta>::value) {
67+
return 0.0;
68+
}
69+
70+
T_partials_return logp(0.0);
71+
auto ops_partials = make_partials_propagator(r_ref, alpha_ref, beta_ref);
72+
73+
scalar_seq_view<T_n> n_vec(n);
74+
scalar_seq_view<T_r_ref> r_vec(r_ref);
75+
scalar_seq_view<T_alpha_ref> alpha_vec(alpha_ref);
76+
scalar_seq_view<T_beta_ref> beta_vec(beta_ref);
77+
size_t size_n = stan::math::size(n);
78+
size_t size_r = stan::math::size(r);
79+
size_t size_alpha = stan::math::size(alpha);
80+
size_t size_beta = stan::math::size(beta);
81+
size_t size_n_r = max_size(n, r);
82+
size_t size_r_alpha = max_size(r, alpha);
83+
size_t size_n_beta = max_size(n, beta);
84+
size_t size_alpha_beta = max_size(alpha, beta);
85+
size_t max_size_seq_view = max_size(n, r, alpha, beta);
86+
87+
VectorBuilder<include_summand<propto>::value, T_partials_return, T_n>
88+
normalizing_constant(size_n);
89+
for (size_t i = 0; i < size_n; i++)
90+
if (include_summand<propto>::value)
91+
normalizing_constant[i] = -lgamma(n_vec[i] + 1);
92+
93+
VectorBuilder<true, T_partials_return, T_r, T_alpha> lbeta_denominator(
94+
size_r_alpha);
95+
for (size_t i = 0; i < size_r_alpha; i++) {
96+
lbeta_denominator[i] = lbeta(r_vec.val(i), alpha_vec.val(i));
97+
}
98+
99+
VectorBuilder<true, T_partials_return, T_beta> lgamma_denominator(size_beta);
100+
for (size_t i = 0; i < size_beta; i++) {
101+
lgamma_denominator[i] = lgamma(beta_vec.val(i));
102+
}
103+
104+
VectorBuilder<true, T_partials_return, T_n, T_beta> lgamma_numerator(
105+
size_n_beta);
106+
for (size_t i = 0; i < size_n_beta; i++) {
107+
lgamma_numerator[i] = lgamma(n_vec[i] + beta_vec.val(i));
108+
}
109+
110+
VectorBuilder<true, T_partials_return, T_n, T_r, T_alpha, T_beta> lbeta_diff(
111+
max_size_seq_view);
112+
for (size_t i = 0; i < max_size_seq_view; i++) {
113+
lbeta_diff[i]
114+
= lbeta(n_vec[i] + r_vec.val(i), alpha_vec.val(i) + beta_vec.val(i))
115+
+ lgamma_numerator[i] - lbeta_denominator[i] - lgamma_denominator[i];
116+
}
117+
118+
// derivative w.r.t. r, alpha and beta
119+
120+
VectorBuilder<!is_constant_all<T_r, T_alpha, T_beta>::value,
121+
T_partials_return, T_n, T_r, T_alpha, T_beta>
122+
digamma_n_r_alpha_beta(max_size_seq_view);
123+
if (!is_constant_all<T_r, T_alpha, T_beta>::value) {
124+
for (size_t i = 0; i < max_size_seq_view; i++) {
125+
digamma_n_r_alpha_beta[i] = digamma(n_vec[i] + r_vec.val(i)
126+
+ alpha_vec.val(i) + beta_vec.val(i));
127+
}
128+
}
129+
130+
VectorBuilder<!is_constant_all<T_alpha, T_beta>::value, T_partials_return,
131+
T_alpha, T_beta>
132+
digamma_alpha_beta(size_alpha_beta);
133+
if (!is_constant_all<T_alpha, T_beta>::value) {
134+
for (size_t i = 0; i < size_alpha_beta; i++) {
135+
digamma_alpha_beta[i] = digamma(alpha_vec.val(i) + beta_vec.val(i));
136+
}
137+
}
138+
139+
VectorBuilder<!is_constant_all<T_r>::value, T_partials_return, T_n, T_r>
140+
digamma_n_r(size_n_r);
141+
if (!is_constant_all<T_r>::value) {
142+
for (size_t i = 0; i < size_n_r; i++) {
143+
digamma_n_r[i] = digamma(n_vec[i] + r_vec.val(i));
144+
}
145+
}
146+
147+
VectorBuilder<!is_constant_all<T_r, T_alpha>::value, T_partials_return, T_r,
148+
T_alpha>
149+
digamma_r_alpha(size_r_alpha);
150+
if (!is_constant_all<T_r, T_alpha>::value) {
151+
for (size_t i = 0; i < size_r_alpha; i++) {
152+
digamma_r_alpha[i] = digamma(r_vec.val(i) + alpha_vec.val(i));
153+
}
154+
}
155+
156+
VectorBuilder<!is_constant_all<T_beta>::value, T_partials_return, T_n,
157+
T_beta>
158+
digamma_n_beta(size_n_beta);
159+
if (!is_constant_all<T_n, T_beta>::value) {
160+
for (size_t i = 0; i < size_n_beta; i++) {
161+
digamma_n_beta[i] = digamma(n_vec[i] + beta_vec.val(i));
162+
}
163+
}
164+
165+
VectorBuilder<!is_constant_all<T_r>::value, T_partials_return, T_r> digamma_r(
166+
size_r);
167+
if (!is_constant_all<T_r>::value) {
168+
for (size_t i = 0; i < size_r; i++) {
169+
digamma_r[i] = digamma(r_vec.val(i));
170+
}
171+
}
172+
173+
VectorBuilder<!is_constant_all<T_alpha>::value, T_partials_return, T_alpha>
174+
digamma_alpha(size_alpha);
175+
if (!is_constant_all<T_alpha>::value) {
176+
for (size_t i = 0; i < size_alpha; i++) {
177+
digamma_alpha[i] = digamma(alpha_vec.val(i));
178+
}
179+
}
180+
181+
VectorBuilder<!is_constant_all<T_beta>::value, T_partials_return, T_beta>
182+
digamma_beta(size_beta);
183+
if (!is_constant_all<T_beta>::value) {
184+
for (size_t i = 0; i < size_beta; i++) {
185+
digamma_beta[i] = digamma(beta_vec.val(i));
186+
}
187+
}
188+
189+
for (size_t i = 0; i < max_size_seq_view; i++) {
190+
if (include_summand<propto>::value)
191+
logp += normalizing_constant[i];
192+
logp += lbeta_diff[i];
193+
194+
if (!is_constant_all<T_r>::value) {
195+
partials<0>(ops_partials)[i] += digamma_n_r[i] - digamma_n_r_alpha_beta[i]
196+
- (digamma_r[i] - digamma_r_alpha[i]);
197+
}
198+
if (!is_constant_all<T_alpha>::value) {
199+
partials<1>(ops_partials)[i] += digamma_alpha_beta[i]
200+
- digamma_n_r_alpha_beta[i]
201+
- (digamma_alpha[i] - digamma_r_alpha[i]);
202+
}
203+
if (!is_constant_all<T_beta>::value) {
204+
partials<2>(ops_partials)[i] += digamma_alpha_beta[i]
205+
- digamma_n_r_alpha_beta[i]
206+
+ digamma_n_beta[i] - digamma_beta[i];
207+
}
208+
}
209+
return ops_partials.build(logp);
210+
}
211+
212+
template <typename T_n, typename T_r, typename T_alpha, typename T_beta>
213+
inline return_type_t<T_r, T_alpha, T_beta> beta_neg_binomial_lpmf(
214+
const T_n& n, const T_r& r, const T_alpha& alpha, const T_beta& beta) {
215+
return beta_neg_binomial_lpmf<false>(n, r, alpha, beta);
216+
}
217+
218+
} // namespace math
219+
} // namespace stan
220+
#endif
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
// Arguments: Ints, Doubles, Doubles, Doubles
2+
#include <stan/math/prim/prob/beta_neg_binomial_lpmf.hpp>
3+
#include <stan/math/prim/fun/lbeta.hpp>
4+
#include <stan/math/prim/fun/lgamma.hpp>
5+
6+
using stan::math::var;
7+
using std::numeric_limits;
8+
using std::vector;
9+
10+
class AgradDistributionsBetaNegBinomial : public AgradDistributionTest {
11+
public:
12+
void valid_values(vector<vector<double> >& parameters,
13+
vector<double>& log_prob) {
14+
vector<double> param(4);
15+
16+
param[0] = 5; // n
17+
param[1] = 20.0; // r
18+
param[2] = 10.0; // alpha
19+
param[3] = 25.0; // beta
20+
parameters.push_back(param);
21+
log_prob.push_back(-10.3681267949788); // expected log_prob
22+
23+
param[0] = 10; // n
24+
param[1] = 5.5; // r
25+
param[2] = 2.5; // alpha
26+
param[3] = 0.5; // beta
27+
parameters.push_back(param);
28+
log_prob.push_back(-5.166741878823932); // expected log_prob
29+
}
30+
31+
void invalid_values(vector<size_t>& index, vector<double>& value) {
32+
// n
33+
index.push_back(0U);
34+
value.push_back(-1);
35+
36+
// r
37+
index.push_back(1U);
38+
value.push_back(0.0);
39+
40+
index.push_back(1U);
41+
value.push_back(-1.0);
42+
43+
index.push_back(1U);
44+
value.push_back(std::numeric_limits<double>::infinity());
45+
46+
// alpha
47+
index.push_back(2U);
48+
value.push_back(0.0);
49+
50+
index.push_back(2U);
51+
value.push_back(-1.0);
52+
53+
index.push_back(2U);
54+
value.push_back(std::numeric_limits<double>::infinity());
55+
56+
// beta
57+
index.push_back(3U);
58+
value.push_back(0.0);
59+
60+
index.push_back(3U);
61+
value.push_back(-1.0);
62+
63+
index.push_back(3U);
64+
value.push_back(std::numeric_limits<double>::infinity());
65+
}
66+
67+
template <class T_n, class T_r, class T_size1, class T_size2, typename T4,
68+
typename T5>
69+
stan::return_type_t<T_r, T_size1, T_size2> log_prob(const T_n& n,
70+
const T_r& r,
71+
const T_size1& alpha,
72+
const T_size2& beta,
73+
const T4&, const T5&) {
74+
return stan::math::beta_neg_binomial_lpmf(n, r, alpha, beta);
75+
}
76+
77+
template <bool propto, class T_n, class T_r, class T_size1, class T_size2,
78+
typename T4, typename T5>
79+
stan::return_type_t<T_r, T_size1, T_size2> log_prob(const T_n& n,
80+
const T_r& r,
81+
const T_size1& alpha,
82+
const T_size2& beta,
83+
const T4&, const T5&) {
84+
return stan::math::beta_neg_binomial_lpmf<propto>(n, r, alpha, beta);
85+
}
86+
87+
template <class T_n, class T_r, class T_size1, class T_size2, typename T4,
88+
typename T5>
89+
stan::return_type_t<T_r, T_size1, T_size2> log_prob_function(
90+
const T_n& n, const T_r& r, const T_size1& alpha, const T_size2& beta,
91+
const T4&, const T5&) {
92+
using stan::math::lbeta;
93+
using stan::math::lgamma;
94+
95+
return lbeta(n + r, alpha + beta) - lbeta(r, alpha) + lgamma(n + beta)
96+
- lgamma(n + 1) - lgamma(beta);
97+
}
98+
};

0 commit comments

Comments
 (0)