Skip to content

Commit 153193c

Browse files
committed
update bessel so that reverse mode is chosen over the perfect forward specialization
1 parent acd3426 commit 153193c

3 files changed

Lines changed: 24 additions & 26 deletions

File tree

stan/math/prim/fun/bessel_second_kind.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ inline T2 bessel_second_kind(int v, const T2 z) {
5252
* @param b Second input
5353
* @return Bessel second kind function applied to the two inputs.
5454
*/
55-
template <typename T1, typename T2, require_any_container_t<T1, T2>* = nullptr>
55+
template <typename T1, typename T2, require_any_container_t<T1, T2>* = nullptr,
56+
require_not_var_matrix_t<T2>* = nullptr>
5657
inline auto bessel_second_kind(T1&& a, T2&& b) {
5758
return apply_scalar_binary(
5859
[](auto&& c, auto&& d) {

stan/math/rev/fun/bessel_second_kind.hpp

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8,31 +8,26 @@
88
namespace stan {
99
namespace math {
1010

11-
inline var bessel_second_kind(int v, const var& a) {
12-
double ret_val = bessel_second_kind(v, a.val());
13-
auto precomp_bessel
14-
= v * ret_val / a.val() - bessel_second_kind(v + 1, a.val());
15-
return make_callback_var(ret_val, [precomp_bessel, a](auto& vi) mutable {
16-
a.adj() += vi.adj() * precomp_bessel;
17-
});
18-
}
19-
20-
/**
21-
* Overload with `var_value<Matrix>` for `int`, `std::vector<int>`, and
22-
* `std::vector<std::vector<int>>`
23-
*/
24-
template <typename T1, typename T2, require_st_integral<T1>* = nullptr,
25-
require_eigen_t<T2>* = nullptr>
26-
inline auto bessel_second_kind(const T1& v, const var_value<T2>& a) {
27-
auto ret_val = bessel_second_kind(v, a.val()).array().eval();
28-
auto v_map = as_array_or_scalar(v);
29-
auto precomp_bessel
30-
= to_arena(v_map * ret_val / a.val().array()
31-
- bessel_second_kind(v_map + 1, a.val().array()));
32-
return make_callback_var(
33-
ret_val.matrix(), [precomp_bessel, a](const auto& vi) mutable {
34-
a.adj().array() += vi.adj().array() * precomp_bessel;
35-
});
11+
template <typename T1, typename T2, require_st_integral<T1>* = nullptr, require_var_t<T2>* = nullptr>
12+
inline auto bessel_second_kind(T1&& v, T2&& a) {
13+
if constexpr (is_stan_scalar_v<T2>) {
14+
double ret_val = bessel_second_kind(v, a.val());
15+
auto precomp_bessel
16+
= v * ret_val / a.val() - bessel_second_kind(v + 1, a.val());
17+
return make_callback_var(ret_val, [precomp_bessel, a](auto& vi) mutable {
18+
a.adj() += vi.adj() * precomp_bessel;
19+
});
20+
} else {
21+
auto ret_val = bessel_second_kind(v, a.val()).array().eval();
22+
auto v_map = as_array_or_scalar(v);
23+
auto precomp_bessel
24+
= to_arena(v_map * ret_val / a.val().array()
25+
- bessel_second_kind(v_map + 1, a.val().array()));
26+
return make_callback_var(
27+
ret_val.matrix(), [precomp_bessel, a](const auto& vi) mutable {
28+
a.adj().array() += vi.adj().array() * precomp_bessel;
29+
});
30+
}
3631
}
3732

3833
} // namespace math

test/unit/math/mix/fun/bessel_second_kind_test.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,12 @@ TEST(mathMixScalFun, besselSecondKind_vec) {
2525
Eigen::VectorXd in2(2);
2626
in2 << 0.5, 3.4;
2727
stan::test::expect_ad_vectorized_binary(f, std_in1, in2);
28+
stan::test::expect_ad_matvar(f, std_in1, in2);
2829

2930
std::vector<std::vector<int>> std_std_in1{std_in1, std_in1};
3031
Eigen::MatrixXd mat_in2 = in2.replicate(1, 2);
3132
stan::test::expect_ad_vectorized_binary(f, std_std_in1, mat_in2);
33+
stan::test::expect_ad_matvar(f, std_std_in1, mat_in2);
3234
}
3335

3436
TEST(mathMixScalFun, besselSecondKind_matvec) {

0 commit comments

Comments
 (0)