@@ -24,7 +24,20 @@ namespace math {
2424 * This power-series representation converges for all gradients
2525 * under the same conditions as the 3F2 function itself.
2626 *
27- * @tparam T type of arguments and result
27+ * @tparam grad_a1 boolean indicating if gradient with respect to a1 is required
28+ * @tparam grad_a2 boolean indicating if gradient with respect to a2 is required
29+ * @tparam grad_a3 boolean indicating if gradient with respect to a3 is required
30+ * @tparam grad_b1 boolean indicating if gradient with respect to b1 is required
31+ * @tparam grad_b2 boolean indicating if gradient with respect to b2 is required
32+ * @tparam grad_z boolean indicating if gradient with respect to z is required
33+ * @tparam T1 a scalar type
34+ * @tparam T2 a scalar type
35+ * @tparam T3 a scalar type
36+ * @tparam T4 a scalar type
37+ * @tparam T5 a scalar type
38+ * @tparam T6 a scalar type
39+ * @tparam T7 a scalar type
40+ * @tparam T8 a scalar type
2841 * @param[out] g g pointer to array of six values of type T, result.
2942 * @param[in] a1 a1 see generalized hypergeometric function definition.
3043 * @param[in] a2 a2 see generalized hypergeometric function definition.
@@ -35,84 +48,96 @@ namespace math {
3548 * @param[in] precision precision of the infinite sum
3649 * @param[in] max_steps number of steps to take
3750 */
38- template <typename T>
39- void grad_F32 (T* g, const T& a1, const T& a2, const T& a3, const T& b1,
40- const T& b2, const T& z, const T& precision = 1e-6 ,
51+ template <bool grad_a1 = true , bool grad_a2 = true , bool grad_a3 = true ,
52+ bool grad_b1 = true , bool grad_b2 = true , bool grad_z = true ,
53+ typename T1, typename T2, typename T3, typename T4, typename T5,
54+ typename T6, typename T7, typename T8 = double >
55+ void grad_F32 (T1* g, const T2& a1, const T3& a2, const T4& a3, const T5& b1,
56+ const T6& b2, const T7& z, const T8& precision = 1e-6 ,
4157 int max_steps = 1e5 ) {
4258 check_3F2_converges (" grad_F32" , a1, a2, a3, b1, b2, z);
4359
44- using std::exp;
45- using std::fabs;
46- using std::log;
47-
4860 for (int i = 0 ; i < 6 ; ++i) {
4961 g[i] = 0.0 ;
5062 }
5163
52- T log_g_old[6 ];
64+ T1 log_g_old[6 ];
5365 for (auto & x : log_g_old) {
5466 x = NEGATIVE_INFTY;
5567 }
5668
57- T log_t_old = 0.0 ;
58- T log_t_new = 0.0 ;
69+ T1 log_t_old = 0.0 ;
70+ T1 log_t_new = 0.0 ;
5971
60- T log_z = log (z);
72+ T7 log_z = log (z);
6173
62- double log_t_new_sign = 1.0 ;
63- double log_t_old_sign = 1.0 ;
64- double log_g_old_sign[6 ];
74+ T1 log_t_new_sign = 1.0 ;
75+ T1 log_t_old_sign = 1.0 ;
76+ T1 log_g_old_sign[6 ];
6577 for (int i = 0 ; i < 6 ; ++i) {
6678 log_g_old_sign[i] = 1.0 ;
6779 }
68-
80+ std::array<T1, 6 > term{ 0 };
6981 for (int k = 0 ; k <= max_steps; ++k) {
70- T p = (a1 + k) * (a2 + k) * (a3 + k) / ((b1 + k) * (b2 + k) * (1 + k));
82+ T1 p = (a1 + k) * (a2 + k) * (a3 + k) / ((b1 + k) * (b2 + k) * (1 + k));
7183 if (p == 0 ) {
7284 return ;
7385 }
7486
7587 log_t_new += log (fabs (p)) + log_z;
7688 log_t_new_sign = p >= 0.0 ? log_t_new_sign : -log_t_new_sign;
89+ if constexpr (grad_a1) {
90+ term[0 ]
91+ = log_g_old_sign[0 ] * log_t_old_sign * exp (log_g_old[0 ] - log_t_old)
92+ + inv (a1 + k);
93+ log_g_old[0 ] = log_t_new + log (fabs (term[0 ]));
94+ log_g_old_sign[0 ] = term[0 ] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
95+ g[0 ] += log_g_old_sign[0 ] * exp (log_g_old[0 ]);
96+ }
97+
98+ if constexpr (grad_a2) {
99+ term[1 ]
100+ = log_g_old_sign[1 ] * log_t_old_sign * exp (log_g_old[1 ] - log_t_old)
101+ + inv (a2 + k);
102+ log_g_old[1 ] = log_t_new + log (fabs (term[1 ]));
103+ log_g_old_sign[1 ] = term[1 ] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
104+ g[1 ] += log_g_old_sign[1 ] * exp (log_g_old[1 ]);
105+ }
106+
107+ if constexpr (grad_a3) {
108+ term[2 ]
109+ = log_g_old_sign[2 ] * log_t_old_sign * exp (log_g_old[2 ] - log_t_old)
110+ + inv (a3 + k);
111+ log_g_old[2 ] = log_t_new + log (fabs (term[2 ]));
112+ log_g_old_sign[2 ] = term[2 ] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
113+ g[2 ] += log_g_old_sign[2 ] * exp (log_g_old[2 ]);
114+ }
115+
116+ if constexpr (grad_b1) {
117+ term[3 ]
118+ = log_g_old_sign[3 ] * log_t_old_sign * exp (log_g_old[3 ] - log_t_old)
119+ - inv (b1 + k);
120+ log_g_old[3 ] = log_t_new + log (fabs (term[3 ]));
121+ log_g_old_sign[3 ] = term[3 ] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
122+ g[3 ] += log_g_old_sign[3 ] * exp (log_g_old[3 ]);
123+ }
124+
125+ if constexpr (grad_b2) {
126+ term[4 ]
127+ = log_g_old_sign[4 ] * log_t_old_sign * exp (log_g_old[4 ] - log_t_old)
128+ - inv (b2 + k);
129+ log_g_old[4 ] = log_t_new + log (fabs (term[4 ]));
130+ log_g_old_sign[4 ] = term[4 ] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
131+ g[4 ] += log_g_old_sign[4 ] * exp (log_g_old[4 ]);
132+ }
77133
78- // g_old[0] = t_new * (g_old[0] / t_old + 1.0 / (a1 + k));
79- T term = log_g_old_sign[0 ] * log_t_old_sign * exp (log_g_old[0 ] - log_t_old)
80- + inv (a1 + k);
81- log_g_old[0 ] = log_t_new + log (fabs (term));
82- log_g_old_sign[0 ] = term >= 0.0 ? log_t_new_sign : -log_t_new_sign;
83-
84- // g_old[1] = t_new * (g_old[1] / t_old + 1.0 / (a2 + k));
85- term = log_g_old_sign[1 ] * log_t_old_sign * exp (log_g_old[1 ] - log_t_old)
86- + inv (a2 + k);
87- log_g_old[1 ] = log_t_new + log (fabs (term));
88- log_g_old_sign[1 ] = term >= 0.0 ? log_t_new_sign : -log_t_new_sign;
89-
90- // g_old[2] = t_new * (g_old[2] / t_old + 1.0 / (a3 + k));
91- term = log_g_old_sign[2 ] * log_t_old_sign * exp (log_g_old[2 ] - log_t_old)
92- + inv (a3 + k);
93- log_g_old[2 ] = log_t_new + log (fabs (term));
94- log_g_old_sign[2 ] = term >= 0.0 ? log_t_new_sign : -log_t_new_sign;
95-
96- // g_old[3] = t_new * (g_old[3] / t_old - 1.0 / (b1 + k));
97- term = log_g_old_sign[3 ] * log_t_old_sign * exp (log_g_old[3 ] - log_t_old)
98- - inv (b1 + k);
99- log_g_old[3 ] = log_t_new + log (fabs (term));
100- log_g_old_sign[3 ] = term >= 0.0 ? log_t_new_sign : -log_t_new_sign;
101-
102- // g_old[4] = t_new * (g_old[4] / t_old - 1.0 / (b2 + k));
103- term = log_g_old_sign[4 ] * log_t_old_sign * exp (log_g_old[4 ] - log_t_old)
104- - inv (b2 + k);
105- log_g_old[4 ] = log_t_new + log (fabs (term));
106- log_g_old_sign[4 ] = term >= 0.0 ? log_t_new_sign : -log_t_new_sign;
107-
108- // g_old[5] = t_new * (g_old[5] / t_old + 1.0 / z);
109- term = log_g_old_sign[5 ] * log_t_old_sign * exp (log_g_old[5 ] - log_t_old)
110- + inv (z);
111- log_g_old[5 ] = log_t_new + log (fabs (term));
112- log_g_old_sign[5 ] = term >= 0.0 ? log_t_new_sign : -log_t_new_sign;
113-
114- for (int i = 0 ; i < 6 ; ++i) {
115- g[i] += log_g_old_sign[i] * exp (log_g_old[i]);
134+ if constexpr (grad_z) {
135+ term[5 ]
136+ = log_g_old_sign[5 ] * log_t_old_sign * exp (log_g_old[5 ] - log_t_old)
137+ + inv (z);
138+ log_g_old[5 ] = log_t_new + log (fabs (term[5 ]));
139+ log_g_old_sign[5 ] = term[5 ] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
140+ g[5 ] += log_g_old_sign[5 ] * exp (log_g_old[5 ]);
116141 }
117142
118143 if (log_t_new <= log (precision)) {
0 commit comments