@@ -63,71 +63,62 @@ inline return_type_t<T_r, T_alpha, T_beta> beta_neg_binomial_lpmf(
6363 check_positive_finite (function, " Prior success parameter" , alpha_ref);
6464 check_positive_finite (function, " Prior failure parameter" , beta_ref);
6565
66- if (!include_summand<propto, T_r, T_alpha, T_beta>::value) {
66+ if constexpr (!include_summand<propto, T_r, T_alpha, T_beta>::value) {
6767 return 0.0 ;
6868 }
6969
70- T_partials_return logp (0.0 );
7170 auto ops_partials = make_partials_propagator (r_ref, alpha_ref, beta_ref);
7271
7372 scalar_seq_view<T_n> n_vec (n);
7473 scalar_seq_view<T_r_ref> r_vec (r_ref);
7574 scalar_seq_view<T_alpha_ref> alpha_vec (alpha_ref);
7675 scalar_seq_view<T_beta_ref> beta_vec (beta_ref);
77- size_t size_n = stan::math::size (n);
78- size_t size_r = stan::math::size (r);
79- size_t size_alpha = stan::math::size (alpha);
80- size_t size_beta = stan::math::size (beta);
81- size_t size_n_r = max_size (n, r);
82- size_t size_r_alpha = max_size (r, alpha);
83- size_t size_n_beta = max_size (n, beta);
84- size_t size_alpha_beta = max_size (alpha, beta);
85- size_t max_size_seq_view = max_size (n, r, alpha, beta);
86-
76+ const size_t max_size_seq_view = max_size (n, r, alpha, beta);
77+ T_partials_return logp (0.0 );
8778 for (size_t i = 0 ; i < max_size_seq_view; i++) {
88- const T_partials_return lbeta_denominator
79+ if constexpr (include_summand<propto>::value) {
80+ logp -= lgamma (n_vec[i] + 1 );
81+ }
82+ T_partials_return lbeta_denominator
8983 = lbeta (r_vec.val (i), alpha_vec.val (i));
90- const T_partials_return lgamma_numerator
84+ T_partials_return lgamma_numerator
9185 = lgamma (n_vec[i] + beta_vec.val (i));
92- const T_partials_return lgamma_denominator = lgamma (beta_vec.val (i));
93- const T_partials_return lbeta_numerator
86+ T_partials_return lgamma_denominator = lgamma (beta_vec.val (i));
87+ T_partials_return lbeta_numerator
9488 = lbeta (n_vec[i] + r_vec.val (i), alpha_vec.val (i) + beta_vec.val (i));
95- if (include_summand<propto>::value) {
96- logp -= lgamma (n_vec[i] + 1 );
97- }
9889 logp += lbeta_numerator + lgamma_numerator - lbeta_denominator
9990 - lgamma_denominator;
91+ if (!is_constant_all<T_r, T_alpha, T_beta>::value) {
92+ T_partials_return digamma_n_r_alpha_beta
93+ = digamma (n_vec[i] + r_vec.val (i) + alpha_vec.val (i)
94+ + beta_vec.val (i));
10095
101- T_partials_return digamma_n_r_alpha_beta
102- = is_constant_all<T_r, T_alpha, T_beta>::value
103- ? 0
104- : digamma (n_vec[i] + r_vec.val (i) + alpha_vec.val (i)
105- + beta_vec.val (i));
106-
107- T_partials_return digamma_r_alpha
108- = is_constant_all<T_r, T_alpha>::value
109- ? 0
110- : digamma (r_vec.val (i) + alpha_vec.val (i));
111-
112- T_partials_return digamma_alpha_beta
113- = is_constant_all<T_alpha, T_beta>::value
114- ? 0
115- : digamma (alpha_vec.val (i) + beta_vec.val (i));
96+ if constexpr (!is_constant<T_r>::value || !is_constant<T_alpha>::value) {
97+ T_partials_return digamma_r_alpha = digamma (r_vec.val (i) + alpha_vec.val (i));
98+ if constexpr (!is_constant_all<T_r>::value) {
99+ partials<0 >(ops_partials)[i]
100+ += digamma (n_vec[i] + r_vec.val (i)) - digamma_n_r_alpha_beta
101+ - (digamma (r_vec.val (i)) - digamma_r_alpha);
102+ }
103+ if constexpr (!is_constant_all<T_alpha>::value) {
104+ partials<1 >(ops_partials)[i]
105+ += -digamma_n_r_alpha_beta
106+ - (digamma (alpha_vec.val (i)) - digamma_r_alpha);
107+ }
108+ }
109+ if constexpr (!is_constant<T_beta>::value || !is_constant<T_alpha>::value) {
110+ T_partials_return digamma_alpha_beta
111+ = digamma (alpha_vec.val (i) + beta_vec.val (i));
112+ if constexpr (!is_constant_all<T_beta>::value) {
113+ partials<2 >(ops_partials)[i]
114+ += digamma_alpha_beta - digamma_n_r_alpha_beta
115+ + digamma (n_vec[i] + beta_vec.val (i)) - digamma (beta_vec.val (i));
116+ }
117+ if constexpr (!is_constant_all<T_alpha>::value) {
118+ partials<1 >(ops_partials)[i] += digamma_alpha_beta;
119+ }
120+ }
116121
117- if (!is_constant_all<T_r>::value) {
118- partials<0 >(ops_partials)[i]
119- += digamma (n_vec[i] + r_vec.val (i)) - digamma_n_r_alpha_beta
120- - (digamma (r_vec.val (i)) - digamma_r_alpha);
121- }
122- if (!is_constant_all<T_alpha>::value) {
123- partials<1 >(ops_partials)[i]
124- += digamma_alpha_beta - digamma_n_r_alpha_beta
125- - (digamma (alpha_vec.val (i)) - digamma_r_alpha);
126- }
127- if (!is_constant_all<T_beta>::value) {
128- partials<2 >(ops_partials)[i]
129- += digamma_alpha_beta - digamma_n_r_alpha_beta
130- + digamma (n_vec[i] + beta_vec.val (i)) - digamma (beta_vec.val (i));
131122 }
132123 }
133124 return ops_partials.build (logp);
0 commit comments