33
44#include < stan/math/rev/core.hpp>
55#include < stan/math/rev/meta.hpp>
6+ #include < stan/math/prim/fun/as_column_vector_or_scalar.hpp>
67#include < stan/math/prim/fun/grad_pFq.hpp>
78#include < stan/math/prim/fun/hypergeometric_pFq.hpp>
89
@@ -25,27 +26,24 @@ template <typename Ta, typename Tb, typename Tz,
2526 bool grad_a = !is_constant<Ta>::value,
2627 bool grad_b = !is_constant<Tb>::value,
2728 bool grad_z = !is_constant<Tz>::value,
28- require_all_matrix_t <Ta, Tb>* = nullptr ,
29+ require_all_vector_t <Ta, Tb>* = nullptr ,
2930 require_return_type_t <is_var, Ta, Tb, Tz>* = nullptr >
30- inline var hypergeometric_pFq (const Ta& a, const Tb& b, const Tz & z) {
31- arena_t <Ta> arena_a = a ;
32- arena_t <Tb> arena_b = b ;
33- auto pfq_val = hypergeometric_pFq (a .val (), b .val (), value_of (z));
31+ inline var hypergeometric_pFq (Ta&& a, Tb&& b, Tz& & z) {
32+ auto && arena_a = to_arena ( as_column_vector_or_scalar (std::forward<Ta>(a))) ;
33+ auto && arena_b = to_arena ( as_column_vector_or_scalar (std::forward<Tb>(b))) ;
34+ auto pfq_val = hypergeometric_pFq (arena_a .val (), arena_b .val (), value_of (z));
3435 return make_callback_var (
3536 pfq_val, [arena_a, arena_b, z, pfq_val](auto & vi) mutable {
3637 auto grad_tuple = grad_pFq<grad_a, grad_b, grad_z>(
3738 pfq_val, arena_a.val (), arena_b.val (), value_of (z));
3839 if constexpr (grad_a) {
39- forward_as<promote_scalar_t <var, Ta>>(arena_a).adj ()
40- += vi.adj () * std::get<0 >(grad_tuple);
40+ arena_a.adj () += vi.adj () * std::get<0 >(grad_tuple);
4141 }
4242 if constexpr (grad_b) {
43- forward_as<promote_scalar_t <var, Tb>>(arena_b).adj ()
44- += vi.adj () * std::get<1 >(grad_tuple);
43+ arena_b.adj () += vi.adj () * std::get<1 >(grad_tuple);
4544 }
4645 if constexpr (grad_z) {
47- forward_as<promote_scalar_t <var, Tz>>(z).adj ()
48- += vi.adj () * std::get<2 >(grad_tuple);
46+ z.adj () += vi.adj () * std::get<2 >(grad_tuple);
4947 }
5048 });
5149}
0 commit comments