Skip to content

Commit 939f03c

Browse files
committed
adds template parameters for grad_f32 for only calculating gradients that are needed. Small cleanup for beta_neg_binomial_lccdf
1 parent 7fc9aab commit 939f03c

4 files changed

Lines changed: 119 additions & 112 deletions

File tree

stan/math/prim/fun/grad_F32.hpp

Lines changed: 68 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,14 @@ namespace math {
2424
* This power-series representation converges for all gradients
2525
* under the same conditions as the 3F2 function itself.
2626
*
27-
* @tparam T type of arguments and result
27+
* @tparam T1 type of g
28+
* @tparam T1 type of g
29+
* @tparam T1 type of g
30+
* @tparam T1 type of g
31+
* @tparam T1 type of g
32+
* @tparam T1 type of g
33+
* @tparam T1 type of g
34+
* @tparam T1 type of g
2835
* @param[out] g g pointer to array of six values of type T, result.
2936
* @param[in] a1 a1 see generalized hypergeometric function definition.
3037
* @param[in] a2 a2 see generalized hypergeometric function definition.
@@ -35,84 +42,90 @@ namespace math {
3542
* @param[in] precision precision of the infinite sum
3643
* @param[in] max_steps number of steps to take
3744
*/
38-
template <typename T>
39-
void grad_F32(T* g, const T& a1, const T& a2, const T& a3, const T& b1,
40-
const T& b2, const T& z, const T& precision = 1e-6,
45+
template <bool grad_a1 = true, bool grad_a2 = true, bool grad_a3 = true,
46+
bool grad_b1 = true, bool grad_b2 = true, bool grad_z = true,
47+
typename T1, typename T2, typename T3, typename T4, typename T5,
48+
typename T6, typename T7, typename T8 = double>
49+
void grad_F32(T1* g, const T2& a1, const T3& a2, const T4& a3, const T5& b1,
50+
const T6& b2, const T7& z, const T8& precision = 1e-6,
4151
int max_steps = 1e5) {
4252
check_3F2_converges("grad_F32", a1, a2, a3, b1, b2, z);
4353

44-
using std::exp;
45-
using std::fabs;
46-
using std::log;
47-
4854
for (int i = 0; i < 6; ++i) {
4955
g[i] = 0.0;
5056
}
5157

52-
T log_g_old[6];
58+
T1 log_g_old[6];
5359
for (auto& x : log_g_old) {
5460
x = NEGATIVE_INFTY;
5561
}
5662

57-
T log_t_old = 0.0;
58-
T log_t_new = 0.0;
63+
T1 log_t_old = 0.0;
64+
T1 log_t_new = 0.0;
5965

60-
T log_z = log(z);
66+
T7 log_z = log(z);
6167

62-
double log_t_new_sign = 1.0;
63-
double log_t_old_sign = 1.0;
64-
double log_g_old_sign[6];
68+
T1 log_t_new_sign = 1.0;
69+
T1 log_t_old_sign = 1.0;
70+
T1 log_g_old_sign[6];
6571
for (int i = 0; i < 6; ++i) {
6672
log_g_old_sign[i] = 1.0;
6773
}
68-
74+
std::array<T1, 6> term{0};
6975
for (int k = 0; k <= max_steps; ++k) {
70-
T p = (a1 + k) * (a2 + k) * (a3 + k) / ((b1 + k) * (b2 + k) * (1 + k));
76+
T1 p = (a1 + k) * (a2 + k) * (a3 + k) / ((b1 + k) * (b2 + k) * (1 + k));
7177
if (p == 0) {
7278
return;
7379
}
7480

7581
log_t_new += log(fabs(p)) + log_z;
7682
log_t_new_sign = p >= 0.0 ? log_t_new_sign : -log_t_new_sign;
83+
if constexpr (grad_a1) {
84+
term[0] = log_g_old_sign[0] * log_t_old_sign * exp(log_g_old[0] - log_t_old)
85+
+ inv(a1 + k);
86+
log_g_old[0] = log_t_new + log(fabs(term[0]));
87+
log_g_old_sign[0] = term[0] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
88+
g[0] += log_g_old_sign[0] * exp(log_g_old[0]);
89+
}
90+
91+
if constexpr (grad_a2) {
92+
term[1] = log_g_old_sign[1] * log_t_old_sign * exp(log_g_old[1] - log_t_old)
93+
+ inv(a2 + k);
94+
log_g_old[1] = log_t_new + log(fabs(term[1]));
95+
log_g_old_sign[1] = term[1] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
96+
g[1] += log_g_old_sign[1] * exp(log_g_old[1]);
97+
}
98+
99+
if constexpr (grad_a3) {
100+
term[2] = log_g_old_sign[2] * log_t_old_sign * exp(log_g_old[2] - log_t_old)
101+
+ inv(a3 + k);
102+
log_g_old[2] = log_t_new + log(fabs(term[2]));
103+
log_g_old_sign[2] = term[2] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
104+
g[2] += log_g_old_sign[2] * exp(log_g_old[2]);
105+
}
106+
107+
if constexpr (grad_b1) {
108+
term[3] = log_g_old_sign[3] * log_t_old_sign * exp(log_g_old[3] - log_t_old)
109+
- inv(b1 + k);
110+
log_g_old[3] = log_t_new + log(fabs(term[3]));
111+
log_g_old_sign[3] = term[3] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
112+
g[3] += log_g_old_sign[3] * exp(log_g_old[3]);
113+
}
114+
115+
if constexpr (grad_b2) {
116+
term[4] = log_g_old_sign[4] * log_t_old_sign * exp(log_g_old[4] - log_t_old)
117+
- inv(b2 + k);
118+
log_g_old[4] = log_t_new + log(fabs(term[4]));
119+
log_g_old_sign[4] = term[4] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
120+
g[4] += log_g_old_sign[4] * exp(log_g_old[4]);
121+
}
77122

78-
// g_old[0] = t_new * (g_old[0] / t_old + 1.0 / (a1 + k));
79-
T term = log_g_old_sign[0] * log_t_old_sign * exp(log_g_old[0] - log_t_old)
80-
+ inv(a1 + k);
81-
log_g_old[0] = log_t_new + log(fabs(term));
82-
log_g_old_sign[0] = term >= 0.0 ? log_t_new_sign : -log_t_new_sign;
83-
84-
// g_old[1] = t_new * (g_old[1] / t_old + 1.0 / (a2 + k));
85-
term = log_g_old_sign[1] * log_t_old_sign * exp(log_g_old[1] - log_t_old)
86-
+ inv(a2 + k);
87-
log_g_old[1] = log_t_new + log(fabs(term));
88-
log_g_old_sign[1] = term >= 0.0 ? log_t_new_sign : -log_t_new_sign;
89-
90-
// g_old[2] = t_new * (g_old[2] / t_old + 1.0 / (a3 + k));
91-
term = log_g_old_sign[2] * log_t_old_sign * exp(log_g_old[2] - log_t_old)
92-
+ inv(a3 + k);
93-
log_g_old[2] = log_t_new + log(fabs(term));
94-
log_g_old_sign[2] = term >= 0.0 ? log_t_new_sign : -log_t_new_sign;
95-
96-
// g_old[3] = t_new * (g_old[3] / t_old - 1.0 / (b1 + k));
97-
term = log_g_old_sign[3] * log_t_old_sign * exp(log_g_old[3] - log_t_old)
98-
- inv(b1 + k);
99-
log_g_old[3] = log_t_new + log(fabs(term));
100-
log_g_old_sign[3] = term >= 0.0 ? log_t_new_sign : -log_t_new_sign;
101-
102-
// g_old[4] = t_new * (g_old[4] / t_old - 1.0 / (b2 + k));
103-
term = log_g_old_sign[4] * log_t_old_sign * exp(log_g_old[4] - log_t_old)
104-
- inv(b2 + k);
105-
log_g_old[4] = log_t_new + log(fabs(term));
106-
log_g_old_sign[4] = term >= 0.0 ? log_t_new_sign : -log_t_new_sign;
107-
108-
// g_old[5] = t_new * (g_old[5] / t_old + 1.0 / z);
109-
term = log_g_old_sign[5] * log_t_old_sign * exp(log_g_old[5] - log_t_old)
110-
+ inv(z);
111-
log_g_old[5] = log_t_new + log(fabs(term));
112-
log_g_old_sign[5] = term >= 0.0 ? log_t_new_sign : -log_t_new_sign;
113-
114-
for (int i = 0; i < 6; ++i) {
115-
g[i] += log_g_old_sign[i] * exp(log_g_old[i]);
123+
if constexpr (grad_z) {
124+
term[5] = log_g_old_sign[5] * log_t_old_sign * exp(log_g_old[5] - log_t_old)
125+
+ inv(z);
126+
log_g_old[5] = log_t_new + log(fabs(term[5]));
127+
log_g_old_sign[5] = term[5] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
128+
g[5] += log_g_old_sign[5] * exp(log_g_old[5]);
116129
}
117130

118131
if (log_t_new <= log(precision)) {

stan/math/prim/fun/grad_pFq.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ template <bool calc_a = true, bool calc_b = true, bool calc_z = true,
8989
typename T_Rtn = return_type_t<Ta, Tb, Tz>,
9090
typename Ta_Rtn = promote_scalar_t<T_Rtn, plain_type_t<Ta>>,
9191
typename Tb_Rtn = promote_scalar_t<T_Rtn, plain_type_t<Tb>>>
92-
std::tuple<Ta_Rtn, Tb_Rtn, T_Rtn> grad_pFq(const TpFq& pfq_val, const Ta& a,
92+
inline std::tuple<Ta_Rtn, Tb_Rtn, T_Rtn> grad_pFq(const TpFq& pfq_val, const Ta& a,
9393
const Tb& b, const Tz& z,
9494
double precision = 1e-14,
9595
int max_steps = 1e6) {

stan/math/prim/fun/hypergeometric_3F2.hpp

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,36 +17,33 @@ namespace stan {
1717
namespace math {
1818
namespace internal {
1919
template <typename Ta, typename Tb, typename Tz,
20-
typename T_return = return_type_t<Ta, Tb, Tz>,
21-
typename ArrayAT = Eigen::Array<scalar_type_t<Ta>, 3, 1>,
22-
typename ArrayBT = Eigen::Array<scalar_type_t<Ta>, 3, 1>,
2320
require_all_vector_t<Ta, Tb>* = nullptr,
2421
require_stan_scalar_t<Tz>* = nullptr>
25-
T_return hypergeometric_3F2_infsum(const Ta& a, const Tb& b, const Tz& z,
22+
inline return_type_t<Ta, Tb, Tz> hypergeometric_3F2_infsum(const Ta& a, const Tb& b, const Tz& z,
2623
double precision = 1e-6,
2724
int max_steps = 1e5) {
28-
ArrayAT a_array = as_array_or_scalar(a);
29-
ArrayBT b_array = append_row(as_array_or_scalar(b), 1.0);
25+
using T_return = return_type_t<Ta, Tb, Tz>;
26+
Eigen::Array<scalar_type_t<Ta>, 3, 1> a_array = as_array_or_scalar(a);
27+
Eigen::Array<scalar_type_t<Tb>, 3, 1> b_array = append_row(as_array_or_scalar(b), 1.0);
3028
check_3F2_converges("hypergeometric_3F2", a_array[0], a_array[1], a_array[2],
3129
b_array[0], b_array[1], z);
3230

3331
T_return t_acc = 1.0;
3432
T_return log_t = 0.0;
35-
T_return log_z = log(fabs(z));
36-
Eigen::ArrayXi a_signs = sign(value_of_rec(a_array));
37-
Eigen::ArrayXi b_signs = sign(value_of_rec(b_array));
38-
plain_type_t<decltype(a_array)> apk = a_array;
39-
plain_type_t<decltype(b_array)> bpk = b_array;
33+
auto log_z = log(fabs(z));
34+
Eigen::Array<int, 3, 1> a_signs = sign(value_of_rec(a_array));
35+
Eigen::Array<int, 3, 1> b_signs = sign(value_of_rec(b_array));
4036
int z_sign = sign(value_of_rec(z));
4137
int t_sign = z_sign * a_signs.prod() * b_signs.prod();
4238

4339
int k = 0;
44-
while (k <= max_steps && log_t >= log(precision)) {
40+
const double log_precision = log(precision);
41+
while (k <= max_steps && log_t >= log_precision) {
4542
// Replace zero values with 1 prior to taking the log so that we accumulate
4643
// 0.0 rather than -inf
47-
const auto& abs_apk = math::fabs((apk == 0).select(1.0, apk));
48-
const auto& abs_bpk = math::fabs((bpk == 0).select(1.0, bpk));
49-
T_return p = sum(log(abs_apk)) - sum(log(abs_bpk));
44+
const auto& abs_apk = math::fabs((a_array == 0).select(1.0, a_array));
45+
const auto& abs_bpk = math::fabs((b_array == 0).select(1.0, b_array));
46+
auto p = sum(log(abs_apk)) - sum(log(abs_bpk));
5047
if (p == NEGATIVE_INFTY) {
5148
return t_acc;
5249
}
@@ -59,10 +56,10 @@ T_return hypergeometric_3F2_infsum(const Ta& a, const Tb& b, const Tz& z,
5956
"overflow hypergeometric function did not converge.");
6057
}
6158
k++;
62-
apk.array() += 1.0;
63-
bpk.array() += 1.0;
64-
a_signs = sign(value_of_rec(apk));
65-
b_signs = sign(value_of_rec(bpk));
59+
a_array += 1.0;
60+
b_array += 1.0;
61+
a_signs = sign(value_of_rec(a_array));
62+
b_signs = sign(value_of_rec(b_array));
6663
t_sign = a_signs.prod() * b_signs.prod() * t_sign;
6764
}
6865
if (k == max_steps) {
@@ -115,7 +112,7 @@ T_return hypergeometric_3F2_infsum(const Ta& a, const Tb& b, const Tz& z,
115112
template <typename Ta, typename Tb, typename Tz,
116113
require_all_vector_t<Ta, Tb>* = nullptr,
117114
require_stan_scalar_t<Tz>* = nullptr>
118-
auto hypergeometric_3F2(const Ta& a, const Tb& b, const Tz& z) {
115+
inline auto hypergeometric_3F2(const Ta& a, const Tb& b, const Tz& z) {
119116
check_3F2_converges("hypergeometric_3F2", a[0], a[1], a[2], b[0], b[1], z);
120117
// Boost's pFq throws convergence errors in some cases, fallback to naive
121118
// infinite-sum approach (tests pass for these)
@@ -143,7 +140,7 @@ auto hypergeometric_3F2(const Ta& a, const Tb& b, const Tz& z) {
143140
*/
144141
template <typename Ta, typename Tb, typename Tz,
145142
require_all_stan_scalar_t<Ta, Tb, Tz>* = nullptr>
146-
auto hypergeometric_3F2(const std::initializer_list<Ta>& a,
143+
inline auto hypergeometric_3F2(const std::initializer_list<Ta>& a,
147144
const std::initializer_list<Tb>& b, const Tz& z) {
148145
return hypergeometric_3F2(std::vector<Ta>(a), std::vector<Tb>(b), z);
149146
}

stan/math/prim/prob/beta_neg_binomial_lccdf.hpp

Lines changed: 32 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,6 @@ template <typename T_n, typename T_r, typename T_alpha, typename T_beta>
4343
inline return_type_t<T_r, T_alpha, T_beta> beta_neg_binomial_lccdf(
4444
const T_n& n, const T_r& r, const T_alpha& alpha, const T_beta& beta,
4545
const double precision = 1e-8, const int max_steps = 1e6) {
46-
using std::exp;
47-
using std::log;
48-
using T_partials_return = partials_return_t<T_n, T_r, T_alpha, T_beta>;
49-
using T_r_ref = ref_type_t<T_r>;
50-
using T_alpha_ref = ref_type_t<T_alpha>;
51-
using T_beta_ref = ref_type_t<T_beta>;
5246
static constexpr const char* function = "beta_neg_binomial_lccdf";
5347
check_consistent_sizes(
5448
function, "Failures variable", n, "Number of successes parameter", r,
@@ -57,67 +51,70 @@ inline return_type_t<T_r, T_alpha, T_beta> beta_neg_binomial_lccdf(
5751
return 0;
5852
}
5953

54+
using T_r_ref = ref_type_t<T_r>;
6055
T_r_ref r_ref = r;
56+
using T_alpha_ref = ref_type_t<T_alpha>;
6157
T_alpha_ref alpha_ref = alpha;
58+
using T_beta_ref = ref_type_t<T_beta>;
6259
T_beta_ref beta_ref = beta;
6360
check_positive_finite(function, "Number of successes parameter", r_ref);
6461
check_positive_finite(function, "Prior success parameter", alpha_ref);
6562
check_positive_finite(function, "Prior failure parameter", beta_ref);
6663

67-
T_partials_return log_ccdf(0.0);
68-
auto ops_partials = make_partials_propagator(r_ref, alpha_ref, beta_ref);
6964

7065
scalar_seq_view<T_n> n_vec(n);
7166
scalar_seq_view<T_r_ref> r_vec(r_ref);
7267
scalar_seq_view<T_alpha_ref> alpha_vec(alpha_ref);
7368
scalar_seq_view<T_beta_ref> beta_vec(beta_ref);
74-
size_t size_n = stan::math::size(n);
69+
int size_n = stan::math::size(n);
7570
size_t max_size_seq_view = max_size(n, r, alpha, beta);
7671

7772
// Explicit return for extreme values
7873
// The gradients are technically ill-defined, but treated as zero
79-
for (size_t i = 0; i < size_n; i++) {
74+
for (int i = 0; i < size_n; i++) {
8075
if (n_vec.val(i) < 0) {
81-
return ops_partials.build(0.0);
76+
return 0.0;
8277
}
8378
}
8479

80+
using T_partials_return = partials_return_t<T_n, T_r, T_alpha, T_beta>;
81+
T_partials_return log_ccdf(0.0);
82+
auto ops_partials = make_partials_propagator(r_ref, alpha_ref, beta_ref);
8583
for (size_t i = 0; i < max_size_seq_view; i++) {
8684
// Explicit return for extreme values
8785
// The gradients are technically ill-defined, but treated as zero
8886
if (n_vec.val(i) == std::numeric_limits<int>::max()) {
8987
return ops_partials.build(negative_infinity());
9088
}
91-
T_partials_return n_dbl = n_vec.val(i);
92-
T_partials_return r_dbl = r_vec.val(i);
93-
T_partials_return alpha_dbl = alpha_vec.val(i);
94-
T_partials_return beta_dbl = beta_vec.val(i);
95-
T_partials_return b_plus_n = beta_dbl + n_dbl;
96-
T_partials_return r_plus_n = r_dbl + n_dbl;
97-
T_partials_return a_plus_r = alpha_dbl + r_dbl;
98-
T_partials_return one = 1;
99-
T_partials_return precision_t
100-
= precision; // default -6, set -8 to pass all tests
101-
102-
T_partials_return F
103-
= hypergeometric_3F2({one, b_plus_n + 1, r_plus_n + 1},
104-
{n_dbl + 2, a_plus_r + b_plus_n + 1}, one);
105-
T_partials_return C = lgamma(r_plus_n + 1) + lbeta(a_plus_r, b_plus_n + 1)
89+
auto n_dbl = n_vec.val(i);
90+
auto r_dbl = r_vec.val(i);
91+
auto alpha_dbl = alpha_vec.val(i);
92+
auto beta_dbl = beta_vec.val(i);
93+
auto b_plus_n = beta_dbl + n_dbl;
94+
auto r_plus_n = r_dbl + n_dbl;
95+
auto a_plus_r = alpha_dbl + r_dbl;
96+
using a_t = return_type_t<decltype(b_plus_n), decltype(r_plus_n)>;
97+
using b_t = return_type_t<decltype(n_dbl), decltype(a_plus_r), decltype(b_plus_n)>;
98+
auto F
99+
= hypergeometric_3F2(
100+
std::initializer_list<a_t>{1.0, b_plus_n + 1.0, r_plus_n + 1.0},
101+
std::initializer_list<b_t>{n_dbl + 2.0, a_plus_r + b_plus_n + 1.0}, 1.0);
102+
auto C = lgamma(r_plus_n + 1.0) + lbeta(a_plus_r, b_plus_n + 1.0)
106103
- lgamma(r_dbl) - lbeta(alpha_dbl, beta_dbl)
107104
- lgamma(n_dbl + 2);
108-
T_partials_return ccdf = exp(C) * F;
109-
T_partials_return log_ccdf_i = log(ccdf);
110-
log_ccdf += log_ccdf_i;
105+
log_ccdf += C + stan::math::log(F);
111106

112107
if constexpr (!is_constant_all<T_r, T_alpha, T_beta>::value) {
113-
T_partials_return digamma_n_r_alpha_beta
114-
= digamma(a_plus_r + b_plus_n + 1);
108+
auto digamma_n_r_alpha_beta
109+
= digamma(a_plus_r + b_plus_n + 1.0);
115110
T_partials_return dF[6];
116-
grad_F32(dF, one, b_plus_n + 1, r_plus_n + 1, n_dbl + 2,
117-
a_plus_r + b_plus_n + 1, one, precision_t, max_steps);
111+
grad_F32<false, !is_constant<T_beta>::value,
112+
!is_constant_all<T_r>::value, false, true, false>(dF, 1.0,
113+
b_plus_n + 1.0, r_plus_n + 1.0, n_dbl + 2.0,
114+
a_plus_r + b_plus_n + 1.0, 1.0, precision, max_steps);
118115

119116
if constexpr (!is_constant<T_r>::value || !is_constant<T_alpha>::value) {
120-
T_partials_return digamma_r_alpha = digamma(a_plus_r);
117+
auto digamma_r_alpha = digamma(a_plus_r);
121118
if constexpr (!is_constant_all<T_r>::value) {
122119
partials<0>(ops_partials)[i]
123120
+= digamma(r_plus_n + 1)
@@ -133,7 +130,7 @@ inline return_type_t<T_r, T_alpha, T_beta> beta_neg_binomial_lccdf(
133130

134131
if constexpr (!is_constant<T_alpha>::value
135132
|| !is_constant<T_beta>::value) {
136-
T_partials_return digamma_alpha_beta = digamma(alpha_dbl + beta_dbl);
133+
auto digamma_alpha_beta = digamma(alpha_dbl + beta_dbl);
137134
if constexpr (!is_constant<T_alpha>::value) {
138135
partials<1>(ops_partials)[i] += digamma_alpha_beta;
139136
}

0 commit comments

Comments
 (0)