|
| 1 | +#ifndef STAN_MATH_PRIM_PROB_BETA_NEG_BINOMIAL_LPMF_HPP |
| 2 | +#define STAN_MATH_PRIM_PROB_BETA_NEG_BINOMIAL_LPMF_HPP |
| 3 | + |
| 4 | +#include <stan/math/prim/meta.hpp> |
| 5 | +#include <stan/math/prim/err.hpp> |
| 6 | +#include <stan/math/prim/fun/constants.hpp> |
| 7 | +#include <stan/math/prim/fun/digamma.hpp> |
| 8 | +#include <stan/math/prim/fun/lbeta.hpp> |
| 9 | +#include <stan/math/prim/fun/lgamma.hpp> |
| 10 | +#include <stan/math/prim/fun/max_size.hpp> |
| 11 | +#include <stan/math/prim/fun/scalar_seq_view.hpp> |
| 12 | +#include <stan/math/prim/fun/size.hpp> |
| 13 | +#include <stan/math/prim/fun/size_zero.hpp> |
| 14 | +#include <stan/math/prim/fun/value_of.hpp> |
| 15 | +#include <stan/math/prim/functor/partials_propagator.hpp> |
| 16 | + |
| 17 | +namespace stan { |
| 18 | +namespace math { |
| 19 | + |
| 20 | +/** \ingroup prob_dists |
| 21 | + * Returns the log PMF of the Beta Negative Binomial distribution with given |
| 22 | + * number of successes, prior success, and prior failure parameters. |
| 23 | + * Given containers of matching sizes, returns the log sum of probabilities. |
| 24 | + * |
| 25 | + * @tparam T_n type of failure parameter |
| 26 | + * @tparam T_r type of number of successes parameter |
| 27 | + * @tparam T_alpha type of prior success parameter |
| 28 | + * @tparam T_beta type of prior failure parameter |
| 29 | + * |
| 30 | + * @param n failure parameter |
| 31 | + * @param r Number of successes parameter |
| 32 | + * @param alpha prior success parameter |
| 33 | + * @param beta prior failure parameter |
| 34 | + * @return log probability or log sum of probabilities |
| 35 | + * @throw std::domain_error if r, alpha, or beta fails to be positive |
| 36 | + * @throw std::invalid_argument if container sizes mismatch |
| 37 | + */ |
| 38 | +template <bool propto, typename T_n, typename T_r, typename T_alpha, |
| 39 | + typename T_beta, |
| 40 | + require_all_not_nonscalar_prim_or_rev_kernel_expression_t< |
| 41 | + T_n, T_r, T_alpha, T_beta>* = nullptr> |
| 42 | +inline return_type_t<T_r, T_alpha, T_beta> beta_neg_binomial_lpmf( |
| 43 | + const T_n& n, const T_r& r, const T_alpha& alpha, const T_beta& beta) { |
| 44 | + using T_partials_return = partials_return_t<T_n, T_r, T_alpha, T_beta>; |
| 45 | + using T_n_ref = ref_type_t<T_n>; |
| 46 | + using T_r_ref = ref_type_t<T_r>; |
| 47 | + using T_alpha_ref = ref_type_t<T_alpha>; |
| 48 | + using T_beta_ref = ref_type_t<T_beta>; |
| 49 | + static constexpr const char* function = "beta_neg_binomial_lpmf"; |
| 50 | + check_consistent_sizes( |
| 51 | + function, "Failures variable", n, "Number of successes parameter", r, |
| 52 | + "Prior success parameter", alpha, "Prior failure parameter", beta); |
| 53 | + if (size_zero(n, r, alpha, beta)) { |
| 54 | + return 0.0; |
| 55 | + } |
| 56 | + |
| 57 | + T_n_ref n_ref = n; |
| 58 | + T_r_ref r_ref = r; |
| 59 | + T_alpha_ref alpha_ref = alpha; |
| 60 | + T_beta_ref beta_ref = beta; |
| 61 | + check_nonnegative(function, "Failures variable", n_ref); |
| 62 | + check_positive_finite(function, "Number of successes parameter", r_ref); |
| 63 | + check_positive_finite(function, "Prior success parameter", alpha_ref); |
| 64 | + check_positive_finite(function, "Prior failure parameter", beta_ref); |
| 65 | + |
| 66 | + if (!include_summand<propto, T_r, T_alpha, T_beta>::value) { |
| 67 | + return 0.0; |
| 68 | + } |
| 69 | + |
| 70 | + T_partials_return logp(0.0); |
| 71 | + auto ops_partials = make_partials_propagator(r_ref, alpha_ref, beta_ref); |
| 72 | + |
| 73 | + scalar_seq_view<T_n> n_vec(n); |
| 74 | + scalar_seq_view<T_r_ref> r_vec(r_ref); |
| 75 | + scalar_seq_view<T_alpha_ref> alpha_vec(alpha_ref); |
| 76 | + scalar_seq_view<T_beta_ref> beta_vec(beta_ref); |
| 77 | + size_t size_n = stan::math::size(n); |
| 78 | + size_t size_r = stan::math::size(r); |
| 79 | + size_t size_alpha = stan::math::size(alpha); |
| 80 | + size_t size_beta = stan::math::size(beta); |
| 81 | + size_t size_n_r = max_size(n, r); |
| 82 | + size_t size_r_alpha = max_size(r, alpha); |
| 83 | + size_t size_n_beta = max_size(n, beta); |
| 84 | + size_t size_alpha_beta = max_size(alpha, beta); |
| 85 | + size_t max_size_seq_view = max_size(n, r, alpha, beta); |
| 86 | + |
| 87 | + VectorBuilder<include_summand<propto>::value, T_partials_return, T_n> |
| 88 | + normalizing_constant(size_n); |
| 89 | + for (size_t i = 0; i < size_n; i++) |
| 90 | + if (include_summand<propto>::value) |
| 91 | + normalizing_constant[i] = -lgamma(n_vec[i] + 1); |
| 92 | + |
| 93 | + VectorBuilder<true, T_partials_return, T_r, T_alpha> lbeta_denominator( |
| 94 | + size_r_alpha); |
| 95 | + for (size_t i = 0; i < size_r_alpha; i++) { |
| 96 | + lbeta_denominator[i] = lbeta(r_vec.val(i), alpha_vec.val(i)); |
| 97 | + } |
| 98 | + |
| 99 | + VectorBuilder<true, T_partials_return, T_beta> lgamma_denominator(size_beta); |
| 100 | + for (size_t i = 0; i < size_beta; i++) { |
| 101 | + lgamma_denominator[i] = lgamma(beta_vec.val(i)); |
| 102 | + } |
| 103 | + |
| 104 | + VectorBuilder<true, T_partials_return, T_n, T_beta> lgamma_numerator( |
| 105 | + size_n_beta); |
| 106 | + for (size_t i = 0; i < size_n_beta; i++) { |
| 107 | + lgamma_numerator[i] = lgamma(n_vec[i] + beta_vec.val(i)); |
| 108 | + } |
| 109 | + |
| 110 | + VectorBuilder<true, T_partials_return, T_n, T_r, T_alpha, T_beta> lbeta_diff( |
| 111 | + max_size_seq_view); |
| 112 | + for (size_t i = 0; i < max_size_seq_view; i++) { |
| 113 | + lbeta_diff[i] |
| 114 | + = lbeta(n_vec[i] + r_vec.val(i), alpha_vec.val(i) + beta_vec.val(i)) |
| 115 | + + lgamma_numerator[i] - lbeta_denominator[i] - lgamma_denominator[i]; |
| 116 | + } |
| 117 | + |
| 118 | + // derivative w.r.t. r, alpha and beta |
| 119 | + |
| 120 | + VectorBuilder<!is_constant_all<T_r, T_alpha, T_beta>::value, |
| 121 | + T_partials_return, T_n, T_r, T_alpha, T_beta> |
| 122 | + digamma_n_r_alpha_beta(max_size_seq_view); |
| 123 | + if (!is_constant_all<T_r, T_alpha, T_beta>::value) { |
| 124 | + for (size_t i = 0; i < max_size_seq_view; i++) { |
| 125 | + digamma_n_r_alpha_beta[i] = digamma(n_vec[i] + r_vec.val(i) |
| 126 | + + alpha_vec.val(i) + beta_vec.val(i)); |
| 127 | + } |
| 128 | + } |
| 129 | + |
| 130 | + VectorBuilder<!is_constant_all<T_alpha, T_beta>::value, T_partials_return, |
| 131 | + T_alpha, T_beta> |
| 132 | + digamma_alpha_beta(size_alpha_beta); |
| 133 | + if (!is_constant_all<T_alpha, T_beta>::value) { |
| 134 | + for (size_t i = 0; i < size_alpha_beta; i++) { |
| 135 | + digamma_alpha_beta[i] = digamma(alpha_vec.val(i) + beta_vec.val(i)); |
| 136 | + } |
| 137 | + } |
| 138 | + |
| 139 | + VectorBuilder<!is_constant_all<T_r>::value, T_partials_return, T_n, T_r> |
| 140 | + digamma_n_r(size_n_r); |
| 141 | + if (!is_constant_all<T_r>::value) { |
| 142 | + for (size_t i = 0; i < size_n_r; i++) { |
| 143 | + digamma_n_r[i] = digamma(n_vec[i] + r_vec.val(i)); |
| 144 | + } |
| 145 | + } |
| 146 | + |
| 147 | + VectorBuilder<!is_constant_all<T_r, T_alpha>::value, T_partials_return, T_r, |
| 148 | + T_alpha> |
| 149 | + digamma_r_alpha(size_r_alpha); |
| 150 | + if (!is_constant_all<T_r, T_alpha>::value) { |
| 151 | + for (size_t i = 0; i < size_r_alpha; i++) { |
| 152 | + digamma_r_alpha[i] = digamma(r_vec.val(i) + alpha_vec.val(i)); |
| 153 | + } |
| 154 | + } |
| 155 | + |
| 156 | + VectorBuilder<!is_constant_all<T_beta>::value, T_partials_return, T_n, |
| 157 | + T_beta> |
| 158 | + digamma_n_beta(size_n_beta); |
| 159 | + if (!is_constant_all<T_n, T_beta>::value) { |
| 160 | + for (size_t i = 0; i < size_n_beta; i++) { |
| 161 | + digamma_n_beta[i] = digamma(n_vec[i] + beta_vec.val(i)); |
| 162 | + } |
| 163 | + } |
| 164 | + |
| 165 | + VectorBuilder<!is_constant_all<T_r>::value, T_partials_return, T_r> digamma_r( |
| 166 | + size_r); |
| 167 | + if (!is_constant_all<T_r>::value) { |
| 168 | + for (size_t i = 0; i < size_r; i++) { |
| 169 | + digamma_r[i] = digamma(r_vec.val(i)); |
| 170 | + } |
| 171 | + } |
| 172 | + |
| 173 | + VectorBuilder<!is_constant_all<T_alpha>::value, T_partials_return, T_alpha> |
| 174 | + digamma_alpha(size_alpha); |
| 175 | + if (!is_constant_all<T_alpha>::value) { |
| 176 | + for (size_t i = 0; i < size_alpha; i++) { |
| 177 | + digamma_alpha[i] = digamma(alpha_vec.val(i)); |
| 178 | + } |
| 179 | + } |
| 180 | + |
| 181 | + VectorBuilder<!is_constant_all<T_beta>::value, T_partials_return, T_beta> |
| 182 | + digamma_beta(size_beta); |
| 183 | + if (!is_constant_all<T_beta>::value) { |
| 184 | + for (size_t i = 0; i < size_beta; i++) { |
| 185 | + digamma_beta[i] = digamma(beta_vec.val(i)); |
| 186 | + } |
| 187 | + } |
| 188 | + |
| 189 | + for (size_t i = 0; i < max_size_seq_view; i++) { |
| 190 | + if (include_summand<propto>::value) |
| 191 | + logp += normalizing_constant[i]; |
| 192 | + logp += lbeta_diff[i]; |
| 193 | + |
| 194 | + if (!is_constant_all<T_r>::value) { |
| 195 | + partials<0>(ops_partials)[i] += digamma_n_r[i] - digamma_n_r_alpha_beta[i] |
| 196 | + - (digamma_r[i] - digamma_r_alpha[i]); |
| 197 | + } |
| 198 | + if (!is_constant_all<T_alpha>::value) { |
| 199 | + partials<1>(ops_partials)[i] += digamma_alpha_beta[i] |
| 200 | + - digamma_n_r_alpha_beta[i] |
| 201 | + - (digamma_alpha[i] - digamma_r_alpha[i]); |
| 202 | + } |
| 203 | + if (!is_constant_all<T_beta>::value) { |
| 204 | + partials<2>(ops_partials)[i] += digamma_alpha_beta[i] |
| 205 | + - digamma_n_r_alpha_beta[i] |
| 206 | + + digamma_n_beta[i] - digamma_beta[i]; |
| 207 | + } |
| 208 | + } |
| 209 | + return ops_partials.build(logp); |
| 210 | +} |
| 211 | + |
| 212 | +template <typename T_n, typename T_r, typename T_alpha, typename T_beta> |
| 213 | +inline return_type_t<T_r, T_alpha, T_beta> beta_neg_binomial_lpmf( |
| 214 | + const T_n& n, const T_r& r, const T_alpha& alpha, const T_beta& beta) { |
| 215 | + return beta_neg_binomial_lpmf<false>(n, r, alpha, beta); |
| 216 | +} |
| 217 | + |
| 218 | +} // namespace math |
| 219 | +} // namespace stan |
| 220 | +#endif |
0 commit comments