Skip to content

Commit 0731eb3

Browse files
committed
uses constexpr for whether to do extra calculations
1 parent 11ba289 commit 0731eb3

1 file changed

Lines changed: 39 additions & 48 deletions

File tree

stan/math/prim/prob/beta_neg_binomial_lpmf.hpp

Lines changed: 39 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -63,71 +63,62 @@ inline return_type_t<T_r, T_alpha, T_beta> beta_neg_binomial_lpmf(
6363
check_positive_finite(function, "Prior success parameter", alpha_ref);
6464
check_positive_finite(function, "Prior failure parameter", beta_ref);
6565

66-
if (!include_summand<propto, T_r, T_alpha, T_beta>::value) {
66+
if constexpr (!include_summand<propto, T_r, T_alpha, T_beta>::value) {
6767
return 0.0;
6868
}
6969

70-
T_partials_return logp(0.0);
7170
auto ops_partials = make_partials_propagator(r_ref, alpha_ref, beta_ref);
7271

7372
scalar_seq_view<T_n> n_vec(n);
7473
scalar_seq_view<T_r_ref> r_vec(r_ref);
7574
scalar_seq_view<T_alpha_ref> alpha_vec(alpha_ref);
7675
scalar_seq_view<T_beta_ref> beta_vec(beta_ref);
77-
size_t size_n = stan::math::size(n);
78-
size_t size_r = stan::math::size(r);
79-
size_t size_alpha = stan::math::size(alpha);
80-
size_t size_beta = stan::math::size(beta);
81-
size_t size_n_r = max_size(n, r);
82-
size_t size_r_alpha = max_size(r, alpha);
83-
size_t size_n_beta = max_size(n, beta);
84-
size_t size_alpha_beta = max_size(alpha, beta);
85-
size_t max_size_seq_view = max_size(n, r, alpha, beta);
86-
76+
const size_t max_size_seq_view = max_size(n, r, alpha, beta);
77+
T_partials_return logp(0.0);
8778
for (size_t i = 0; i < max_size_seq_view; i++) {
88-
const T_partials_return lbeta_denominator
79+
if constexpr (include_summand<propto>::value) {
80+
logp -= lgamma(n_vec[i] + 1);
81+
}
82+
T_partials_return lbeta_denominator
8983
= lbeta(r_vec.val(i), alpha_vec.val(i));
90-
const T_partials_return lgamma_numerator
84+
T_partials_return lgamma_numerator
9185
= lgamma(n_vec[i] + beta_vec.val(i));
92-
const T_partials_return lgamma_denominator = lgamma(beta_vec.val(i));
93-
const T_partials_return lbeta_numerator
86+
T_partials_return lgamma_denominator = lgamma(beta_vec.val(i));
87+
T_partials_return lbeta_numerator
9488
= lbeta(n_vec[i] + r_vec.val(i), alpha_vec.val(i) + beta_vec.val(i));
95-
if (include_summand<propto>::value) {
96-
logp -= lgamma(n_vec[i] + 1);
97-
}
9889
logp += lbeta_numerator + lgamma_numerator - lbeta_denominator
9990
- lgamma_denominator;
91+
if (!is_constant_all<T_r, T_alpha, T_beta>::value) {
92+
T_partials_return digamma_n_r_alpha_beta
93+
= digamma(n_vec[i] + r_vec.val(i) + alpha_vec.val(i)
94+
+ beta_vec.val(i));
10095

101-
T_partials_return digamma_n_r_alpha_beta
102-
= is_constant_all<T_r, T_alpha, T_beta>::value
103-
? 0
104-
: digamma(n_vec[i] + r_vec.val(i) + alpha_vec.val(i)
105-
+ beta_vec.val(i));
106-
107-
T_partials_return digamma_r_alpha
108-
= is_constant_all<T_r, T_alpha>::value
109-
? 0
110-
: digamma(r_vec.val(i) + alpha_vec.val(i));
111-
112-
T_partials_return digamma_alpha_beta
113-
= is_constant_all<T_alpha, T_beta>::value
114-
? 0
115-
: digamma(alpha_vec.val(i) + beta_vec.val(i));
96+
if constexpr (!is_constant<T_r>::value || !is_constant<T_alpha>::value) {
97+
T_partials_return digamma_r_alpha = digamma(r_vec.val(i) + alpha_vec.val(i));
98+
if constexpr (!is_constant_all<T_r>::value) {
99+
partials<0>(ops_partials)[i]
100+
+= digamma(n_vec[i] + r_vec.val(i)) - digamma_n_r_alpha_beta
101+
- (digamma(r_vec.val(i)) - digamma_r_alpha);
102+
}
103+
if constexpr (!is_constant_all<T_alpha>::value) {
104+
partials<1>(ops_partials)[i]
105+
+= -digamma_n_r_alpha_beta
106+
- (digamma(alpha_vec.val(i)) - digamma_r_alpha);
107+
}
108+
}
109+
if constexpr (!is_constant<T_beta>::value || !is_constant<T_alpha>::value) {
110+
T_partials_return digamma_alpha_beta
111+
= digamma(alpha_vec.val(i) + beta_vec.val(i));
112+
if constexpr (!is_constant_all<T_beta>::value) {
113+
partials<2>(ops_partials)[i]
114+
+= digamma_alpha_beta - digamma_n_r_alpha_beta
115+
+ digamma(n_vec[i] + beta_vec.val(i)) - digamma(beta_vec.val(i));
116+
}
117+
if constexpr (!is_constant_all<T_alpha>::value) {
118+
partials<1>(ops_partials)[i] += digamma_alpha_beta;
119+
}
120+
}
116121

117-
if (!is_constant_all<T_r>::value) {
118-
partials<0>(ops_partials)[i]
119-
+= digamma(n_vec[i] + r_vec.val(i)) - digamma_n_r_alpha_beta
120-
- (digamma(r_vec.val(i)) - digamma_r_alpha);
121-
}
122-
if (!is_constant_all<T_alpha>::value) {
123-
partials<1>(ops_partials)[i]
124-
+= digamma_alpha_beta - digamma_n_r_alpha_beta
125-
- (digamma(alpha_vec.val(i)) - digamma_r_alpha);
126-
}
127-
if (!is_constant_all<T_beta>::value) {
128-
partials<2>(ops_partials)[i]
129-
+= digamma_alpha_beta - digamma_n_r_alpha_beta
130-
+ digamma(n_vec[i] + beta_vec.val(i)) - digamma(beta_vec.val(i));
131122
}
132123
}
133124
return ops_partials.build(logp);

0 commit comments

Comments
 (0)