Skip to content

Commit 5053a4b

Browse files
spinkneysyclik
authored andcommitted
super stable gamma_lccdf
1 parent 5b7cef4 commit 5053a4b

3 files changed

Lines changed: 263 additions & 44 deletions

File tree

stan/math/prim/prob/gamma_lccdf.hpp

Lines changed: 143 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -4,30 +4,84 @@
44
#include <stan/math/prim/meta.hpp>
55
#include <stan/math/prim/err.hpp>
66
#include <stan/math/prim/fun/constants.hpp>
7-
#include <stan/math/prim/fun/digamma.hpp>
87
#include <stan/math/prim/fun/exp.hpp>
9-
#include <stan/math/prim/fun/gamma_q.hpp>
10-
#include <stan/math/prim/fun/grad_reg_inc_gamma.hpp>
8+
#include <stan/math/prim/fun/fma.hpp>
9+
#include <stan/math/prim/fun/gamma_p.hpp>
10+
#include <stan/math/prim/fun/grad_reg_lower_inc_gamma.hpp>
1111
#include <stan/math/prim/fun/log.hpp>
12+
#include <stan/math/prim/fun/log1m.hpp>
1213
#include <stan/math/prim/fun/max_size.hpp>
1314
#include <stan/math/prim/fun/scalar_seq_view.hpp>
1415
#include <stan/math/prim/fun/size.hpp>
1516
#include <stan/math/prim/fun/size_zero.hpp>
16-
#include <stan/math/prim/fun/tgamma.hpp>
1717
#include <stan/math/prim/fun/value_of.hpp>
1818
#include <stan/math/prim/functor/partials_propagator.hpp>
19+
#include <stan/math/fwd/fun/lgamma.hpp>
20+
#include <stan/math/fwd/fun/log.hpp>
21+
#include <stan/math/fwd/fun/value_of.hpp>
1922
#include <cmath>
2023

