@@ -49,9 +49,9 @@ namespace math {
4949 * @param[in] max_steps number of steps to take
5050 */
5151template <bool grad_a1 = true , bool grad_a2 = true , bool grad_a3 = true ,
52- bool grad_b1 = true , bool grad_b2 = true , bool grad_z = true ,
53- typename T1, typename T2, typename T3, typename T4, typename T5,
54- typename T6, typename T7, typename T8 = double >
52+ bool grad_b1 = true , bool grad_b2 = true , bool grad_z = true ,
53+ typename T1, typename T2, typename T3, typename T4, typename T5,
54+ typename T6, typename T7, typename T8 = double >
5555void grad_F32 (T1* g, const T2& a1, const T3& a2, const T4& a3, const T5& b1,
5656 const T6& b2, const T7& z, const T8& precision = 1e-6 ,
5757 int max_steps = 1e5 ) {
@@ -87,47 +87,53 @@ void grad_F32(T1* g, const T2& a1, const T3& a2, const T4& a3, const T5& b1,
8787 log_t_new += log (fabs (p)) + log_z;
8888 log_t_new_sign = p >= 0.0 ? log_t_new_sign : -log_t_new_sign;
8989 if constexpr (grad_a1) {
90- term[0 ] = log_g_old_sign[0 ] * log_t_old_sign * exp (log_g_old[0 ] - log_t_old)
91- + inv (a1 + k);
90+ term[0 ]
91+ = log_g_old_sign[0 ] * log_t_old_sign * exp (log_g_old[0 ] - log_t_old)
92+ + inv (a1 + k);
9293 log_g_old[0 ] = log_t_new + log (fabs (term[0 ]));
9394 log_g_old_sign[0 ] = term[0 ] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
9495 g[0 ] += log_g_old_sign[0 ] * exp (log_g_old[0 ]);
9596 }
9697
9798 if constexpr (grad_a2) {
98- term[1 ] = log_g_old_sign[1 ] * log_t_old_sign * exp (log_g_old[1 ] - log_t_old)
99+ term[1 ]
100+ = log_g_old_sign[1 ] * log_t_old_sign * exp (log_g_old[1 ] - log_t_old)
99101 + inv (a2 + k);
100102 log_g_old[1 ] = log_t_new + log (fabs (term[1 ]));
101103 log_g_old_sign[1 ] = term[1 ] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
102104 g[1 ] += log_g_old_sign[1 ] * exp (log_g_old[1 ]);
103105 }
104106
105107 if constexpr (grad_a3) {
106- term[2 ] = log_g_old_sign[2 ] * log_t_old_sign * exp (log_g_old[2 ] - log_t_old)
108+ term[2 ]
109+ = log_g_old_sign[2 ] * log_t_old_sign * exp (log_g_old[2 ] - log_t_old)
107110 + inv (a3 + k);
108111 log_g_old[2 ] = log_t_new + log (fabs (term[2 ]));
109112 log_g_old_sign[2 ] = term[2 ] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
110113 g[2 ] += log_g_old_sign[2 ] * exp (log_g_old[2 ]);
111114 }
112115
113116 if constexpr (grad_b1) {
114- term[3 ] = log_g_old_sign[3 ] * log_t_old_sign * exp (log_g_old[3 ] - log_t_old)
117+ term[3 ]
118+ = log_g_old_sign[3 ] * log_t_old_sign * exp (log_g_old[3 ] - log_t_old)
115119 - inv (b1 + k);
116120 log_g_old[3 ] = log_t_new + log (fabs (term[3 ]));
117121 log_g_old_sign[3 ] = term[3 ] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
118122 g[3 ] += log_g_old_sign[3 ] * exp (log_g_old[3 ]);
119123 }
120124
121125 if constexpr (grad_b2) {
122- term[4 ] = log_g_old_sign[4 ] * log_t_old_sign * exp (log_g_old[4 ] - log_t_old)
126+ term[4 ]
127+ = log_g_old_sign[4 ] * log_t_old_sign * exp (log_g_old[4 ] - log_t_old)
123128 - inv (b2 + k);
124129 log_g_old[4 ] = log_t_new + log (fabs (term[4 ]));
125130 log_g_old_sign[4 ] = term[4 ] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
126131 g[4 ] += log_g_old_sign[4 ] * exp (log_g_old[4 ]);
127132 }
128133
129134 if constexpr (grad_z) {
130- term[5 ] = log_g_old_sign[5 ] * log_t_old_sign * exp (log_g_old[5 ] - log_t_old)
135+ term[5 ]
136+ = log_g_old_sign[5 ] * log_t_old_sign * exp (log_g_old[5 ] - log_t_old)
131137 + inv (z);
132138 log_g_old[5 ] = log_t_new + log (fabs (term[5 ]));
133139 log_g_old_sign[5 ] = term[5 ] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
0 commit comments