Skip to content

Commit fa9a987

Browse files
spinkneysyclik
authored andcommitted
remove fwd and fix templating
1 parent 01940dd commit fa9a987

2 files changed

Lines changed: 232 additions & 55 deletions

File tree

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
#ifndef STAN_MATH_PRIM_FUN_LOG_GAMMA_Q_DGAMMA_HPP
2+
#define STAN_MATH_PRIM_FUN_LOG_GAMMA_Q_DGAMMA_HPP
3+
4+
#include <stan/math/prim/meta.hpp>
5+
#include <stan/math/prim/fun/constants.hpp>
6+
#include <stan/math/prim/fun/digamma.hpp>
7+
#include <stan/math/prim/fun/exp.hpp>
8+
#include <stan/math/prim/fun/gamma_p.hpp>
9+
#include <stan/math/prim/fun/gamma_q.hpp>
10+
#include <stan/math/prim/fun/grad_reg_inc_gamma.hpp>
11+
#include <stan/math/prim/fun/lgamma.hpp>
12+
#include <stan/math/prim/fun/log.hpp>
13+
#include <stan/math/prim/fun/log1m.hpp>
14+
#include <stan/math/prim/fun/tgamma.hpp>
15+
#include <stan/math/prim/fun/value_of.hpp>
16+
#include <cmath>
17+
18+
namespace stan {
19+
namespace math {
20+
21+
/**
22+
* Result structure containing log(Q(a,z)) and its gradient with respect to a.
23+
*
24+
* @tparam T return type
25+
*/
26+
template <typename T>
27+
struct log_gamma_q_result {
28+
T log_q; ///< log(Q(a,z)) where Q is upper regularized incomplete gamma
29+
T dlog_q_da; ///< d/da log(Q(a,z))
30+
};
31+
32+
/**
33+
* Compute log(Q(a,z)) and its gradient with respect to a using continued
34+
* fraction expansion, where Q(a,z) = Gamma(a,z) / Gamma(a) is the regularized
35+
* upper incomplete gamma function.
36+
*
37+
* This uses a continued fraction representation for numerical stability when
38+
* computing the upper incomplete gamma function in log space, along with
39+
* analytical gradient computation.
40+
*
41+
* @tparam T_a type of the shape parameter
42+
* @tparam T_z type of the value parameter
43+
* @param a shape parameter (must be positive)
44+
* @param z value parameter (must be non-negative)
45+
* @param max_steps maximum iterations for continued fraction
46+
* @param precision convergence threshold
47+
* @return structure containing log(Q(a,z)) and d/da log(Q(a,z))
48+
*/
49+
template <typename T_a, typename T_z>
50+
inline log_gamma_q_result<return_type_t<T_a, T_z>> log_gamma_q_dgamma(
51+
const T_a& a, const T_z& z, int max_steps = 250,
52+
double precision = 1e-16) {
53+
using std::exp;
54+
using std::fabs;
55+
using std::log;
56+
using T_return = return_type_t<T_a, T_z>;
57+
58+
const double a_dbl = value_of(a);
59+
const double z_dbl = value_of(z);
60+
61+
log_gamma_q_result<T_return> result;
62+
63+
// For z > a + 1, use continued fraction for better numerical stability
64+
if (z_dbl > a_dbl + 1.0) {
65+
// Continued fraction for Q(a,z) in log space
66+
// log(Q(a,z)) = log_prefactor - log(continued_fraction)
67+
const double log_prefactor = a_dbl * log(z_dbl) - z_dbl - lgamma(a_dbl);
68+
69+
double b = z_dbl + 1.0 - a_dbl;
70+
double C = (fabs(b) >= EPSILON) ? b : EPSILON;
71+
double D = 0.0;
72+
double f = C;
73+
74+
for (int i = 1; i <= max_steps; ++i) {
75+
const double an = -i * (i - a_dbl);
76+
b += 2.0;
77+
78+
D = b + an * D;
79+
if (fabs(D) < EPSILON) {
80+
D = EPSILON;
81+
}
82+
C = b + an / C;
83+
if (fabs(C) < EPSILON) {
84+
C = EPSILON;
85+
}
86+
87+
D = 1.0 / D;
88+
const double delta = C * D;
89+
f *= delta;
90+
91+
const double delta_m1 = fabs(delta - 1.0);
92+
if (delta_m1 < precision) {
93+
break;
94+
}
95+
}
96+
97+
result.log_q = log_prefactor - log(f);
98+
99+
// For gradient, use: d/da log(Q) = (1/Q) * dQ/da
100+
// grad_reg_inc_gamma computes dQ/da
101+
const double Q_val = exp(result.log_q);
102+
const double dQ_da = grad_reg_inc_gamma(a_dbl, z_dbl, tgamma(a_dbl), digamma(a_dbl));
103+
result.dlog_q_da = dQ_da / Q_val;
104+
105+
} else {
106+
// For z <= a + 1, use log1m(P(a,z)) for better numerical accuracy
107+
const double P_val = gamma_p(a_dbl, z_dbl);
108+
result.log_q = log1m(P_val);
109+
110+
// Gradient: d/da log(Q) = (1/Q) * dQ/da
111+
// grad_reg_inc_gamma computes dQ/da
112+
const double Q_val = exp(result.log_q);
113+
if (Q_val > 0) {
114+
const double dQ_da = grad_reg_inc_gamma(a_dbl, z_dbl, tgamma(a_dbl), digamma(a_dbl));
115+
result.dlog_q_da = dQ_da / Q_val;
116+
} else {
117+
// Fallback if Q rounds to zero - use asymptotic approximation
118+
result.dlog_q_da = log(z_dbl) - digamma(a_dbl);
119+
}
120+
}
121+
122+
return result;
123+
}
124+
125+
} // namespace math
126+
} // namespace stan
127+
128+
#endif

