Skip to content

Commit b93e44d

Browse files
authored
Merge pull request #3220 from lingium/feature/issue-3219-yule-simon-lpmf
add yule_simon_lpmf
2 parents 7de3318 + 533bcc2 commit b93e44d

3 files changed

Lines changed: 167 additions & 1 deletion

File tree

stan/math/prim/prob.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,5 +316,5 @@
316316
#include <stan/math/prim/prob/wishart_cholesky_rng.hpp>
317317
#include <stan/math/prim/prob/wishart_lpdf.hpp>
318318
#include <stan/math/prim/prob/wishart_rng.hpp>
319-
319+
#include <stan/math/prim/prob/yule_simon_lpmf.hpp>
320320
#endif
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
#ifndef STAN_MATH_PRIM_PROB_YULE_SIMON_LPMF_HPP
2+
#define STAN_MATH_PRIM_PROB_YULE_SIMON_LPMF_HPP
3+
4+
#include <stan/math/prim/meta.hpp>
5+
#include <stan/math/prim/err.hpp>
6+
#include <stan/math/prim/fun/constants.hpp>
7+
#include <stan/math/prim/fun/digamma.hpp>
8+
#include <stan/math/prim/fun/lbeta.hpp>
9+
#include <stan/math/prim/fun/lgamma.hpp>
10+
#include <stan/math/prim/fun/max_size.hpp>
11+
#include <stan/math/prim/fun/scalar_seq_view.hpp>
12+
#include <stan/math/prim/fun/size.hpp>
13+
#include <stan/math/prim/fun/size_zero.hpp>
14+
#include <stan/math/prim/fun/value_of.hpp>
15+
#include <stan/math/prim/functor/partials_propagator.hpp>
16+
17+
namespace stan {
18+
namespace math {
19+
20+
/** \ingroup prob_dists
21+
* Returns the log PMF of the Yule-Simon distribution with shape parameter.
22+
* Given containers of matching sizes, returns the log sum of probabilities.
23+
*
24+
* @tparam T_n type of outcome variable
25+
* @tparam T_alpha type of shape parameter
26+
*
27+
* @param n outcome variable
28+
* @param alpha shape parameter
29+
* @return log probability or log sum of probabilities
30+
* @throw std::domain_error if alpha fails to be positive
31+
* @throw std::invalid_argument if container sizes mismatch
32+
*/
33+
template <bool propto, typename T_n, typename T_alpha,
34+
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
35+
T_n, T_alpha> * = nullptr>
36+
inline return_type_t<T_alpha> yule_simon_lpmf(const T_n &n,
37+
const T_alpha &alpha) {
38+
using std::log;
39+
using T_partials_return = partials_return_t<T_n, T_alpha>;
40+
using T_n_ref = ref_type_t<T_n>;
41+
using T_alpha_ref = ref_type_t<T_alpha>;
42+
static constexpr const char *function = "yule_simon_lpmf";
43+
check_consistent_sizes(function, "Failures variable", n, "Shape parameter",
44+
alpha);
45+
if (size_zero(n, alpha)) {
46+
return 0.0;
47+
}
48+
49+
T_n_ref n_ref = n;
50+
T_alpha_ref alpha_ref = alpha;
51+
check_greater_or_equal(function, "Outcome variable", n_ref, 1);
52+
check_positive_finite(function, "Shape parameter", alpha_ref);
53+
54+
if constexpr (!include_summand<propto, T_alpha>::value) {
55+
return 0.0;
56+
}
57+
58+
auto ops_partials = make_partials_propagator(alpha_ref);
59+
60+
scalar_seq_view<T_n_ref> n_vec(n_ref);
61+
scalar_seq_view<T_alpha_ref> alpha_vec(alpha_ref);
62+
const size_t max_size_seq_view = max_size(n_ref, alpha_ref);
63+
T_partials_return logp(0.0);
64+
if constexpr (include_summand<propto>::value) {
65+
if constexpr (is_stan_scalar_v<T_n>) {
66+
logp += lgamma(n_ref) * max_size_seq_view;
67+
}
68+
}
69+
for (size_t i = 0; i < max_size_seq_view; i++) {
70+
if constexpr (include_summand<propto>::value) {
71+
if constexpr (!is_stan_scalar_v<T_n>) {
72+
logp += lgamma(n_vec.val(i));
73+
}
74+
}
75+
T_partials_return alpha_plus_one = alpha_vec.val(i) + 1.0;
76+
logp += log(alpha_vec.val(i)) + lgamma(alpha_plus_one)
77+
- lgamma(n_vec.val(i) + alpha_plus_one);
78+
if constexpr (!is_constant<T_alpha>::value) {
79+
partials<0>(ops_partials)[i] += 1.0 / alpha_vec.val(i)
80+
+ digamma(alpha_plus_one)
81+
- digamma(n_vec.val(i) + alpha_plus_one);
82+
}
83+
}
84+
return ops_partials.build(logp);
85+
}
86+
87+
template <typename T_n, typename T_alpha>
88+
inline return_type_t<T_alpha> yule_simon_lpmf(const T_n &n,
89+
const T_alpha &alpha) {
90+
return yule_simon_lpmf<false>(n, alpha);
91+
}
92+
93+
} // namespace math
94+
} // namespace stan
95+
#endif
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
// Arguments: Ints, Doubles
2+
#include <stan/math/prim/prob/yule_simon_lpmf.hpp>
3+
#include <stan/math/prim/fun/lbeta.hpp>
4+
#include <stan/math/prim/fun/lgamma.hpp>
5+
6+
using std::numeric_limits;
7+
using std::vector;
8+
9+
class AgradDistributionsYuleSimon : public AgradDistributionTest {
10+
public:
11+
void valid_values(vector<vector<double> >& parameters,
12+
vector<double>& log_prob) {
13+
vector<double> param(2);
14+
15+
param[0] = 5; // n
16+
param[1] = 20.0; // alpha
17+
parameters.push_back(param);
18+
log_prob.push_back(-9.494202658325099); // expected log_prob
19+
20+
param[0] = 10; // n
21+
param[1] = 5.5; // alpha
22+
parameters.push_back(param);
23+
log_prob.push_back(-9.108616882863778); // expected log_prob
24+
}
25+
26+
void invalid_values(vector<size_t>& index, vector<double>& value) {
27+
// n
28+
index.push_back(0U);
29+
value.push_back(-1);
30+
31+
index.push_back(0U);
32+
value.push_back(0);
33+
34+
// alpha
35+
index.push_back(1U);
36+
value.push_back(0.0);
37+
38+
index.push_back(1U);
39+
value.push_back(-1.0);
40+
41+
index.push_back(1U);
42+
value.push_back(std::numeric_limits<double>::infinity());
43+
}
44+
45+
template <class T_n, class T_alpha, typename T2, typename T3, typename T4,
46+
typename T5>
47+
stan::return_type_t<T_alpha> log_prob(const T_n& n, const T_alpha& alpha,
48+
const T2&, const T3&, const T4&,
49+
const T5&) {
50+
return stan::math::yule_simon_lpmf(n, alpha);
51+
}
52+
53+
template <bool propto, class T_n, class T_alpha, typename T2, typename T3,
54+
typename T4, typename T5>
55+
stan::return_type_t<T_alpha> log_prob(const T_n& n, const T_alpha& alpha,
56+
const T2&, const T3&, const T4&,
57+
const T5&) {
58+
return stan::math::yule_simon_lpmf<propto>(n, alpha);
59+
}
60+
61+
template <class T_n, class T_alpha, typename T2, typename T3, typename T4,
62+
typename T5>
63+
stan::return_type_t<T_alpha> log_prob_function(const T_n& n,
64+
const T_alpha& alpha,
65+
const T2&, const T3&,
66+
const T4&, const T5&) {
67+
using stan::math::lbeta;
68+
using std::log;
69+
return log(alpha) + lbeta(n, alpha + 1.0);
70+
}
71+
};

0 commit comments

Comments
 (0)