2124
namespace stan {
2225
namespace math {
2326

27+
namespace internal {
28+
29+
template <typename T_a>
30+
inline auto log_q_gamma_cf(const T_a& a, const double x,
31+
int max_steps = 250, double precision = 1e-16) {
32+
using std::fabs;
33+
using stan::math::lgamma;
34+
using stan::math::log;
35+
using stan::math::value_of;
36+
37+
const auto log_prefactor = a * log(x) - x - lgamma(a);
38+
39+
auto b = x + 1.0 - a;
40+
auto C = (fabs(value_of(b)) >= EPSILON) ? b : T_a(EPSILON);
41+
auto D = T_a(0.0);
42+
auto f = C;
43+
44+
for (int i = 1; i <= max_steps; ++i) {
45+
auto an = -i * (i - a);
46+
b += 2.0;
47+
48+
D = b + an * D;
49+
if (fabs(value_of(D)) < EPSILON) {
50+
D = T_a(EPSILON);
51+
}
52+
C = b + an / C;
53+
if (fabs(value_of(C)) < EPSILON) {
54+
C = T_a(EPSILON);
55+
}
56+
57+
D = 1.0 / D;
58+
auto delta = C * D;
59+
f *= delta;
60+
61+
const double delta_m1 = fabs(value_of(delta) - 1.0);
62+
if (delta_m1 < precision) {
63+
if constexpr (stan::is_fvar<std::decay_t<T_a>>::value) {
64+
if (fabs(value_of(delta.d_)) < precision) {
65+
break;
66+
}
67+
} else {
68+
break;
69+
}
70+
}
71+
}
72+
73+
return log_prefactor - log(f);
74+
}
75+
76+
} // namespace internal
77+
2478
template <typename T_y, typename T_shape, typename T_inv_scale>
25-
inline return_type_t<T_y, T_shape, T_inv_scale> gamma_lccdf(
26-
const T_y& y, const T_shape& alpha, const T_inv_scale& beta) {
79+
return_type_t<T_y, T_shape, T_inv_scale> gamma_lccdf(const T_y& y,
80+
const T_shape& alpha,
81+
const T_inv_scale& beta) {
2782
using T_partials_return = partials_return_t<T_y, T_shape, T_inv_scale>;
2883
using std::exp;
2984
using std::log;
30-
using std::pow;
3185
using T_y_ref = ref_type_t<T_y>;
3286
using T_alpha_ref = ref_type_t<T_shape>;
3387
using T_beta_ref = ref_type_t<T_inv_scale>;
@@ -51,61 +105,106 @@ inline return_type_t<T_y, T_shape, T_inv_scale> gamma_lccdf(
51105
scalar_seq_view<T_y_ref> y_vec(y_ref);
52106
scalar_seq_view<T_alpha_ref> alpha_vec(alpha_ref);
53107
scalar_seq_view<T_beta_ref> beta_vec(beta_ref);
54-
size_t N = max_size(y, alpha, beta);
55-
56-
// Explicit return for extreme values
57-
// The gradients are technically ill-defined, but treated as zero
58-
for (size_t i = 0; i < stan::math::size(y); i++) {
59-
if (y_vec.val(i) == 0) {
60-
// LCCDF(0) = log(P(Y > 0)) = log(1) = 0
61-
return ops_partials.build(0.0);
62-
}
63-
}
108+
const size_t N = max_size(y, alpha, beta);
109+
110+
constexpr bool need_y_beta_deriv = !is_constant_all<T_y, T_inv_scale>::value;
64111

65112
for (size_t n = 0; n < N; n++) {
66113
// Explicit results for extreme values
67114
// The gradients are technically ill-defined, but treated as zero
68-
if (y_vec.val(n) == INFTY) {
69-
// LCCDF(∞) = log(P(Y > ∞)) = log(0) = -∞
115+
const T_partials_return y_dbl = value_of(y_vec.val(n));
116+
if (y_dbl == 0.0) {
117+
continue;
118+
}
119+
if (y_dbl == INFTY) {
70120
return ops_partials.build(negative_infinity());
71121
}
72122

73-
const T_partials_return y_dbl = y_vec.val(n);
74-
const T_partials_return alpha_dbl = alpha_vec.val(n);
75-
const T_partials_return beta_dbl = beta_vec.val(n);
76-
const T_partials_return beta_y_dbl = beta_dbl * y_dbl;
123+
const T_partials_return alpha_dbl = value_of(alpha_vec.val(n));
124+
const T_partials_return beta_dbl = value_of(beta_vec.val(n));
125+
126+
const T_partials_return beta_y = beta_dbl * y_dbl;
127+
if (beta_y == INFTY) {
128+
return ops_partials.build(negative_infinity());
129+
}
77130

78-
// Qn = 1 - Pn
79-
const T_partials_return Qn = gamma_q(alpha_dbl, beta_y_dbl);
80-
const T_partials_return log_Qn = log(Qn);
131+
bool use_cf = beta_y > alpha_dbl + 1.0;
132+
T_partials_return log_Qn;
133+
[[maybe_unused]] T_partials_return dlogQ_dalpha = 0.0;
134+
// Extract double values for continued fraction - we handle y/beta derivatives via hazard
135+
const double beta_y_dbl = value_of(value_of(beta_y));
136+
const double alpha_dbl_val = value_of(value_of(alpha_dbl));
137+
138+
if (use_cf) {
139+
if constexpr (is_autodiff_v<T_shape>) {
140+
stan::math::fvar<double> a_f(alpha_dbl_val, 1.0);
141+
const stan::math::fvar<double> logq_f
142+
= internal::log_q_gamma_cf(a_f, beta_y_dbl);
143+
log_Qn = logq_f.val_;
144+
dlogQ_dalpha = logq_f.d_;
145+
} else {
146+
log_Qn = internal::log_q_gamma_cf(alpha_dbl_val, beta_y_dbl);
147+
}
148+
} else {
149+
const T_partials_return Pn = gamma_p(alpha_dbl, beta_y);
150+
log_Qn = log1m(Pn);
151+
152+
if (!std::isfinite(value_of(value_of(log_Qn)))) {
153+
use_cf = beta_y > 0.0;
154+
if (use_cf) {
155+
if constexpr (is_autodiff_v<T_shape>) {
156+
stan::math::fvar<double> a_f(alpha_dbl_val, 1.0);
157+
const stan::math::fvar<double> logq_f
158+
= internal::log_q_gamma_cf(a_f, beta_y_dbl);
159+
log_Qn = logq_f.val_;
160+
dlogQ_dalpha = logq_f.d_;
161+
} else {
162+
log_Qn = internal::log_q_gamma_cf(alpha_dbl_val, beta_y_dbl);
163+
}
164+
}
165+
}
81166

167+
if constexpr (is_autodiff_v<T_shape>) {
168+
if (!use_cf) {
169+
const T_partials_return Qn = exp(log_Qn);
170+
if (Qn > 0.0) {
171+
dlogQ_dalpha = -grad_reg_lower_inc_gamma(alpha_dbl, beta_y) / Qn;
172+
} else {
173+
stan::math::fvar<double> a_f(alpha_dbl_val, 1.0);
174+
const stan::math::fvar<double> logq_f
175+
= internal::log_q_gamma_cf(a_f, beta_y_dbl);
176+
log_Qn = logq_f.val_;
177+
dlogQ_dalpha = logq_f.d_;
178+
use_cf = true;
179+
}
180+
}
181+
}
182+
}
183+
if (!std::isfinite(value_of(value_of(log_Qn)))) {
184+
return ops_partials.build(negative_infinity());
185+
}
82186
P += log_Qn;
83187

84-
if constexpr (is_any_autodiff_v<T_y, T_inv_scale>) {
85-
const T_partials_return log_y_dbl = log(y_dbl);
86-
const T_partials_return log_beta_dbl = log(beta_dbl);
87-
const T_partials_return log_pdf
88-
= alpha_dbl * log_beta_dbl - lgamma(alpha_dbl)
89-
+ (alpha_dbl - 1.0) * log_y_dbl - beta_y_dbl;
90-
const T_partials_return common_term = exp(log_pdf - log_Qn);
188+
if constexpr (need_y_beta_deriv) {
189+
const T_partials_return log_y = log(y_dbl);
190+
const T_partials_return log_beta = log(beta_dbl);
191+
const T_partials_return lgamma_alpha = lgamma(alpha_dbl);
192+
const T_partials_return alpha_minus_one = fma(alpha_dbl, log_y, -log_y);
193+
194+
const T_partials_return log_pdf = alpha_dbl * log_beta - lgamma_alpha
195+
+ alpha_minus_one - beta_y;
196+
197+
const T_partials_return hazard = exp(log_pdf - log_Qn); // f/Q
91198

92199
if constexpr (is_autodiff_v<T_y>) {
93-
// d/dy log(1-F(y)) = -f(y)/(1-F(y))
94-
partials<0>(ops_partials)[n] -= common_term;
200+
partials<0>(ops_partials)[n] -= hazard;
95201
}
96202
if constexpr (is_autodiff_v<T_inv_scale>) {
97-
// d/dbeta log(1-F(y)) = -y*f(y)/(beta*(1-F(y)))
98-
partials<2>(ops_partials)[n] -= y_dbl / beta_dbl * common_term;
203+
partials<2>(ops_partials)[n] -= (y_dbl / beta_dbl) * hazard;
99204
}
100205
}
101-
102206
if constexpr (is_autodiff_v<T_shape>) {
103-
const T_partials_return digamma_val = digamma(alpha_dbl);
104-
const T_partials_return gamma_val = tgamma(alpha_dbl);
105-
// d/dalpha log(1-F(y)) = grad_upper_inc_gamma / (1-F(y))
106-
partials<1>(ops_partials)[n]
107-
+= grad_reg_inc_gamma(alpha_dbl, beta_y_dbl, gamma_val, digamma_val)
108-
/ Qn;
207+
partials<1>(ops_partials)[n] += dlogQ_dalpha;
109208
}
110209
}
111210
return ops_partials.build(P);

test/unit/math/prim/prob/gamma_lccdf_test.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,31 @@ TEST(ProbGamma, lccdf_small_alpha_small_y) {
6666
EXPECT_LT(result, 0.0);
6767
}
6868

69+
TEST(ProbGamma, lccdf_alpha_gt_30_small_y_old_code_rounds_to_zero) {
70+
using stan::math::gamma_lccdf;
71+
using stan::math::gamma_p;
72+
using stan::math::gamma_q;
73+
using stan::math::log1m;
74+
75+
// For large alpha and very small y, the CCDF is extremely close to 1.
76+
// The old implementation computed `log(gamma_q(alpha, beta * y))`, which can
77+
// round to `log(1) == 0`. The updated implementation uses `log1m(gamma_p)`,
78+
// which preserves the tiny negative value.
79+
double y = 1e-8;
80+
double alpha = 31.25;
81+
double beta = 1.0;
82+
83+
double new_val = gamma_lccdf(y, alpha, beta);
84+
double expected = log1m(gamma_p(alpha, beta * y));
85+
86+
// Old code: log(gamma_q(alpha, beta * y))
87+
double old_val = std::log(gamma_q(alpha, beta * y));
88+
89+
EXPECT_EQ(old_val, 0.0);
90+
EXPECT_LT(new_val, 0.0);
91+
EXPECT_DOUBLE_EQ(new_val, expected);
92+
}
93+
6994
TEST(ProbGamma, lccdf_large_alpha_large_y) {
7095
using stan::math::gamma_lccdf;
7196

@@ -154,6 +179,29 @@ TEST(ProbGamma, lccdf_extreme_large_alpha) {
154179
EXPECT_TRUE(std::isfinite(result));
155180
}
156181

182+
TEST(ProbGamma, lccdf_large_alpha_1000_beta_3) {
183+
using stan::math::gamma_lccdf;
184+
185+
// Large alpha = 1000, beta = 3
186+
double alpha = 1000.0;
187+
double beta = 3.0;
188+
189+
// Test various y values
190+
std::vector<double> y_values = {100.0, 300.0, 333.333, 400.0, 500.0};
191+
192+
for (double y : y_values) {
193+
double result = gamma_lccdf(y, alpha, beta);
194+
195+
// Result should be finite
196+
EXPECT_TRUE(std::isfinite(result))
197+
<< "Failed for y=" << y << ", alpha=" << alpha << ", beta=" << beta;
198+
199+
// Result should be <= 0 (log of probability)
200+
EXPECT_LE(result, 0.0) << "Positive value for y=" << y << ", alpha="
201+
<< alpha << ", beta=" << beta;
202+
}
203+
}
204+
157205
TEST(ProbGamma, lccdf_monotonic_in_y) {
158206
using stan::math::gamma_lccdf;
159207

test/unit/math/rev/prob/gamma_lccdf_test.cpp

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,48 @@ TEST(ProbDistributionsGamma, lccdf_extreme_values_small) {
230230
}
231231
}
232232

233+
TEST(ProbDistributionsGamma,
234+
lccdf_alpha_gt_30_small_y_old_code_rounds_to_zero) {
235+
using stan::math::gamma_lccdf;
236+
using stan::math::gamma_p;
237+
using stan::math::gamma_q;
238+
using stan::math::log1m;
239+
using stan::math::var;
240+
241+
// Same comparison as the prim test, but also exercises autodiff for
242+
// alpha > 30.
243+
double y_d = 1e-8;
244+
double alpha_d = 31.25;
245+
double beta_d = 1.0;
246+
247+
var y_v = y_d;
248+
var alpha_v = alpha_d;
249+
var beta_v = beta_d;
250+
251+
var lccdf_var = gamma_lccdf(y_v, alpha_v, beta_v);
252+
253+
// Old code: log(gamma_q(alpha, beta * y))
254+
double old_val = std::log(gamma_q(alpha_d, beta_d * y_d));
255+
double expected = log1m(gamma_p(alpha_d, beta_d * y_d));
256+
257+
EXPECT_EQ(old_val, 0.0);
258+
EXPECT_LT(lccdf_var.val(), 0.0);
259+
EXPECT_DOUBLE_EQ(lccdf_var.val(), expected);
260+
261+
std::vector<var> vars = {y_v, alpha_v, beta_v};
262+
std::vector<double> grads;
263+
lccdf_var.grad(vars, grads);
264+
265+
for (size_t i = 0; i < grads.size(); ++i) {
266+
EXPECT_FALSE(std::isnan(grads[i])) << "Gradient " << i << " is NaN";
267+
EXPECT_TRUE(std::isfinite(grads[i]))
268+
<< "Gradient " << i << " is not finite";
269+
}
270+
271+
// d/dy log(CCDF) should be <= 0 (can underflow to -0)
272+
EXPECT_LE(grads[0], 0.0);
273+
}
274+
233275
TEST(ProbDistributionsGamma, lccdf_extreme_values_large) {
234276
using stan::math::gamma_lccdf;
235277
using stan::math::var;
@@ -258,6 +300,36 @@ TEST(ProbDistributionsGamma, lccdf_extreme_values_large) {
258300
}
259301
}
260302

303+
TEST(ProbDistributionsGamma, lccdf_large_alpha_1000_beta_3) {
304+
using stan::math::gamma_lccdf;
305+
using stan::math::var;
306+
307+
// Large alpha = 1000, beta = 3
308+
// Note: This test only checks values, not gradients, as large alpha values
309+
// can cause numerical issues with gradient computation
310+
double alpha_d = 1000.0;
311+
double beta_d = 3.0;
312+
313+
// Test various y values
314+
std::vector<double> y_values = {100.0, 300.0, 333.333, 400.0, 500.0};
315+
316+
for (double y_d : y_values) {
317+
var y_v = y_d;
318+
var alpha_v = alpha_d;
319+
var beta_v = beta_d;
320+
321+
var lccdf_var = gamma_lccdf(y_v, alpha_v, beta_v);
322+
323+
// Value should be finite and <= 0
324+
EXPECT_TRUE(std::isfinite(lccdf_var.val()))
325+
<< "Failed for y=" << y_d << ", alpha=" << alpha_d << ", beta="
326+
<< beta_d;
327+
EXPECT_LE(lccdf_var.val(), 0.0)
328+
<< "Positive value for y=" << y_d << ", alpha=" << alpha_d
329+
<< ", beta=" << beta_d;
330+
}
331+
}
332+
261333
TEST(ProbDistributionsGamma, lccdf_alpha_one_derivatives) {
262334
using stan::math::gamma_lccdf;
263335
using stan::math::var;

0 commit comments

Comments
 (0)