stan/math/prim/prob/gamma_lccdf.hpp

Lines changed: 104 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,23 @@
44
#include <stan/math/prim/meta.hpp>
55
#include <stan/math/prim/err.hpp>
66
#include <stan/math/prim/fun/constants.hpp>
7+
#include <stan/math/prim/fun/digamma.hpp>
78
#include <stan/math/prim/fun/exp.hpp>
89
#include <stan/math/prim/fun/fma.hpp>
910
#include <stan/math/prim/fun/gamma_p.hpp>
11+
#include <stan/math/prim/fun/grad_reg_inc_gamma.hpp>
1012
#include <stan/math/prim/fun/grad_reg_lower_inc_gamma.hpp>
11-
// #include <stan/math/prim/fun/lgamma.hpp>
12-
// #include <stan/math/prim/fun/log.hpp>
13+
#include <stan/math/prim/fun/lgamma.hpp>
14+
#include <stan/math/prim/fun/log.hpp>
1315
#include <stan/math/prim/fun/log1m.hpp>
1416
#include <stan/math/prim/fun/max_size.hpp>
1517
#include <stan/math/prim/fun/scalar_seq_view.hpp>
1618
#include <stan/math/prim/fun/size.hpp>
1719
#include <stan/math/prim/fun/size_zero.hpp>
20+
#include <stan/math/prim/fun/tgamma.hpp>
1821
#include <stan/math/prim/fun/value_of.hpp>
22+
#include <stan/math/prim/fun/log_gamma_q_dgamma.hpp>
1923
#include <stan/math/prim/functor/partials_propagator.hpp>
20-
21-
#include <stan/math/fwd/fun/lgamma.hpp>
22-
#include <stan/math/fwd/fun/log.hpp>
23-
#include <stan/math/fwd/fun/value_of.hpp>
2424
#include <cmath>
2525

