Skip to content

Commit 11ba289

Browse files
committed
move loops into one
1 parent d2b5f26 commit 11ba289

1 file changed

Lines changed: 33 additions & 109 deletions

File tree

stan/math/prim/prob/beta_neg_binomial_lpmf.hpp

Lines changed: 33 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -84,126 +84,50 @@ inline return_type_t<T_r, T_alpha, T_beta> beta_neg_binomial_lpmf(
8484
size_t size_alpha_beta = max_size(alpha, beta);
8585
size_t max_size_seq_view = max_size(n, r, alpha, beta);
8686

87-
VectorBuilder<include_summand<propto>::value, T_partials_return, T_n>
88-
normalizing_constant(size_n);
89-
for (size_t i = 0; i < size_n; i++)
90-
if (include_summand<propto>::value)
91-
normalizing_constant[i] = -lgamma(n_vec[i] + 1);
92-
93-
VectorBuilder<true, T_partials_return, T_r, T_alpha> lbeta_denominator(
94-
size_r_alpha);
95-
for (size_t i = 0; i < size_r_alpha; i++) {
96-
lbeta_denominator[i] = lbeta(r_vec.val(i), alpha_vec.val(i));
97-
}
98-
99-
VectorBuilder<true, T_partials_return, T_beta> lgamma_denominator(size_beta);
100-
for (size_t i = 0; i < size_beta; i++) {
101-
lgamma_denominator[i] = lgamma(beta_vec.val(i));
102-
}
103-
104-
VectorBuilder<true, T_partials_return, T_n, T_beta> lgamma_numerator(
105-
size_n_beta);
106-
for (size_t i = 0; i < size_n_beta; i++) {
107-
lgamma_numerator[i] = lgamma(n_vec[i] + beta_vec.val(i));
108-
}
109-
110-
VectorBuilder<true, T_partials_return, T_n, T_r, T_alpha, T_beta> lbeta_diff(
111-
max_size_seq_view);
11287
for (size_t i = 0; i < max_size_seq_view; i++) {
113-
lbeta_diff[i]
114-
= lbeta(n_vec[i] + r_vec.val(i), alpha_vec.val(i) + beta_vec.val(i))
115-
+ lgamma_numerator[i] - lbeta_denominator[i] - lgamma_denominator[i];
116-
}
117-
118-
// derivative w.r.t. r, alpha and beta
119-
120-
VectorBuilder<!is_constant_all<T_r, T_alpha, T_beta>::value,
121-
T_partials_return, T_n, T_r, T_alpha, T_beta>
122-
digamma_n_r_alpha_beta(max_size_seq_view);
123-
if (!is_constant_all<T_r, T_alpha, T_beta>::value) {
124-
for (size_t i = 0; i < max_size_seq_view; i++) {
125-
digamma_n_r_alpha_beta[i] = digamma(n_vec[i] + r_vec.val(i)
126-
+ alpha_vec.val(i) + beta_vec.val(i));
88+
const T_partials_return lbeta_denominator
89+
= lbeta(r_vec.val(i), alpha_vec.val(i));
90+
const T_partials_return lgamma_numerator
91+
= lgamma(n_vec[i] + beta_vec.val(i));
92+
const T_partials_return lgamma_denominator = lgamma(beta_vec.val(i));
93+
const T_partials_return lbeta_numerator
94+
= lbeta(n_vec[i] + r_vec.val(i), alpha_vec.val(i) + beta_vec.val(i));
95+
if (include_summand<propto>::value) {
96+
logp -= lgamma(n_vec[i] + 1);
12797
}
128-
}
98+
logp += lbeta_numerator + lgamma_numerator - lbeta_denominator
99+
- lgamma_denominator;
129100

130-
VectorBuilder<!is_constant_all<T_alpha, T_beta>::value, T_partials_return,
131-
T_alpha, T_beta>
132-
digamma_alpha_beta(size_alpha_beta);
133-
if (!is_constant_all<T_alpha, T_beta>::value) {
134-
for (size_t i = 0; i < size_alpha_beta; i++) {
135-
digamma_alpha_beta[i] = digamma(alpha_vec.val(i) + beta_vec.val(i));
136-
}
137-
}
101+
T_partials_return digamma_n_r_alpha_beta
102+
= is_constant_all<T_r, T_alpha, T_beta>::value
103+
? 0
104+
: digamma(n_vec[i] + r_vec.val(i) + alpha_vec.val(i)
105+
+ beta_vec.val(i));
138106

139-
VectorBuilder<!is_constant_all<T_r>::value, T_partials_return, T_n, T_r>
140-
digamma_n_r(size_n_r);
141-
if (!is_constant_all<T_r>::value) {
142-
for (size_t i = 0; i < size_n_r; i++) {
143-
digamma_n_r[i] = digamma(n_vec[i] + r_vec.val(i));
144-
}
145-
}
107+
T_partials_return digamma_r_alpha
108+
= is_constant_all<T_r, T_alpha>::value
109+
? 0
110+
: digamma(r_vec.val(i) + alpha_vec.val(i));
146111

147-
VectorBuilder<!is_constant_all<T_r, T_alpha>::value, T_partials_return, T_r,
148-
T_alpha>
149-
digamma_r_alpha(size_r_alpha);
150-
if (!is_constant_all<T_r, T_alpha>::value) {
151-
for (size_t i = 0; i < size_r_alpha; i++) {
152-
digamma_r_alpha[i] = digamma(r_vec.val(i) + alpha_vec.val(i));
153-
}
154-
}
155-
156-
VectorBuilder<!is_constant_all<T_beta>::value, T_partials_return, T_n,
157-
T_beta>
158-
digamma_n_beta(size_n_beta);
159-
if (!is_constant_all<T_n, T_beta>::value) {
160-
for (size_t i = 0; i < size_n_beta; i++) {
161-
digamma_n_beta[i] = digamma(n_vec[i] + beta_vec.val(i));
162-
}
163-
}
164-
165-
VectorBuilder<!is_constant_all<T_r>::value, T_partials_return, T_r> digamma_r(
166-
size_r);
167-
if (!is_constant_all<T_r>::value) {
168-
for (size_t i = 0; i < size_r; i++) {
169-
digamma_r[i] = digamma(r_vec.val(i));
170-
}
171-
}
172-
173-
VectorBuilder<!is_constant_all<T_alpha>::value, T_partials_return, T_alpha>
174-
digamma_alpha(size_alpha);
175-
if (!is_constant_all<T_alpha>::value) {
176-
for (size_t i = 0; i < size_alpha; i++) {
177-
digamma_alpha[i] = digamma(alpha_vec.val(i));
178-
}
179-
}
180-
181-
VectorBuilder<!is_constant_all<T_beta>::value, T_partials_return, T_beta>
182-
digamma_beta(size_beta);
183-
if (!is_constant_all<T_beta>::value) {
184-
for (size_t i = 0; i < size_beta; i++) {
185-
digamma_beta[i] = digamma(beta_vec.val(i));
186-
}
187-
}
188-
189-
for (size_t i = 0; i < max_size_seq_view; i++) {
190-
if (include_summand<propto>::value)
191-
logp += normalizing_constant[i];
192-
logp += lbeta_diff[i];
112+
T_partials_return digamma_alpha_beta
113+
= is_constant_all<T_alpha, T_beta>::value
114+
? 0
115+
: digamma(alpha_vec.val(i) + beta_vec.val(i));
193116

194117
if (!is_constant_all<T_r>::value) {
195-
partials<0>(ops_partials)[i] += digamma_n_r[i] - digamma_n_r_alpha_beta[i]
196-
- (digamma_r[i] - digamma_r_alpha[i]);
118+
partials<0>(ops_partials)[i]
119+
+= digamma(n_vec[i] + r_vec.val(i)) - digamma_n_r_alpha_beta
120+
- (digamma(r_vec.val(i)) - digamma_r_alpha);
197121
}
198122
if (!is_constant_all<T_alpha>::value) {
199-
partials<1>(ops_partials)[i] += digamma_alpha_beta[i]
200-
- digamma_n_r_alpha_beta[i]
201-
- (digamma_alpha[i] - digamma_r_alpha[i]);
123+
partials<1>(ops_partials)[i]
124+
+= digamma_alpha_beta - digamma_n_r_alpha_beta
125+
- (digamma(alpha_vec.val(i)) - digamma_r_alpha);
202126
}
203127
if (!is_constant_all<T_beta>::value) {
204-
partials<2>(ops_partials)[i] += digamma_alpha_beta[i]
205-
- digamma_n_r_alpha_beta[i]
206-
+ digamma_n_beta[i] - digamma_beta[i];
128+
partials<2>(ops_partials)[i]
129+
+= digamma_alpha_beta - digamma_n_r_alpha_beta
130+
+ digamma(n_vec[i] + beta_vec.val(i)) - digamma(beta_vec.val(i));
207131
}
208132
}
209133
return ops_partials.build(logp);

0 commit comments

Comments
 (0)