Skip to content

Commit 87d8cd9

Browse files
committed
use pf on signatures for reverse mode bessel_second_kind
1 parent 58a6d4e commit 87d8cd9

3 files changed

Lines changed: 36 additions & 21 deletions

File tree

stan/math/prim/fun/bessel_second_kind.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ inline T2 bessel_second_kind(int v, const T2 z) {
5353
* @return Bessel second kind function applied to the two inputs.
5454
*/
5555
template <typename T1, typename T2, require_any_container_t<T1, T2>* = nullptr,
56-
require_not_var_t<T2>* = nullptr>
56+
require_not_var_matrix_t<T2>* = nullptr>
5757
inline auto bessel_second_kind(T1&& a, T2&& b) {
5858
return apply_scalar_binary(
5959
[](auto&& c, auto&& d) {

stan/math/rev/fun/bessel_second_kind.hpp

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,27 +8,34 @@
88
namespace stan {
99
namespace math {
1010

11+
template <typename T1, typename T2, require_integral_t<T1>* = nullptr,
12+
require_var_t<T2>* = nullptr,
13+
require_stan_scalar_t<T2>* = nullptr>
14+
inline var bessel_second_kind(T1&& v, T2&& a) {
15+
double ret_val = bessel_second_kind(v, a.val());
16+
auto precomp_bessel
17+
= v * ret_val / a.val() - bessel_second_kind(v + 1, a.val());
18+
return make_callback_var(ret_val, [precomp_bessel, a](auto& vi) mutable {
19+
a.adj() += vi.adj() * precomp_bessel;
20+
});
21+
}
22+
23+
/**
24+
* Overload with `var_value<Matrix>` for `int`, `std::vector<int>`, and
25+
* `std::vector<std::vector<int>>`
26+
*/
1127
template <typename T1, typename T2, require_st_integral<T1>* = nullptr,
12-
require_var_t<T2>* = nullptr>
28+
require_var_matrix_t<T2>* = nullptr>
1329
inline auto bessel_second_kind(T1&& v, T2&& a) {
14-
if constexpr (is_stan_scalar_v<T2>) {
15-
double ret_val = bessel_second_kind(v, a.val());
16-
auto precomp_bessel
17-
= v * ret_val / a.val() - bessel_second_kind(v + 1, a.val());
18-
return make_callback_var(ret_val, [precomp_bessel, a](auto& vi) mutable {
19-
a.adj() += vi.adj() * precomp_bessel;
20-
});
21-
} else {
22-
auto ret_val = bessel_second_kind(v, a.val()).array().eval();
23-
auto v_map = as_array_or_scalar(v);
24-
auto precomp_bessel
25-
= to_arena(v_map * ret_val / a.val().array()
26-
- bessel_second_kind(v_map + 1, a.val().array()));
27-
return make_callback_var(
28-
ret_val.matrix(), [precomp_bessel, a](const auto& vi) mutable {
29-
a.adj().array() += vi.adj().array() * precomp_bessel;
30-
});
31-
}
30+
auto ret_val = bessel_second_kind(v, a.val()).array().eval();
31+
auto v_map = as_array_or_scalar(v);
32+
auto precomp_bessel
33+
= to_arena(v_map * ret_val / a.val().array()
34+
- bessel_second_kind(v_map + 1, a.val().array()));
35+
return make_callback_var(
36+
ret_val.matrix(), [precomp_bessel, a](const auto& vi) mutable {
37+
a.adj().array() += vi.adj().array() * precomp_bessel;
38+
});
3239
}
3340

3441
} // namespace math

test/unit/math/mix/fun/bessel_second_kind_test.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
TEST(mathMixScalFun, besselSecondKind) {
55
// bind integer arg because can't autodiff through
6-
auto f = [](const int x1) {
6+
auto f = [](const auto& x1) {
77
return
88
[=](const auto& x2) { return stan::math::bessel_second_kind(x1, x2); };
99
};
@@ -13,6 +13,13 @@ TEST(mathMixScalFun, besselSecondKind) {
1313
stan::test::expect_ad(f(1), 3.0);
1414
stan::test::expect_ad(f(1), std::numeric_limits<double>::quiet_NaN());
1515
stan::test::expect_ad(f(2), 2.79);
16+
std::vector<int> std_in1{3, 1};
17+
stan::test::expect_ad(f(std_in1), 4.0);
18+
stan::test::expect_ad(f(std_in1), std::numeric_limits<double>::quiet_NaN());
19+
stan::test::expect_ad(f(std_in1), -3.0);
20+
stan::test::expect_ad(f(std_in1), 3.0);
21+
stan::test::expect_ad(f(std_in1), std::numeric_limits<double>::quiet_NaN());
22+
stan::test::expect_ad(f(std_in1), 2.79);
1623
}
1724

1825
TEST(mathMixScalFun, besselSecondKind_vec) {
@@ -31,6 +38,7 @@ TEST(mathMixScalFun, besselSecondKind_vec) {
3138
Eigen::MatrixXd mat_in2 = in2.replicate(1, 2);
3239
stan::test::expect_ad_vectorized_binary(f, std_std_in1, mat_in2);
3340
stan::test::expect_ad_matvar(f, std_std_in1, mat_in2);
41+
3442
}
3543

3644
TEST(mathMixScalFun, besselSecondKind_matvec) {

0 commit comments

Comments
 (0)