44#include < stan/math/prim/meta.hpp>
55#include < stan/math/prim/err.hpp>
66#include < stan/math/prim/fun/constants.hpp>
7+ #include < stan/math/prim/fun/digamma.hpp>
78#include < stan/math/prim/fun/exp.hpp>
89#include < stan/math/prim/fun/fma.hpp>
910#include < stan/math/prim/fun/gamma_p.hpp>
11+ #include < stan/math/prim/fun/grad_reg_inc_gamma.hpp>
1012#include < stan/math/prim/fun/grad_reg_lower_inc_gamma.hpp>
11- // #include <stan/math/prim/fun/lgamma.hpp>
12- // #include <stan/math/prim/fun/log.hpp>
13+ #include < stan/math/prim/fun/lgamma.hpp>
14+ #include < stan/math/prim/fun/log.hpp>
1315#include < stan/math/prim/fun/log1m.hpp>
1416#include < stan/math/prim/fun/max_size.hpp>
1517#include < stan/math/prim/fun/scalar_seq_view.hpp>
1618#include < stan/math/prim/fun/size.hpp>
1719#include < stan/math/prim/fun/size_zero.hpp>
20+ #include < stan/math/prim/fun/tgamma.hpp>
1821#include < stan/math/prim/fun/value_of.hpp>
22+ #include < stan/math/prim/fun/log_gamma_q_dgamma.hpp>
1923#include < stan/math/prim/functor/partials_propagator.hpp>
20-
21- #include < stan/math/fwd/fun/lgamma.hpp>
22- #include < stan/math/fwd/fun/log.hpp>
23- #include < stan/math/fwd/fun/value_of.hpp>
2424#include < cmath>
2525
2626namespace stan {
@@ -30,34 +30,35 @@ namespace internal {
3030
3131/* *
3232 * Compute log(Q(a,x)) using continued fraction expansion for upper incomplete
33- * gamma function, where Q(a,x) = Gamma(a,x) / Gamma(a) is the regularized
34- * upper incomplete gamma function.
33+ * gamma function. When used with fvar types, automatically computes derivatives.
3534 *
36- * @tparam T_a Type of shape parameter a; can be either double or fvar<double>
37- * for forward-mode automatic differentiation
35+ * @tparam T_a Type of shape parameter a (double or fvar types)
3836 * @param a Shape parameter
3937 * @param x Value at which to evaluate
4038 * @param max_steps Maximum number of continued fraction iterations
4139 * @param precision Convergence threshold
4240 * @return log(Q(a,x)) with same type as T_a
4341 */
44- template <typename T_a>
45- inline auto log_q_gamma_cf (const T_a& a, const double x, int max_steps = 250 ,
42+ template <typename T_a, typename T_x >
43+ inline auto log_q_gamma_cf (const T_a& a, const T_x& x, int max_steps = 250 ,
4644 double precision = 1e-16 ) {
4745 using stan::math::lgamma;
4846 using stan::math::log;
4947 using stan::math::value_of;
5048 using std::fabs;
49+ using T_return = return_type_t <T_a, T_x>;
5150
52- const auto log_prefactor = a * log (x) - x - lgamma (a);
51+ const T_return a_ret = a;
52+ const T_return x_ret = x;
53+ const auto log_prefactor = a_ret * log (x_ret) - x_ret - lgamma (a_ret);
5354
54- auto b = x + 1.0 - a ;
55- auto C = (fabs (value_of (b)) >= EPSILON) ? b : T_a (EPSILON);
56- auto D = T_a (0.0 );
55+ auto b = x_ret + 1.0 - a_ret ;
56+ auto C = (fabs (value_of (b)) >= EPSILON) ? b : T_return (EPSILON);
57+ auto D = T_return (0.0 );
5758 auto f = C;
5859
5960 for (int i = 1 ; i <= max_steps; ++i) {
60- auto an = -i * (i - a );
61+ auto an = -i * (i - a_ret );
6162 b += 2.0 ;
6263
6364 D = b + an * D;
@@ -73,15 +74,9 @@ inline auto log_q_gamma_cf(const T_a& a, const double x, int max_steps = 250,
7374 auto delta = C * D;
7475 f *= delta;
7576
76- const double delta_m1 = fabs (value_of (delta) - 1.0 );
77+ const double delta_m1 = value_of ( fabs (value_of (delta) - 1.0 ) );
7778 if (delta_m1 < precision) {
78- if constexpr (stan::is_fvar<std::decay_t <T_a>>::value) {
79- if (fabs (value_of (delta.d_ )) < precision) {
80- break ;
81- }
82- } else {
83- break ;
84- }
79+ break ;
8580 }
8681 }
8782
@@ -123,20 +118,25 @@ return_type_t<T_y, T_shape, T_inv_scale> gamma_lccdf(const T_y& y,
123118 const size_t N = max_size (y, alpha, beta);
124119
125120 constexpr bool need_y_beta_deriv = !is_constant_all<T_y, T_inv_scale>::value;
121+ constexpr bool any_fvar
122+ = is_fvar<scalar_type_t <T_y>>::value
123+ || is_fvar<scalar_type_t <T_shape>>::value
124+ || is_fvar<scalar_type_t <T_inv_scale>>::value;
125+ constexpr bool partials_fvar = is_fvar<T_partials_return>::value;
126126
127127 for (size_t n = 0 ; n < N; n++) {
128128 // Explicit results for extreme values
129129 // The gradients are technically ill-defined, but treated as zero
130- const T_partials_return y_dbl = value_of ( y_vec.val (n) );
130+ const T_partials_return y_dbl = y_vec.val (n);
131131 if (y_dbl == 0.0 ) {
132132 continue ;
133133 }
134134 if (y_dbl == INFTY) {
135135 return ops_partials.build (negative_infinity ());
136136 }
137137
138- const T_partials_return alpha_dbl = value_of ( alpha_vec.val (n) );
139- const T_partials_return beta_dbl = value_of ( beta_vec.val (n) );
138+ const T_partials_return alpha_dbl = alpha_vec.val (n);
139+ const T_partials_return beta_dbl = beta_vec.val (n);
140140
141141 const T_partials_return beta_y = beta_dbl * y_dbl;
142142 if (beta_y == INFTY) {
@@ -146,20 +146,35 @@ return_type_t<T_y, T_shape, T_inv_scale> gamma_lccdf(const T_y& y,
146146 bool use_cf = beta_y > alpha_dbl + 1.0 ;
147147 T_partials_return log_Qn;
148148 [[maybe_unused]] T_partials_return dlogQ_dalpha = 0.0 ;
149- // Extract double values for continued fraction - we handle y/beta
150- // derivatives via hazard
151- const double beta_y_dbl = value_of (value_of (beta_y));
152- const double alpha_dbl_val = value_of (value_of (alpha_dbl));
149+ // Extract double values for the double-only continued fraction path.
150+ [[maybe_unused]] const double beta_y_dbl = value_of (value_of (beta_y));
151+ [[maybe_unused]] const double alpha_dbl_val = value_of (value_of (alpha_dbl));
153152
154153 if (use_cf) {
155- if constexpr (is_autodiff_v<T_shape>) {
156- stan::math::fvar<double > a_f (alpha_dbl_val, 1.0 );
157- const stan::math::fvar<double > logq_f
158- = internal::log_q_gamma_cf (a_f, beta_y_dbl);
159- log_Qn = logq_f.val_ ;
160- dlogQ_dalpha = logq_f.d_ ;
154+ if constexpr (!any_fvar && is_autodiff_v<T_shape>) {
155+ // var-only: use analytical gradient with double inputs
156+ auto log_q_result = log_gamma_q_dgamma (alpha_dbl_val, beta_y_dbl);
157+ log_Qn = log_q_result.log_q ;
158+ dlogQ_dalpha = log_q_result.dlog_q_da ;
161159 } else {
162- log_Qn = internal::log_q_gamma_cf (alpha_dbl_val, beta_y_dbl);
160+ log_Qn = internal::log_q_gamma_cf (alpha_dbl, beta_y);
161+ if constexpr (is_autodiff_v<T_shape>) {
162+ if constexpr (partials_fvar) {
163+ auto alpha_unit = alpha_dbl;
164+ alpha_unit.d_ = 1 ;
165+ auto beta_unit = beta_y;
166+ beta_unit.d_ = 0 ;
167+ auto log_Qn_fvar
168+ = internal::log_q_gamma_cf (alpha_unit, beta_unit);
169+ dlogQ_dalpha = log_Qn_fvar.d_ ;
170+ } else {
171+ const T_partials_return Qn = exp (log_Qn);
172+ dlogQ_dalpha
173+ = grad_reg_inc_gamma (alpha_dbl, beta_y, tgamma (alpha_dbl),
174+ digamma (alpha_dbl))
175+ / Qn;
176+ }
177+ }
163178 }
164179 } else {
165180 const T_partials_return Pn = gamma_p (alpha_dbl, beta_y);
@@ -168,30 +183,64 @@ return_type_t<T_y, T_shape, T_inv_scale> gamma_lccdf(const T_y& y,
168183 if (!std::isfinite (value_of (value_of (log_Qn)))) {
169184 use_cf = beta_y > 0.0 ;
170185 if (use_cf) {
171- if constexpr (is_autodiff_v<T_shape>) {
172- stan::math::fvar<double > a_f (alpha_dbl_val, 1.0 );
173- const stan::math::fvar<double > logq_f
174- = internal::log_q_gamma_cf (a_f, beta_y_dbl);
175- log_Qn = logq_f.val_ ;
176- dlogQ_dalpha = logq_f.d_ ;
186+ // Fallback to continued fraction if log1m fails
187+ if constexpr (!any_fvar && is_autodiff_v<T_shape>) {
188+ auto log_q_result = log_gamma_q_dgamma (alpha_dbl_val, beta_y_dbl);
189+ log_Qn = log_q_result.log_q ;
190+ dlogQ_dalpha = log_q_result.dlog_q_da ;
177191 } else {
178- log_Qn = internal::log_q_gamma_cf (alpha_dbl_val, beta_y_dbl);
192+ log_Qn = internal::log_q_gamma_cf (alpha_dbl, beta_y);
193+ if constexpr (is_autodiff_v<T_shape>) {
194+ if constexpr (partials_fvar) {
195+ auto alpha_unit = alpha_dbl;
196+ alpha_unit.d_ = 1 ;
197+ auto beta_unit = beta_y;
198+ beta_unit.d_ = 0 ;
199+ auto log_Qn_fvar
200+ = internal::log_q_gamma_cf (alpha_unit, beta_unit);
201+ dlogQ_dalpha = log_Qn_fvar.d_ ;
202+ } else {
203+ const T_partials_return Qn = exp (log_Qn);
204+ dlogQ_dalpha
205+ = grad_reg_inc_gamma (alpha_dbl, beta_y, tgamma (alpha_dbl),
206+ digamma (alpha_dbl))
207+ / Qn;
208+ }
209+ }
179210 }
180211 }
181212 }
182213
183214 if constexpr (is_autodiff_v<T_shape>) {
184215 if (!use_cf) {
185- const T_partials_return Qn = exp (log_Qn);
186- if (Qn > 0.0 ) {
187- dlogQ_dalpha = -grad_reg_lower_inc_gamma (alpha_dbl, beta_y) / Qn;
216+ if constexpr (partials_fvar) {
217+ auto alpha_unit = alpha_dbl;
218+ alpha_unit.d_ = 1 ;
219+ auto beta_unit = beta_y;
220+ beta_unit.d_ = 0 ;
221+ auto log_Qn_fvar = log1m (gamma_p (alpha_unit, beta_unit));
222+ dlogQ_dalpha = log_Qn_fvar.d_ ;
188223 } else {
189- stan::math::fvar<double > a_f (alpha_dbl_val, 1.0 );
190- const stan::math::fvar<double > logq_f
191- = internal::log_q_gamma_cf (a_f, beta_y_dbl);
192- log_Qn = logq_f.val_ ;
193- dlogQ_dalpha = logq_f.d_ ;
194- use_cf = true ;
224+ const T_partials_return Qn = exp (log_Qn);
225+ if (Qn > 0.0 ) {
226+ dlogQ_dalpha = -grad_reg_lower_inc_gamma (alpha_dbl, beta_y) / Qn;
227+ } else {
228+ // Fallback to continued fraction if Q rounds to zero
229+ if constexpr (!any_fvar) {
230+ auto log_q_result
231+ = log_gamma_q_dgamma (alpha_dbl_val, beta_y_dbl);
232+ log_Qn = log_q_result.log_q ;
233+ dlogQ_dalpha = log_q_result.dlog_q_da ;
234+ } else {
235+ log_Qn = internal::log_q_gamma_cf (alpha_dbl, beta_y);
236+ const T_partials_return Qn_cf = exp (log_Qn);
237+ dlogQ_dalpha
238+ = grad_reg_inc_gamma (alpha_dbl, beta_y, tgamma (alpha_dbl),
239+ digamma (alpha_dbl))
240+ / Qn_cf;
241+ }
242+ use_cf = true ;
243+ }
195244 }
196245 }
197246 }
0 commit comments