77#include < stan/math/prim/fun/dot_product.hpp>
88#include < stan/math/prim/fun/grad_pFq.hpp>
99#include < stan/math/prim/fun/hypergeometric_pFq.hpp>
10+ #include < stan/math/prim/fun/as_column_vector_or_scalar.hpp>
11+ #include < stan/math/prim/fun/to_ref.hpp>
1012
1113namespace stan {
1214namespace math {
@@ -30,33 +32,27 @@ template <typename Ta, typename Tb, typename Tz,
3032 bool grad_z = !is_constant<Tz>::value,
3133 require_all_vector_t <Ta, Tb>* = nullptr ,
3234 require_fvar_t <FvarT>* = nullptr >
33- inline FvarT hypergeometric_pFq (const Ta& a, const Tb& b, const Tz& z) {
34- using PartialsT = partials_type_t <FvarT>;
35- using ARefT = ref_type_t <Ta>;
36- using BRefT = ref_type_t <Tb>;
37-
38- ARefT a_ref = a;
39- BRefT b_ref = b;
35+ inline FvarT hypergeometric_pFq (Ta&& a, Tb&& b, Tz&& z) {
36+ auto && a_ref = to_ref (as_column_vector_or_scalar (a));
37+ auto && b_ref = to_ref (as_column_vector_or_scalar (b));
4038 auto && a_val = value_of (a_ref);
4139 auto && b_val = value_of (b_ref);
4240 auto && z_val = value_of (z);
43- PartialsT pfq_val = hypergeometric_pFq (a_val, b_val, z_val);
41+
42+ partials_type_t <FvarT> pfq_val = hypergeometric_pFq (a_val, b_val, z_val);
4443 auto grad_tuple
4544 = grad_pFq<grad_a, grad_b, grad_z>(pfq_val, a_val, b_val, z_val);
4645
4746 FvarT rtn = FvarT (pfq_val, 0.0 );
4847
49- if (grad_a) {
50- rtn.d_ += dot_product (forward_as<promote_scalar_t <FvarT, ARefT>>(a_ref).d (),
51- std::get<0 >(grad_tuple));
48+ if constexpr (grad_a) {
49+ rtn.d_ += dot_product (a_ref.d (), std::get<0 >(grad_tuple));
5250 }
53- if (grad_b) {
54- rtn.d_ += dot_product (forward_as<promote_scalar_t <FvarT, BRefT>>(b_ref).d (),
55- std::get<1 >(grad_tuple));
51+ if constexpr (grad_b) {
52+ rtn.d_ += dot_product (b_ref.d (), std::get<1 >(grad_tuple));
5653 }
57- if (grad_z) {
58- rtn.d_ += forward_as<promote_scalar_t <FvarT, Tz>>(z).d_
59- * std::get<2 >(grad_tuple);
54+ if constexpr (grad_z) {
55+ rtn.d_ += z.d_ * std::get<2 >(grad_tuple);
6056 }
6157
6258 return rtn;
0 commit comments