@@ -86,101 +86,69 @@ return_type_t<T_y, T_shape, T_inv_scale> gamma_lccdf(const T_y& y,
8686 bool use_cf = beta_y > alpha_dbl + 1.0 ;
8787 T_partials_return log_Qn;
8888 [[maybe_unused]] T_partials_return dlogQ_dalpha = 0.0 ;
89- // Extract double values for the double-only continued fraction path.
90- [[maybe_unused]] const double beta_y_dbl = value_of (value_of (beta_y));
91- [[maybe_unused]] const double alpha_dbl_val = value_of (value_of (alpha_dbl));
9289
93- if (use_cf) {
94- if constexpr (!any_fvar && is_autodiff_v<T_shape>) {
95- // var-only: use analytical gradient with double inputs
90+ // Branch by autodiff type first, then handle use_cf logic inside each path
91+ if constexpr (!any_fvar && is_autodiff_v<T_shape>) {
92+ // var-only path: use log_gamma_q_dgamma which computes both log_q
93+ // and its gradient analytically with double inputs
94+ const double beta_y_dbl = value_of (value_of (beta_y));
95+ const double alpha_dbl_val = value_of (value_of (alpha_dbl));
96+
97+ if (use_cf) {
9698 auto log_q_result = log_gamma_q_dgamma (alpha_dbl_val, beta_y_dbl);
9799 log_Qn = log_q_result.log_q ;
98100 dlogQ_dalpha = log_q_result.dlog_q_da ;
99101 } else {
100- log_Qn = internal::log_q_gamma_cf (alpha_dbl, beta_y);
101- if constexpr (is_autodiff_v<T_shape>) {
102- if constexpr (partials_fvar) {
103- auto alpha_unit = alpha_dbl;
104- alpha_unit.d_ = 1 ;
105- auto beta_unit = beta_y;
106- beta_unit.d_ = 0 ;
107- auto log_Qn_fvar = internal::log_q_gamma_cf (alpha_unit, beta_unit);
108- dlogQ_dalpha = log_Qn_fvar.d_ ;
109- } else {
110- const T_partials_return Qn = exp (log_Qn);
111- dlogQ_dalpha
112- = grad_reg_inc_gamma (alpha_dbl, beta_y, tgamma (alpha_dbl),
113- digamma (alpha_dbl))
114- / Qn;
115- }
102+ const T_partials_return Pn = gamma_p (alpha_dbl, beta_y);
103+ log_Qn = log1m (Pn);
104+ const T_partials_return Qn = exp (log_Qn);
105+
106+ // Check if we need to fallback to continued fraction
107+ bool need_cf_fallback = !std::isfinite (value_of (value_of (log_Qn)))
108+ || Qn <= 0.0 ;
109+ if (need_cf_fallback && beta_y > 0.0 ) {
110+ auto log_q_result = log_gamma_q_dgamma (alpha_dbl_val, beta_y_dbl);
111+ log_Qn = log_q_result.log_q ;
112+ dlogQ_dalpha = log_q_result.dlog_q_da ;
113+ } else {
114+ dlogQ_dalpha = -grad_reg_lower_inc_gamma (alpha_dbl, beta_y) / Qn;
116115 }
117116 }
118- } else {
119- const T_partials_return Pn = gamma_p (alpha_dbl, beta_y);
120- log_Qn = log1m (Pn);
121-
122- if (!std::isfinite (value_of (value_of (log_Qn)))) {
123- use_cf = beta_y > 0.0 ;
124- if (use_cf) {
125- // Fallback to continued fraction if log1m fails
126- if constexpr (!any_fvar && is_autodiff_v<T_shape>) {
127- auto log_q_result = log_gamma_q_dgamma (alpha_dbl_val, beta_y_dbl);
128- log_Qn = log_q_result.log_q ;
129- dlogQ_dalpha = log_q_result.dlog_q_da ;
130- } else {
131- log_Qn = internal::log_q_gamma_cf (alpha_dbl, beta_y);
132- if constexpr (is_autodiff_v<T_shape>) {
133- if constexpr (partials_fvar) {
134- auto alpha_unit = alpha_dbl;
135- alpha_unit.d_ = 1 ;
136- auto beta_unit = beta_y;
137- beta_unit.d_ = 0 ;
138- auto log_Qn_fvar
139- = internal::log_q_gamma_cf (alpha_unit, beta_unit);
140- dlogQ_dalpha = log_Qn_fvar.d_ ;
141- } else {
142- const T_partials_return Qn = exp (log_Qn);
143- dlogQ_dalpha
144- = grad_reg_inc_gamma (alpha_dbl, beta_y, tgamma (alpha_dbl),
145- digamma (alpha_dbl))
146- / Qn;
147- }
148- }
149- }
117+ } else if constexpr (partials_fvar && is_autodiff_v<T_shape>) {
118+ // fvar path: use unit derivative trick to compute gradients
119+ auto alpha_unit = alpha_dbl;
120+ alpha_unit.d_ = 1 ;
121+ auto beta_unit = beta_y;
122+ beta_unit.d_ = 0 ;
123+
124+ if (use_cf) {
125+ log_Qn = internal::log_q_gamma_cf (alpha_dbl, beta_y);
126+ auto log_Qn_fvar = internal::log_q_gamma_cf (alpha_unit, beta_unit);
127+ dlogQ_dalpha = log_Qn_fvar.d_ ;
128+ } else {
129+ const T_partials_return Pn = gamma_p (alpha_dbl, beta_y);
130+ log_Qn = log1m (Pn);
131+
132+ if (!std::isfinite (value_of (value_of (log_Qn))) && beta_y > 0.0 ) {
133+ // Fallback to continued fraction
134+ log_Qn = internal::log_q_gamma_cf (alpha_dbl, beta_y);
135+ auto log_Qn_fvar = internal::log_q_gamma_cf (alpha_unit, beta_unit);
136+ dlogQ_dalpha = log_Qn_fvar.d_ ;
137+ } else {
138+ auto log_Qn_fvar = log1m (gamma_p (alpha_unit, beta_unit));
139+ dlogQ_dalpha = log_Qn_fvar.d_ ;
150140 }
151141 }
142+ } else {
143+ // No alpha derivative needed (alpha is constant or double-only)
144+ if (use_cf) {
145+ log_Qn = internal::log_q_gamma_cf (alpha_dbl, beta_y);
146+ } else {
147+ const T_partials_return Pn = gamma_p (alpha_dbl, beta_y);
148+ log_Qn = log1m (Pn);
152149
153- if constexpr (is_autodiff_v<T_shape>) {
154- if (!use_cf) {
155- if constexpr (partials_fvar) {
156- auto alpha_unit = alpha_dbl;
157- alpha_unit.d_ = 1 ;
158- auto beta_unit = beta_y;
159- beta_unit.d_ = 0 ;
160- auto log_Qn_fvar = log1m (gamma_p (alpha_unit, beta_unit));
161- dlogQ_dalpha = log_Qn_fvar.d_ ;
162- } else {
163- const T_partials_return Qn = exp (log_Qn);
164- if (Qn > 0.0 ) {
165- dlogQ_dalpha = -grad_reg_lower_inc_gamma (alpha_dbl, beta_y) / Qn;
166- } else {
167- // Fallback to continued fraction if Q rounds to zero
168- if constexpr (!any_fvar) {
169- auto log_q_result
170- = log_gamma_q_dgamma (alpha_dbl_val, beta_y_dbl);
171- log_Qn = log_q_result.log_q ;
172- dlogQ_dalpha = log_q_result.dlog_q_da ;
173- } else {
174- log_Qn = internal::log_q_gamma_cf (alpha_dbl, beta_y);
175- const T_partials_return Qn_cf = exp (log_Qn);
176- dlogQ_dalpha
177- = grad_reg_inc_gamma (alpha_dbl, beta_y, tgamma (alpha_dbl),
178- digamma (alpha_dbl))
179- / Qn_cf;
180- }
181- use_cf = true ;
182- }
183- }
150+ if (!std::isfinite (value_of (value_of (log_Qn))) && beta_y > 0.0 ) {
151+ log_Qn = internal::log_q_gamma_cf (alpha_dbl, beta_y);
184152 }
185153 }
186154 }
0 commit comments