@@ -84,126 +84,50 @@ inline return_type_t<T_r, T_alpha, T_beta> beta_neg_binomial_lpmf(
8484 size_t size_alpha_beta = max_size (alpha, beta);
8585 size_t max_size_seq_view = max_size (n, r, alpha, beta);
8686
87- VectorBuilder<include_summand<propto>::value, T_partials_return, T_n>
88- normalizing_constant (size_n);
89- for (size_t i = 0 ; i < size_n; i++)
90- if (include_summand<propto>::value)
91- normalizing_constant[i] = -lgamma (n_vec[i] + 1 );
92-
93- VectorBuilder<true , T_partials_return, T_r, T_alpha> lbeta_denominator (
94- size_r_alpha);
95- for (size_t i = 0 ; i < size_r_alpha; i++) {
96- lbeta_denominator[i] = lbeta (r_vec.val (i), alpha_vec.val (i));
97- }
98-
99- VectorBuilder<true , T_partials_return, T_beta> lgamma_denominator (size_beta);
100- for (size_t i = 0 ; i < size_beta; i++) {
101- lgamma_denominator[i] = lgamma (beta_vec.val (i));
102- }
103-
104- VectorBuilder<true , T_partials_return, T_n, T_beta> lgamma_numerator (
105- size_n_beta);
106- for (size_t i = 0 ; i < size_n_beta; i++) {
107- lgamma_numerator[i] = lgamma (n_vec[i] + beta_vec.val (i));
108- }
109-
110- VectorBuilder<true , T_partials_return, T_n, T_r, T_alpha, T_beta> lbeta_diff (
111- max_size_seq_view);
11287 for (size_t i = 0 ; i < max_size_seq_view; i++) {
113- lbeta_diff[i]
114- = lbeta (n_vec[i] + r_vec.val (i), alpha_vec.val (i) + beta_vec.val (i))
115- + lgamma_numerator[i] - lbeta_denominator[i] - lgamma_denominator[i];
116- }
117-
118- // derivative w.r.t. r, alpha and beta
119-
120- VectorBuilder<!is_constant_all<T_r, T_alpha, T_beta>::value,
121- T_partials_return, T_n, T_r, T_alpha, T_beta>
122- digamma_n_r_alpha_beta (max_size_seq_view);
123- if (!is_constant_all<T_r, T_alpha, T_beta>::value) {
124- for (size_t i = 0 ; i < max_size_seq_view; i++) {
125- digamma_n_r_alpha_beta[i] = digamma (n_vec[i] + r_vec.val (i)
126- + alpha_vec.val (i) + beta_vec.val (i));
88+ const T_partials_return lbeta_denominator
89+ = lbeta (r_vec.val (i), alpha_vec.val (i));
90+ const T_partials_return lgamma_numerator
91+ = lgamma (n_vec[i] + beta_vec.val (i));
92+ const T_partials_return lgamma_denominator = lgamma (beta_vec.val (i));
93+ const T_partials_return lbeta_numerator
94+ = lbeta (n_vec[i] + r_vec.val (i), alpha_vec.val (i) + beta_vec.val (i));
95+ if (include_summand<propto>::value) {
96+ logp -= lgamma (n_vec[i] + 1 );
12797 }
128- }
98+ logp += lbeta_numerator + lgamma_numerator - lbeta_denominator
99+ - lgamma_denominator;
129100
130- VectorBuilder<!is_constant_all<T_alpha, T_beta>::value, T_partials_return,
131- T_alpha, T_beta>
132- digamma_alpha_beta (size_alpha_beta);
133- if (!is_constant_all<T_alpha, T_beta>::value) {
134- for (size_t i = 0 ; i < size_alpha_beta; i++) {
135- digamma_alpha_beta[i] = digamma (alpha_vec.val (i) + beta_vec.val (i));
136- }
137- }
101+ T_partials_return digamma_n_r_alpha_beta
102+ = is_constant_all<T_r, T_alpha, T_beta>::value
103+ ? 0
104+ : digamma (n_vec[i] + r_vec.val (i) + alpha_vec.val (i)
105+ + beta_vec.val (i));
138106
139- VectorBuilder<!is_constant_all<T_r>::value, T_partials_return, T_n, T_r>
140- digamma_n_r (size_n_r);
141- if (!is_constant_all<T_r>::value) {
142- for (size_t i = 0 ; i < size_n_r; i++) {
143- digamma_n_r[i] = digamma (n_vec[i] + r_vec.val (i));
144- }
145- }
107+ T_partials_return digamma_r_alpha
108+ = is_constant_all<T_r, T_alpha>::value
109+ ? 0
110+ : digamma (r_vec.val (i) + alpha_vec.val (i));
146111
147- VectorBuilder<!is_constant_all<T_r, T_alpha>::value, T_partials_return, T_r,
148- T_alpha>
149- digamma_r_alpha (size_r_alpha);
150- if (!is_constant_all<T_r, T_alpha>::value) {
151- for (size_t i = 0 ; i < size_r_alpha; i++) {
152- digamma_r_alpha[i] = digamma (r_vec.val (i) + alpha_vec.val (i));
153- }
154- }
155-
156- VectorBuilder<!is_constant_all<T_beta>::value, T_partials_return, T_n,
157- T_beta>
158- digamma_n_beta (size_n_beta);
159- if (!is_constant_all<T_n, T_beta>::value) {
160- for (size_t i = 0 ; i < size_n_beta; i++) {
161- digamma_n_beta[i] = digamma (n_vec[i] + beta_vec.val (i));
162- }
163- }
164-
165- VectorBuilder<!is_constant_all<T_r>::value, T_partials_return, T_r> digamma_r (
166- size_r);
167- if (!is_constant_all<T_r>::value) {
168- for (size_t i = 0 ; i < size_r; i++) {
169- digamma_r[i] = digamma (r_vec.val (i));
170- }
171- }
172-
173- VectorBuilder<!is_constant_all<T_alpha>::value, T_partials_return, T_alpha>
174- digamma_alpha (size_alpha);
175- if (!is_constant_all<T_alpha>::value) {
176- for (size_t i = 0 ; i < size_alpha; i++) {
177- digamma_alpha[i] = digamma (alpha_vec.val (i));
178- }
179- }
180-
181- VectorBuilder<!is_constant_all<T_beta>::value, T_partials_return, T_beta>
182- digamma_beta (size_beta);
183- if (!is_constant_all<T_beta>::value) {
184- for (size_t i = 0 ; i < size_beta; i++) {
185- digamma_beta[i] = digamma (beta_vec.val (i));
186- }
187- }
188-
189- for (size_t i = 0 ; i < max_size_seq_view; i++) {
190- if (include_summand<propto>::value)
191- logp += normalizing_constant[i];
192- logp += lbeta_diff[i];
112+ T_partials_return digamma_alpha_beta
113+ = is_constant_all<T_alpha, T_beta>::value
114+ ? 0
115+ : digamma (alpha_vec.val (i) + beta_vec.val (i));
193116
194117 if (!is_constant_all<T_r>::value) {
195- partials<0 >(ops_partials)[i] += digamma_n_r[i] - digamma_n_r_alpha_beta[i]
196- - (digamma_r[i] - digamma_r_alpha[i]);
118+ partials<0 >(ops_partials)[i]
119+ += digamma (n_vec[i] + r_vec.val (i)) - digamma_n_r_alpha_beta
120+ - (digamma (r_vec.val (i)) - digamma_r_alpha);
197121 }
198122 if (!is_constant_all<T_alpha>::value) {
199- partials<1 >(ops_partials)[i] += digamma_alpha_beta[i]
200- - digamma_n_r_alpha_beta[i]
201- - (digamma_alpha[i] - digamma_r_alpha[i] );
123+ partials<1 >(ops_partials)[i]
124+ += digamma_alpha_beta - digamma_n_r_alpha_beta
125+ - (digamma (alpha_vec. val (i)) - digamma_r_alpha);
202126 }
203127 if (!is_constant_all<T_beta>::value) {
204- partials<2 >(ops_partials)[i] += digamma_alpha_beta[i]
205- - digamma_n_r_alpha_beta[i]
206- + digamma_n_beta [i] - digamma_beta[i] ;
128+ partials<2 >(ops_partials)[i]
129+ += digamma_alpha_beta - digamma_n_r_alpha_beta
130+ + digamma (n_vec [i] + beta_vec. val (i)) - digamma (beta_vec. val (i)) ;
207131 }
208132 }
209133 return ops_partials.build (logp);
0 commit comments