2626
namespace stan {
@@ -30,34 +30,35 @@ namespace internal {
3030

3131
/**
3232
* Compute log(Q(a,x)) using continued fraction expansion for upper incomplete
33-
* gamma function, where Q(a,x) = Gamma(a,x) / Gamma(a) is the regularized
34-
* upper incomplete gamma function.
33+
* gamma function. When used with fvar types, automatically computes derivatives.
3534
*
36-
* @tparam T_a Type of shape parameter a; can be either double or fvar<double>
37-
* for forward-mode automatic differentiation
35+
* @tparam T_a Type of shape parameter a (double or fvar types)
3836
* @param a Shape parameter
3937
* @param x Value at which to evaluate
4038
* @param max_steps Maximum number of continued fraction iterations
4139
* @param precision Convergence threshold
4240
* @return log(Q(a,x)) with same type as T_a
4341
*/
44-
template <typename T_a>
45-
inline auto log_q_gamma_cf(const T_a& a, const double x, int max_steps = 250,
42+
template <typename T_a, typename T_x>
43+
inline auto log_q_gamma_cf(const T_a& a, const T_x& x, int max_steps = 250,
4644
double precision = 1e-16) {
4745
using stan::math::lgamma;
4846
using stan::math::log;
4947
using stan::math::value_of;
5048
using std::fabs;
49+
using T_return = return_type_t<T_a, T_x>;
5150

52-
const auto log_prefactor = a * log(x) - x - lgamma(a);
51+
const T_return a_ret = a;
52+
const T_return x_ret = x;
53+
const auto log_prefactor = a_ret * log(x_ret) - x_ret - lgamma(a_ret);
5354

54-
auto b = x + 1.0 - a;
55-
auto C = (fabs(value_of(b)) >= EPSILON) ? b : T_a(EPSILON);
56-
auto D = T_a(0.0);
55+
auto b = x_ret + 1.0 - a_ret;
56+
auto C = (fabs(value_of(b)) >= EPSILON) ? b : T_return(EPSILON);
57+
auto D = T_return(0.0);
5758
auto f = C;
5859

5960
for (int i = 1; i <= max_steps; ++i) {
60-
auto an = -i * (i - a);
61+
auto an = -i * (i - a_ret);
6162
b += 2.0;
6263

6364
D = b + an * D;
@@ -73,15 +74,9 @@ inline auto log_q_gamma_cf(const T_a& a, const double x, int max_steps = 250,
7374
auto delta = C * D;
7475
f *= delta;
7576

76-
const double delta_m1 = fabs(value_of(delta) - 1.0);
77+
const double delta_m1 = value_of(fabs(value_of(delta) - 1.0));
7778
if (delta_m1 < precision) {
78-
if constexpr (stan::is_fvar<std::decay_t<T_a>>::value) {
79-
if (fabs(value_of(delta.d_)) < precision) {
80-
break;
81-
}
82-
} else {
83-
break;
84-
}
79+
break;
8580
}
8681
}
8782

@@ -123,20 +118,25 @@ return_type_t<T_y, T_shape, T_inv_scale> gamma_lccdf(const T_y& y,
123118
const size_t N = max_size(y, alpha, beta);
124119

125120
constexpr bool need_y_beta_deriv = !is_constant_all<T_y, T_inv_scale>::value;
121+
constexpr bool any_fvar
122+
= is_fvar<scalar_type_t<T_y>>::value
123+
|| is_fvar<scalar_type_t<T_shape>>::value
124+
|| is_fvar<scalar_type_t<T_inv_scale>>::value;
125+
constexpr bool partials_fvar = is_fvar<T_partials_return>::value;
126126

127127
for (size_t n = 0; n < N; n++) {
128128
// Explicit results for extreme values
129129
// The gradients are technically ill-defined, but treated as zero
130-
const T_partials_return y_dbl = value_of(y_vec.val(n));
130+
const T_partials_return y_dbl = y_vec.val(n);
131131
if (y_dbl == 0.0) {
132132
continue;
133133
}
134134
if (y_dbl == INFTY) {
135135
return ops_partials.build(negative_infinity());
136136
}
137137

138-
const T_partials_return alpha_dbl = value_of(alpha_vec.val(n));
139-
const T_partials_return beta_dbl = value_of(beta_vec.val(n));
138+
const T_partials_return alpha_dbl = alpha_vec.val(n);
139+
const T_partials_return beta_dbl = beta_vec.val(n);
140140

141141
const T_partials_return beta_y = beta_dbl * y_dbl;
142142
if (beta_y == INFTY) {
@@ -146,20 +146,35 @@ return_type_t<T_y, T_shape, T_inv_scale> gamma_lccdf(const T_y& y,
146146
bool use_cf = beta_y > alpha_dbl + 1.0;
147147
T_partials_return log_Qn;
148148
[[maybe_unused]] T_partials_return dlogQ_dalpha = 0.0;
149-
// Extract double values for continued fraction - we handle y/beta
150-
// derivatives via hazard
151-
const double beta_y_dbl = value_of(value_of(beta_y));
152-
const double alpha_dbl_val = value_of(value_of(alpha_dbl));
149+
// Extract double values for the double-only continued fraction path.
150+
[[maybe_unused]] const double beta_y_dbl = value_of(value_of(beta_y));
151+
[[maybe_unused]] const double alpha_dbl_val = value_of(value_of(alpha_dbl));
153152

154153
if (use_cf) {
155-
if constexpr (is_autodiff_v<T_shape>) {
156-
stan::math::fvar<double> a_f(alpha_dbl_val, 1.0);
157-
const stan::math::fvar<double> logq_f
158-
= internal::log_q_gamma_cf(a_f, beta_y_dbl);
159-
log_Qn = logq_f.val_;
160-
dlogQ_dalpha = logq_f.d_;
154+
if constexpr (!any_fvar && is_autodiff_v<T_shape>) {
155+
// var-only: use analytical gradient with double inputs
156+
auto log_q_result = log_gamma_q_dgamma(alpha_dbl_val, beta_y_dbl);
157+
log_Qn = log_q_result.log_q;
158+
dlogQ_dalpha = log_q_result.dlog_q_da;
161159
} else {
162-
log_Qn = internal::log_q_gamma_cf(alpha_dbl_val, beta_y_dbl);
160+
log_Qn = internal::log_q_gamma_cf(alpha_dbl, beta_y);
161+
if constexpr (is_autodiff_v<T_shape>) {
162+
if constexpr (partials_fvar) {
163+
auto alpha_unit = alpha_dbl;
164+
alpha_unit.d_ = 1;
165+
auto beta_unit = beta_y;
166+
beta_unit.d_ = 0;
167+
auto log_Qn_fvar
168+
= internal::log_q_gamma_cf(alpha_unit, beta_unit);
169+
dlogQ_dalpha = log_Qn_fvar.d_;
170+
} else {
171+
const T_partials_return Qn = exp(log_Qn);
172+
dlogQ_dalpha
173+
= grad_reg_inc_gamma(alpha_dbl, beta_y, tgamma(alpha_dbl),
174+
digamma(alpha_dbl))
175+
/ Qn;
176+
}
177+
}
163178
}
164179
} else {
165180
const T_partials_return Pn = gamma_p(alpha_dbl, beta_y);
@@ -168,30 +183,64 @@ return_type_t<T_y, T_shape, T_inv_scale> gamma_lccdf(const T_y& y,
168183
if (!std::isfinite(value_of(value_of(log_Qn)))) {
169184
use_cf = beta_y > 0.0;
170185
if (use_cf) {
171-
if constexpr (is_autodiff_v<T_shape>) {
172-
stan::math::fvar<double> a_f(alpha_dbl_val, 1.0);
173-
const stan::math::fvar<double> logq_f
174-
= internal::log_q_gamma_cf(a_f, beta_y_dbl);
175-
log_Qn = logq_f.val_;
176-
dlogQ_dalpha = logq_f.d_;
186+
// Fallback to continued fraction if log1m fails
187+
if constexpr (!any_fvar && is_autodiff_v<T_shape>) {
188+
auto log_q_result = log_gamma_q_dgamma(alpha_dbl_val, beta_y_dbl);
189+
log_Qn = log_q_result.log_q;
190+
dlogQ_dalpha = log_q_result.dlog_q_da;
177191
} else {
178-
log_Qn = internal::log_q_gamma_cf(alpha_dbl_val, beta_y_dbl);
192+
log_Qn = internal::log_q_gamma_cf(alpha_dbl, beta_y);
193+
if constexpr (is_autodiff_v<T_shape>) {
194+
if constexpr (partials_fvar) {
195+
auto alpha_unit = alpha_dbl;
196+
alpha_unit.d_ = 1;
197+
auto beta_unit = beta_y;
198+
beta_unit.d_ = 0;
199+
auto log_Qn_fvar
200+
= internal::log_q_gamma_cf(alpha_unit, beta_unit);
201+
dlogQ_dalpha = log_Qn_fvar.d_;
202+
} else {
203+
const T_partials_return Qn = exp(log_Qn);
204+
dlogQ_dalpha
205+
= grad_reg_inc_gamma(alpha_dbl, beta_y, tgamma(alpha_dbl),
206+
digamma(alpha_dbl))
207+
/ Qn;
208+
}
209+
}
179210
}
180211
}
181212
}
182213

183214
if constexpr (is_autodiff_v<T_shape>) {
184215
if (!use_cf) {
185-
const T_partials_return Qn = exp(log_Qn);
186-
if (Qn > 0.0) {
187-
dlogQ_dalpha = -grad_reg_lower_inc_gamma(alpha_dbl, beta_y) / Qn;
216+
if constexpr (partials_fvar) {
217+
auto alpha_unit = alpha_dbl;
218+
alpha_unit.d_ = 1;
219+
auto beta_unit = beta_y;
220+
beta_unit.d_ = 0;
221+
auto log_Qn_fvar = log1m(gamma_p(alpha_unit, beta_unit));
222+
dlogQ_dalpha = log_Qn_fvar.d_;
188223
} else {
189-
stan::math::fvar<double> a_f(alpha_dbl_val, 1.0);
190-
const stan::math::fvar<double> logq_f
191-
= internal::log_q_gamma_cf(a_f, beta_y_dbl);
192-
log_Qn = logq_f.val_;
193-
dlogQ_dalpha = logq_f.d_;
194-
use_cf = true;
224+
const T_partials_return Qn = exp(log_Qn);
225+
if (Qn > 0.0) {
226+
dlogQ_dalpha = -grad_reg_lower_inc_gamma(alpha_dbl, beta_y) / Qn;
227+
} else {
228+
// Fallback to continued fraction if Q rounds to zero
229+
if constexpr (!any_fvar) {
230+
auto log_q_result
231+
= log_gamma_q_dgamma(alpha_dbl_val, beta_y_dbl);
232+
log_Qn = log_q_result.log_q;
233+
dlogQ_dalpha = log_q_result.dlog_q_da;
234+
} else {
235+
log_Qn = internal::log_q_gamma_cf(alpha_dbl, beta_y);
236+
const T_partials_return Qn_cf = exp(log_Qn);
237+
dlogQ_dalpha
238+
= grad_reg_inc_gamma(alpha_dbl, beta_y, tgamma(alpha_dbl),
239+
digamma(alpha_dbl))
240+
/ Qn_cf;
241+
}
242+
use_cf = true;
243+
}
195244
}
196245
}
197246
}

0 commit comments

Comments
 (0)