44#include < stan/math/prim/meta.hpp>
55#include < stan/math/prim/err.hpp>
66#include < stan/math/prim/fun/constants.hpp>
7- #include < stan/math/prim/fun/digamma.hpp>
87#include < stan/math/prim/fun/exp.hpp>
9- #include < stan/math/prim/fun/gamma_q.hpp>
10- #include < stan/math/prim/fun/grad_reg_inc_gamma.hpp>
8+ #include < stan/math/prim/fun/fma.hpp>
9+ #include < stan/math/prim/fun/gamma_p.hpp>
10+ #include < stan/math/prim/fun/grad_reg_lower_inc_gamma.hpp>
1111#include < stan/math/prim/fun/log.hpp>
12+ #include < stan/math/prim/fun/log1m.hpp>
1213#include < stan/math/prim/fun/max_size.hpp>
1314#include < stan/math/prim/fun/scalar_seq_view.hpp>
1415#include < stan/math/prim/fun/size.hpp>
1516#include < stan/math/prim/fun/size_zero.hpp>
16- #include < stan/math/prim/fun/tgamma.hpp>
1717#include < stan/math/prim/fun/value_of.hpp>
1818#include < stan/math/prim/functor/partials_propagator.hpp>
19+ #include < stan/math/fwd/fun/lgamma.hpp>
20+ #include < stan/math/fwd/fun/log.hpp>
21+ #include < stan/math/fwd/fun/value_of.hpp>
1922#include < cmath>
2023
2124namespace stan {
2225namespace math {
2326
27+ namespace internal {
28+
29+ template <typename T_a>
30+ inline auto log_q_gamma_cf (const T_a& a, const double x,
31+ int max_steps = 250 , double precision = 1e-16 ) {
32+ using std::fabs;
33+ using stan::math::lgamma;
34+ using stan::math::log;
35+ using stan::math::value_of;
36+
37+ const auto log_prefactor = a * log (x) - x - lgamma (a);
38+
39+ auto b = x + 1.0 - a;
40+ auto C = (fabs (value_of (b)) >= EPSILON) ? b : T_a (EPSILON);
41+ auto D = T_a (0.0 );
42+ auto f = C;
43+
44+ for (int i = 1 ; i <= max_steps; ++i) {
45+ auto an = -i * (i - a);
46+ b += 2.0 ;
47+
48+ D = b + an * D;
49+ if (fabs (value_of (D)) < EPSILON) {
50+ D = T_a (EPSILON);
51+ }
52+ C = b + an / C;
53+ if (fabs (value_of (C)) < EPSILON) {
54+ C = T_a (EPSILON);
55+ }
56+
57+ D = 1.0 / D;
58+ auto delta = C * D;
59+ f *= delta;
60+
61+ const double delta_m1 = fabs (value_of (delta) - 1.0 );
62+ if (delta_m1 < precision) {
63+ if constexpr (stan::is_fvar<std::decay_t <T_a>>::value) {
64+ if (fabs (value_of (delta.d_ )) < precision) {
65+ break ;
66+ }
67+ } else {
68+ break ;
69+ }
70+ }
71+ }
72+
73+ return log_prefactor - log (f);
74+ }
75+
76+ } // namespace internal
77+
2478template <typename T_y, typename T_shape, typename T_inv_scale>
25- inline return_type_t <T_y, T_shape, T_inv_scale> gamma_lccdf (
26- const T_y& y, const T_shape& alpha, const T_inv_scale& beta) {
79+ return_type_t <T_y, T_shape, T_inv_scale> gamma_lccdf (const T_y& y,
80+ const T_shape& alpha,
81+ const T_inv_scale& beta) {
2782 using T_partials_return = partials_return_t <T_y, T_shape, T_inv_scale>;
2883 using std::exp;
2984 using std::log;
30- using std::pow;
3185 using T_y_ref = ref_type_t <T_y>;
3286 using T_alpha_ref = ref_type_t <T_shape>;
3387 using T_beta_ref = ref_type_t <T_inv_scale>;
@@ -51,61 +105,106 @@ inline return_type_t<T_y, T_shape, T_inv_scale> gamma_lccdf(
51105 scalar_seq_view<T_y_ref> y_vec (y_ref);
52106 scalar_seq_view<T_alpha_ref> alpha_vec (alpha_ref);
53107 scalar_seq_view<T_beta_ref> beta_vec (beta_ref);
54- size_t N = max_size (y, alpha, beta);
55-
56- // Explicit return for extreme values
57- // The gradients are technically ill-defined, but treated as zero
58- for (size_t i = 0 ; i < stan::math::size (y); i++) {
59- if (y_vec.val (i) == 0 ) {
60- // LCCDF(0) = log(P(Y > 0)) = log(1) = 0
61- return ops_partials.build (0.0 );
62- }
63- }
108+ const size_t N = max_size (y, alpha, beta);
109+
110+ constexpr bool need_y_beta_deriv = !is_constant_all<T_y, T_inv_scale>::value;
64111
65112 for (size_t n = 0 ; n < N; n++) {
66113 // Explicit results for extreme values
67114 // The gradients are technically ill-defined, but treated as zero
68- if (y_vec.val (n) == INFTY) {
69- // LCCDF(∞) = log(P(Y > ∞)) = log(0) = -∞
115+ const T_partials_return y_dbl = value_of (y_vec.val (n));
116+ if (y_dbl == 0.0 ) {
117+ continue ;
118+ }
119+ if (y_dbl == INFTY) {
70120 return ops_partials.build (negative_infinity ());
71121 }
72122
73- const T_partials_return y_dbl = y_vec.val (n);
74- const T_partials_return alpha_dbl = alpha_vec.val (n);
75- const T_partials_return beta_dbl = beta_vec.val (n);
76- const T_partials_return beta_y_dbl = beta_dbl * y_dbl;
123+ const T_partials_return alpha_dbl = value_of (alpha_vec.val (n));
124+ const T_partials_return beta_dbl = value_of (beta_vec.val (n));
125+
126+ const T_partials_return beta_y = beta_dbl * y_dbl;
127+ if (beta_y == INFTY) {
128+ return ops_partials.build (negative_infinity ());
129+ }
77130
78- // Qn = 1 - Pn
79- const T_partials_return Qn = gamma_q (alpha_dbl, beta_y_dbl);
80- const T_partials_return log_Qn = log (Qn);
131+ bool use_cf = beta_y > alpha_dbl + 1.0 ;
132+ T_partials_return log_Qn;
133+ [[maybe_unused]] T_partials_return dlogQ_dalpha = 0.0 ;
134+ // Extract double values for continued fraction - we handle y/beta derivatives via hazard
135+ const double beta_y_dbl = value_of (value_of (beta_y));
136+ const double alpha_dbl_val = value_of (value_of (alpha_dbl));
137+
138+ if (use_cf) {
139+ if constexpr (is_autodiff_v<T_shape>) {
140+ stan::math::fvar<double > a_f (alpha_dbl_val, 1.0 );
141+ const stan::math::fvar<double > logq_f
142+ = internal::log_q_gamma_cf (a_f, beta_y_dbl);
143+ log_Qn = logq_f.val_ ;
144+ dlogQ_dalpha = logq_f.d_ ;
145+ } else {
146+ log_Qn = internal::log_q_gamma_cf (alpha_dbl_val, beta_y_dbl);
147+ }
148+ } else {
149+ const T_partials_return Pn = gamma_p (alpha_dbl, beta_y);
150+ log_Qn = log1m (Pn);
151+
152+ if (!std::isfinite (value_of (value_of (log_Qn)))) {
153+ use_cf = beta_y > 0.0 ;
154+ if (use_cf) {
155+ if constexpr (is_autodiff_v<T_shape>) {
156+ stan::math::fvar<double > a_f (alpha_dbl_val, 1.0 );
157+ const stan::math::fvar<double > logq_f
158+ = internal::log_q_gamma_cf (a_f, beta_y_dbl);
159+ log_Qn = logq_f.val_ ;
160+ dlogQ_dalpha = logq_f.d_ ;
161+ } else {
162+ log_Qn = internal::log_q_gamma_cf (alpha_dbl_val, beta_y_dbl);
163+ }
164+ }
165+ }
81166
167+ if constexpr (is_autodiff_v<T_shape>) {
168+ if (!use_cf) {
169+ const T_partials_return Qn = exp (log_Qn);
170+ if (Qn > 0.0 ) {
171+ dlogQ_dalpha = -grad_reg_lower_inc_gamma (alpha_dbl, beta_y) / Qn;
172+ } else {
173+ stan::math::fvar<double > a_f (alpha_dbl_val, 1.0 );
174+ const stan::math::fvar<double > logq_f
175+ = internal::log_q_gamma_cf (a_f, beta_y_dbl);
176+ log_Qn = logq_f.val_ ;
177+ dlogQ_dalpha = logq_f.d_ ;
178+ use_cf = true ;
179+ }
180+ }
181+ }
182+ }
183+ if (!std::isfinite (value_of (value_of (log_Qn)))) {
184+ return ops_partials.build (negative_infinity ());
185+ }
82186 P += log_Qn;
83187
84- if constexpr (is_any_autodiff_v<T_y, T_inv_scale>) {
85- const T_partials_return log_y_dbl = log (y_dbl);
86- const T_partials_return log_beta_dbl = log (beta_dbl);
87- const T_partials_return log_pdf
88- = alpha_dbl * log_beta_dbl - lgamma (alpha_dbl)
89- + (alpha_dbl - 1.0 ) * log_y_dbl - beta_y_dbl;
90- const T_partials_return common_term = exp (log_pdf - log_Qn);
188+ if constexpr (need_y_beta_deriv) {
189+ const T_partials_return log_y = log (y_dbl);
190+ const T_partials_return log_beta = log (beta_dbl);
191+ const T_partials_return lgamma_alpha = lgamma (alpha_dbl);
192+ const T_partials_return alpha_minus_one = fma (alpha_dbl, log_y, -log_y);
193+
194+ const T_partials_return log_pdf = alpha_dbl * log_beta - lgamma_alpha
195+ + alpha_minus_one - beta_y;
196+
197+ const T_partials_return hazard = exp (log_pdf - log_Qn); // f/Q
91198
92199 if constexpr (is_autodiff_v<T_y>) {
93- // d/dy log(1-F(y)) = -f(y)/(1-F(y))
94- partials<0 >(ops_partials)[n] -= common_term;
200+ partials<0 >(ops_partials)[n] -= hazard;
95201 }
96202 if constexpr (is_autodiff_v<T_inv_scale>) {
97- // d/dbeta log(1-F(y)) = -y*f(y)/(beta*(1-F(y)))
98- partials<2 >(ops_partials)[n] -= y_dbl / beta_dbl * common_term;
203+ partials<2 >(ops_partials)[n] -= (y_dbl / beta_dbl) * hazard;
99204 }
100205 }
101-
102206 if constexpr (is_autodiff_v<T_shape>) {
103- const T_partials_return digamma_val = digamma (alpha_dbl);
104- const T_partials_return gamma_val = tgamma (alpha_dbl);
105- // d/dalpha log(1-F(y)) = grad_upper_inc_gamma / (1-F(y))
106- partials<1 >(ops_partials)[n]
107- += grad_reg_inc_gamma (alpha_dbl, beta_y_dbl, gamma_val, digamma_val)
108- / Qn;
207+ partials<1 >(ops_partials)[n] += dlogQ_dalpha;
109208 }
110209 }
111210 return ops_partials.build (P);
0 commit comments