@@ -43,9 +43,9 @@ namespace math {
4343 * @param[in] max_steps number of steps to take
4444 */
4545template <bool grad_a1 = true , bool grad_a2 = true , bool grad_a3 = true ,
46- bool grad_b1 = true , bool grad_b2 = true , bool grad_z = true ,
47- typename T1, typename T2, typename T3, typename T4, typename T5,
48- typename T6, typename T7, typename T8 = double >
46+ bool grad_b1 = true , bool grad_b2 = true , bool grad_z = true ,
47+ typename T1, typename T2, typename T3, typename T4, typename T5,
48+ typename T6, typename T7, typename T8 = double >
4949void grad_F32 (T1* g, const T2& a1, const T3& a2, const T4& a3, const T5& b1,
5050 const T6& b2, const T7& z, const T8& precision = 1e-6 ,
5151 int max_steps = 1e5 ) {
@@ -81,47 +81,53 @@ void grad_F32(T1* g, const T2& a1, const T3& a2, const T4& a3, const T5& b1,
8181 log_t_new += log (fabs (p)) + log_z;
8282 log_t_new_sign = p >= 0.0 ? log_t_new_sign : -log_t_new_sign;
8383 if constexpr (grad_a1) {
84- term[0 ] = log_g_old_sign[0 ] * log_t_old_sign * exp (log_g_old[0 ] - log_t_old)
85- + inv (a1 + k);
84+ term[0 ]
85+ = log_g_old_sign[0 ] * log_t_old_sign * exp (log_g_old[0 ] - log_t_old)
86+ + inv (a1 + k);
8687 log_g_old[0 ] = log_t_new + log (fabs (term[0 ]));
8788 log_g_old_sign[0 ] = term[0 ] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
8889 g[0 ] += log_g_old_sign[0 ] * exp (log_g_old[0 ]);
8990 }
9091
9192 if constexpr (grad_a2) {
92- term[1 ] = log_g_old_sign[1 ] * log_t_old_sign * exp (log_g_old[1 ] - log_t_old)
93+ term[1 ]
94+ = log_g_old_sign[1 ] * log_t_old_sign * exp (log_g_old[1 ] - log_t_old)
9395 + inv (a2 + k);
9496 log_g_old[1 ] = log_t_new + log (fabs (term[1 ]));
9597 log_g_old_sign[1 ] = term[1 ] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
9698 g[1 ] += log_g_old_sign[1 ] * exp (log_g_old[1 ]);
9799 }
98100
99101 if constexpr (grad_a3) {
100- term[2 ] = log_g_old_sign[2 ] * log_t_old_sign * exp (log_g_old[2 ] - log_t_old)
102+ term[2 ]
103+ = log_g_old_sign[2 ] * log_t_old_sign * exp (log_g_old[2 ] - log_t_old)
101104 + inv (a3 + k);
102105 log_g_old[2 ] = log_t_new + log (fabs (term[2 ]));
103106 log_g_old_sign[2 ] = term[2 ] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
104107 g[2 ] += log_g_old_sign[2 ] * exp (log_g_old[2 ]);
105108 }
106109
107110 if constexpr (grad_b1) {
108- term[3 ] = log_g_old_sign[3 ] * log_t_old_sign * exp (log_g_old[3 ] - log_t_old)
111+ term[3 ]
112+ = log_g_old_sign[3 ] * log_t_old_sign * exp (log_g_old[3 ] - log_t_old)
109113 - inv (b1 + k);
110114 log_g_old[3 ] = log_t_new + log (fabs (term[3 ]));
111115 log_g_old_sign[3 ] = term[3 ] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
112116 g[3 ] += log_g_old_sign[3 ] * exp (log_g_old[3 ]);
113117 }
114118
115119 if constexpr (grad_b2) {
116- term[4 ] = log_g_old_sign[4 ] * log_t_old_sign * exp (log_g_old[4 ] - log_t_old)
120+ term[4 ]
121+ = log_g_old_sign[4 ] * log_t_old_sign * exp (log_g_old[4 ] - log_t_old)
117122 - inv (b2 + k);
118123 log_g_old[4 ] = log_t_new + log (fabs (term[4 ]));
119124 log_g_old_sign[4 ] = term[4 ] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
120125 g[4 ] += log_g_old_sign[4 ] * exp (log_g_old[4 ]);
121126 }
122127
123128 if constexpr (grad_z) {
124- term[5 ] = log_g_old_sign[5 ] * log_t_old_sign * exp (log_g_old[5 ] - log_t_old)
129+ term[5 ]
130+ = log_g_old_sign[5 ] * log_t_old_sign * exp (log_g_old[5 ] - log_t_old)
125131 + inv (z);
126132 log_g_old[5 ] = log_t_new + log (fabs (term[5 ]));
127133 log_g_old_sign[5 ] = term[5 ] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
0 commit comments