Skip to content

Commit a47bc08

Browse files
spinkneysyclik
authored andcommitted
update logic per review
1 parent 5f899f6 commit a47bc08

1 file changed

Lines changed: 53 additions & 85 deletions

File tree

stan/math/prim/prob/gamma_lccdf.hpp

Lines changed: 53 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -86,101 +86,69 @@ return_type_t<T_y, T_shape, T_inv_scale> gamma_lccdf(const T_y& y,
8686
bool use_cf = beta_y > alpha_dbl + 1.0;
8787
T_partials_return log_Qn;
8888
[[maybe_unused]] T_partials_return dlogQ_dalpha = 0.0;
89-
// Extract double values for the double-only continued fraction path.
90-
[[maybe_unused]] const double beta_y_dbl = value_of(value_of(beta_y));
91-
[[maybe_unused]] const double alpha_dbl_val = value_of(value_of(alpha_dbl));
9289

93-
if (use_cf) {
94-
if constexpr (!any_fvar && is_autodiff_v<T_shape>) {
95-
// var-only: use analytical gradient with double inputs
90+
// Branch by autodiff type first, then handle use_cf logic inside each path
91+
if constexpr (!any_fvar && is_autodiff_v<T_shape>) {
92+
// var-only path: use log_gamma_q_dgamma which computes both log_q
93+
// and its gradient analytically with double inputs
94+
const double beta_y_dbl = value_of(value_of(beta_y));
95+
const double alpha_dbl_val = value_of(value_of(alpha_dbl));
96+
97+
if (use_cf) {
9698
auto log_q_result = log_gamma_q_dgamma(alpha_dbl_val, beta_y_dbl);
9799
log_Qn = log_q_result.log_q;
98100
dlogQ_dalpha = log_q_result.dlog_q_da;
99101
} else {
100-
log_Qn = internal::log_q_gamma_cf(alpha_dbl, beta_y);
101-
if constexpr (is_autodiff_v<T_shape>) {
102-
if constexpr (partials_fvar) {
103-
auto alpha_unit = alpha_dbl;
104-
alpha_unit.d_ = 1;
105-
auto beta_unit = beta_y;
106-
beta_unit.d_ = 0;
107-
auto log_Qn_fvar = internal::log_q_gamma_cf(alpha_unit, beta_unit);
108-
dlogQ_dalpha = log_Qn_fvar.d_;
109-
} else {
110-
const T_partials_return Qn = exp(log_Qn);
111-
dlogQ_dalpha
112-
= grad_reg_inc_gamma(alpha_dbl, beta_y, tgamma(alpha_dbl),
113-
digamma(alpha_dbl))
114-
/ Qn;
115-
}
102+
const T_partials_return Pn = gamma_p(alpha_dbl, beta_y);
103+
log_Qn = log1m(Pn);
104+
const T_partials_return Qn = exp(log_Qn);
105+
106+
// Check if we need to fallback to continued fraction
107+
bool need_cf_fallback = !std::isfinite(value_of(value_of(log_Qn)))
108+
|| Qn <= 0.0;
109+
if (need_cf_fallback && beta_y > 0.0) {
110+
auto log_q_result = log_gamma_q_dgamma(alpha_dbl_val, beta_y_dbl);
111+
log_Qn = log_q_result.log_q;
112+
dlogQ_dalpha = log_q_result.dlog_q_da;
113+
} else {
114+
dlogQ_dalpha = -grad_reg_lower_inc_gamma(alpha_dbl, beta_y) / Qn;
116115
}
117116
}
118-
} else {
119-
const T_partials_return Pn = gamma_p(alpha_dbl, beta_y);
120-
log_Qn = log1m(Pn);
121-
122-
if (!std::isfinite(value_of(value_of(log_Qn)))) {
123-
use_cf = beta_y > 0.0;
124-
if (use_cf) {
125-
// Fallback to continued fraction if log1m fails
126-
if constexpr (!any_fvar && is_autodiff_v<T_shape>) {
127-
auto log_q_result = log_gamma_q_dgamma(alpha_dbl_val, beta_y_dbl);
128-
log_Qn = log_q_result.log_q;
129-
dlogQ_dalpha = log_q_result.dlog_q_da;
130-
} else {
131-
log_Qn = internal::log_q_gamma_cf(alpha_dbl, beta_y);
132-
if constexpr (is_autodiff_v<T_shape>) {
133-
if constexpr (partials_fvar) {
134-
auto alpha_unit = alpha_dbl;
135-
alpha_unit.d_ = 1;
136-
auto beta_unit = beta_y;
137-
beta_unit.d_ = 0;
138-
auto log_Qn_fvar
139-
= internal::log_q_gamma_cf(alpha_unit, beta_unit);
140-
dlogQ_dalpha = log_Qn_fvar.d_;
141-
} else {
142-
const T_partials_return Qn = exp(log_Qn);
143-
dlogQ_dalpha
144-
= grad_reg_inc_gamma(alpha_dbl, beta_y, tgamma(alpha_dbl),
145-
digamma(alpha_dbl))
146-
/ Qn;
147-
}
148-
}
149-
}
117+
} else if constexpr (partials_fvar && is_autodiff_v<T_shape>) {
118+
// fvar path: use unit derivative trick to compute gradients
119+
auto alpha_unit = alpha_dbl;
120+
alpha_unit.d_ = 1;
121+
auto beta_unit = beta_y;
122+
beta_unit.d_ = 0;
123+
124+
if (use_cf) {
125+
log_Qn = internal::log_q_gamma_cf(alpha_dbl, beta_y);
126+
auto log_Qn_fvar = internal::log_q_gamma_cf(alpha_unit, beta_unit);
127+
dlogQ_dalpha = log_Qn_fvar.d_;
128+
} else {
129+
const T_partials_return Pn = gamma_p(alpha_dbl, beta_y);
130+
log_Qn = log1m(Pn);
131+
132+
if (!std::isfinite(value_of(value_of(log_Qn))) && beta_y > 0.0) {
133+
// Fallback to continued fraction
134+
log_Qn = internal::log_q_gamma_cf(alpha_dbl, beta_y);
135+
auto log_Qn_fvar = internal::log_q_gamma_cf(alpha_unit, beta_unit);
136+
dlogQ_dalpha = log_Qn_fvar.d_;
137+
} else {
138+
auto log_Qn_fvar = log1m(gamma_p(alpha_unit, beta_unit));
139+
dlogQ_dalpha = log_Qn_fvar.d_;
150140
}
151141
}
142+
} else {
143+
// No alpha derivative needed (alpha is constant or double-only)
144+
if (use_cf) {
145+
log_Qn = internal::log_q_gamma_cf(alpha_dbl, beta_y);
146+
} else {
147+
const T_partials_return Pn = gamma_p(alpha_dbl, beta_y);
148+
log_Qn = log1m(Pn);
152149

153-
if constexpr (is_autodiff_v<T_shape>) {
154-
if (!use_cf) {
155-
if constexpr (partials_fvar) {
156-
auto alpha_unit = alpha_dbl;
157-
alpha_unit.d_ = 1;
158-
auto beta_unit = beta_y;
159-
beta_unit.d_ = 0;
160-
auto log_Qn_fvar = log1m(gamma_p(alpha_unit, beta_unit));
161-
dlogQ_dalpha = log_Qn_fvar.d_;
162-
} else {
163-
const T_partials_return Qn = exp(log_Qn);
164-
if (Qn > 0.0) {
165-
dlogQ_dalpha = -grad_reg_lower_inc_gamma(alpha_dbl, beta_y) / Qn;
166-
} else {
167-
// Fallback to continued fraction if Q rounds to zero
168-
if constexpr (!any_fvar) {
169-
auto log_q_result
170-
= log_gamma_q_dgamma(alpha_dbl_val, beta_y_dbl);
171-
log_Qn = log_q_result.log_q;
172-
dlogQ_dalpha = log_q_result.dlog_q_da;
173-
} else {
174-
log_Qn = internal::log_q_gamma_cf(alpha_dbl, beta_y);
175-
const T_partials_return Qn_cf = exp(log_Qn);
176-
dlogQ_dalpha
177-
= grad_reg_inc_gamma(alpha_dbl, beta_y, tgamma(alpha_dbl),
178-
digamma(alpha_dbl))
179-
/ Qn_cf;
180-
}
181-
use_cf = true;
182-
}
183-
}
150+
if (!std::isfinite(value_of(value_of(log_Qn))) && beta_y > 0.0) {
151+
log_Qn = internal::log_q_gamma_cf(alpha_dbl, beta_y);
184152
}
185153
}
186154
}

0 commit comments

Comments
 (0)