@@ -137,12 +137,13 @@ inline auto wiener4_distribution(const T_y& y, const T_a& a, const T_v& v,
137137 using ret_t = return_type_t <T_y, T_a, T_w, T_v>;
138138 const auto neg_v = -v;
139139 const auto one_m_w = 1.0 - w;
140-
140+
141141 const auto one_m_w_a_neg_v = one_m_w * a * neg_v;
142142
143143 const auto K1 = 0.5 * (fabs (neg_v) / a * y - one_m_w);
144- const auto arg
145- = fmax (0.0 , fmin (1.0 , exp (one_m_w_a_neg_v + square (neg_v) * y / 2.0 + err) / 2.0 ));
144+ const auto arg = fmax (
145+ 0.0 ,
146+ fmin (1.0 , exp (one_m_w_a_neg_v + square (neg_v) * y / 2.0 + err) / 2.0 ));
146147 const auto K2 = (arg == 0 ) ? INFTY
147148 : (arg == 1 ) ? NEGATIVE_INFTY
148149 : -sqrt (y) / 2.0 / a * inv_Phi (arg);
@@ -176,7 +177,7 @@ inline auto wiener4_distribution(const T_y& y, const T_a& a, const T_v& v,
176177 auto neg2 = dj + logMill ((rj + neg_vy) / sqrt_y);
177178 fminus = log_sum_exp (fminus, log_sum_exp (neg1, neg2));
178179 }
179- auto ans = ret_t (0.0 );
180+ auto ans = ret_t (0.0 );
180181 if (fplus > fminus) {
181182 ans = log_diff_exp (fplus, fminus);
182183 } else if (fplus < fminus) {
@@ -187,41 +188,41 @@ inline auto wiener4_distribution(const T_y& y, const T_a& a, const T_v& v,
187188 ret_t log_distribution = ans - one_m_w_a_neg_v - square (neg_v) * y / 2 ;
188189 return NaturalScale ? exp (log_distribution) : log_distribution;
189190 }
190- const auto log_a = log (a);
191- const auto log_v = log (fabs (neg_v));
192- ret_t fplus = NEGATIVE_INFTY;
193- ret_t fminus = NEGATIVE_INFTY;
194- for (auto k = K_large_value; k > 0 ; --k) {
195- auto log_k = log (k);
196- auto k_pi = k * pi ();
197- auto sin_k_pi_w = sin (k_pi * one_m_w);
198- if (sin_k_pi_w > 0 ) {
199- fplus = log_sum_exp (
200- fplus,
201- log_k - log_sum_exp (2.0 * log_v, 2.0 * (log_k + LOG_PI - log_a))
202- - 0.5 * square (k_pi / a) * y + log (sin_k_pi_w));
203- } else if (sin_k_pi_w < 0 ) {
204- fminus = log_sum_exp (
205- fminus,
206- log_k - log_sum_exp (2.0 * log_v, 2.0 * (log_k + LOG_PI - log_a))
207- - 0.5 * square (k_pi / a) * y + log (-sin_k_pi_w));
208- }
209- }
210- ret_t ans = NEGATIVE_INFTY;
211- if (fplus > fminus) {
212- ans = log_diff_exp (fplus, fminus);
213- } else if (fplus < fminus) {
214- ans = log_diff_exp (fminus, fplus);
215- }
216- auto summand_1 = log_probability_distribution (a, neg_v, one_m_w);
217- auto summand_2 = lg + (ans - one_m_w_a_neg_v - 0.5 * square (neg_v) * y);
218- ret_t log_distribution = NEGATIVE_INFTY;
219- if (summand_1 > summand_2) {
220- log_distribution = log_diff_exp (summand_1, summand_2);
221- } else if (summand_1 < summand_2) {
222- log_distribution = log_diff_exp (summand_2, summand_1);
223- }
224- return NaturalScale ? exp (log_distribution) : log_distribution;
191+ const auto log_a = log (a);
192+ const auto log_v = log (fabs (neg_v));
193+ ret_t fplus = NEGATIVE_INFTY;
194+ ret_t fminus = NEGATIVE_INFTY;
195+ for (auto k = K_large_value; k > 0 ; --k) {
196+ auto log_k = log (k);
197+ auto k_pi = k * pi ();
198+ auto sin_k_pi_w = sin (k_pi * one_m_w);
199+ if (sin_k_pi_w > 0 ) {
200+ fplus = log_sum_exp (
201+ fplus, log_k
202+ - log_sum_exp (2.0 * log_v, 2.0 * (log_k + LOG_PI - log_a))
203+ - 0.5 * square (k_pi / a) * y + log (sin_k_pi_w));
204+ } else if (sin_k_pi_w < 0 ) {
205+ fminus = log_sum_exp (
206+ fminus, log_k
207+ - log_sum_exp (2.0 * log_v, 2.0 * (log_k + LOG_PI - log_a))
208+ - 0.5 * square (k_pi / a) * y + log (-sin_k_pi_w));
209+ }
210+ }
211+ ret_t ans = NEGATIVE_INFTY;
212+ if (fplus > fminus) {
213+ ans = log_diff_exp (fplus, fminus);
214+ } else if (fplus < fminus) {
215+ ans = log_diff_exp (fminus, fplus);
216+ }
217+ auto summand_1 = log_probability_distribution (a, neg_v, one_m_w);
218+ auto summand_2 = lg + (ans - one_m_w_a_neg_v - 0.5 * square (neg_v) * y);
219+ ret_t log_distribution = NEGATIVE_INFTY;
220+ if (summand_1 > summand_2) {
221+ log_distribution = log_diff_exp (summand_1, summand_2);
222+ } else if (summand_1 < summand_2) {
223+ log_distribution = log_diff_exp (summand_2, summand_1);
224+ }
225+ return NaturalScale ? exp (log_distribution) : log_distribution;
225226}
226227
227228/* *
@@ -243,14 +244,14 @@ inline auto wiener4_cdf_grad_a(const T_y& y, const T_a& a, const T_v& v,
243244 using ret_t = return_type_t <T_y, T_a, T_w, T_v>;
244245 const auto neg_v = -v;
245246 const auto one_m_w = 1 - w;
246-
247+
247248 const auto one_m_w_neg_v = one_m_w * neg_v;
248249 const auto one_m_w_a_neg_v = one_m_w_neg_v * a;
249250
250251 const auto log_y = log (y);
251252 const auto log_a = log (a);
252- auto C1 = ret_t (LOG_TWO
253- - log_sum_exp (2.0 * log (fabs (neg_v)), 2.0 * (LOG_PI - log_a)));
253+ auto C1 = ret_t (
254+ LOG_TWO - log_sum_exp (2.0 * log (fabs (neg_v)), 2.0 * (LOG_PI - log_a)));
254255 C1 = log_sum_exp (C1, log_y);
255256 const auto factor = one_m_w_a_neg_v + square (neg_v) * y / 2.0 + err;
256257 const auto alphK = fmin (factor + LOG_PI + log_y + log_a - LOG_TWO - C1, 0.0 );
@@ -310,26 +311,26 @@ inline auto wiener4_cdf_grad_a(const T_y& y, const T_a& a, const T_v& v,
310311 const auto summands_small_y = ans / (y * F_k);
311312 return -one_m_w_neg_v * cdf + summands_small_y;
312313 }
313- ret_t ans = 0.0 ;
314- for (auto k = K_large_value; k > 0 ; --k) {
315- const auto kpi = k * pi ();
316- const auto kpia2 = square (kpi / a);
317- const auto denom = square (neg_v) + kpia2;
318- auto last = (square (kpi) / pow (a, 3 ) * (y + 2.0 / denom)) * k / denom
319- * exp (-0.5 * kpia2 * y);
320- ans -= last * sin (kpi * one_m_w);
321- }
322- const ret_t prob = fmin (exp (log_probability_distribution (a, neg_v, one_m_w)),
323- std::numeric_limits<ret_t >::max ());
324- const auto dav = log_probability_GradAV (a, neg_v, one_m_w);
325- auto dav_neg_v = dav * neg_v;
326- auto prob_deriv
327- = fabs (neg_v) == 0 ? ret_t (0.0 )
328- : is_inf (dav_neg_v) ? NEGATIVE_INFTY : dav_neg_v * prob;
329- ans = (-2.0 / a - one_m_w_neg_v) * (cdf - prob)
330- + ans * (2.0 * pi () / square (a))
331- * exp (-one_m_w_a_neg_v - 0.5 * square (neg_v) * y);
332- return prob_deriv + ans;
314+ ret_t ans = 0.0 ;
315+ for (auto k = K_large_value; k > 0 ; --k) {
316+ const auto kpi = k * pi ();
317+ const auto kpia2 = square (kpi / a);
318+ const auto denom = square (neg_v) + kpia2;
319+ auto last = (square (kpi) / pow (a, 3 ) * (y + 2.0 / denom)) * k / denom
320+ * exp (-0.5 * kpia2 * y);
321+ ans -= last * sin (kpi * one_m_w);
322+ }
323+ const ret_t prob = fmin (exp (log_probability_distribution (a, neg_v, one_m_w)),
324+ std::numeric_limits<ret_t >::max ());
325+ const auto dav = log_probability_GradAV (a, neg_v, one_m_w);
326+ auto dav_neg_v = dav * neg_v;
327+ auto prob_deriv = fabs (neg_v) == 0
328+ ? ret_t (0.0 )
329+ : is_inf (dav_neg_v) ? NEGATIVE_INFTY : dav_neg_v * prob;
330+ ans = (-2.0 / a - one_m_w_neg_v) * (cdf - prob)
331+ + ans * (2.0 * pi () / square (a))
332+ * exp (-one_m_w_a_neg_v - 0.5 * square (neg_v) * y);
333+ return prob_deriv + ans;
333334}
334335
335336/* *
@@ -351,10 +352,10 @@ inline auto wiener4_cdf_grad_v(const T_y& y, const T_a& a, const T_v& v,
351352 using ret_t = return_type_t <T_y, T_a, T_w, T_v>;
352353 const auto neg_v = -v;
353354 const auto one_m_w = 1.0 - w;
354-
355+
355356 const auto one_m_w_a = one_m_w * a;
356357 const auto one_m_w_a_neg_v = one_m_w_a * neg_v;
357-
358+
358359 const auto log_y = log (y);
359360 const auto factor = one_m_w_a_neg_v + square (neg_v) * y / 2.0 + err;
360361
@@ -419,25 +420,25 @@ inline auto wiener4_cdf_grad_v(const T_y& y, const T_a& a, const T_v& v,
419420 const auto summands_small_y = ans / F_k;
420421 return (one_m_w_a + neg_vy) * cdf - summands_small_y;
421422 }
422- ret_t ans = 0.0 ;
423- for (auto k = K_large_value; k > 0 ; --k) {
424- const auto kpi = k * pi ();
425- const auto kpia2 = square (kpi / a);
426- const auto ekpia2y = exp (-0.5 * kpia2 * y);
427- const auto denom = square (neg_v) + kpia2;
428- const auto denomk = k / denom;
429- auto last = denomk * ekpia2y / denom;
430- ans -= last * sin (kpi * one_m_w);
431- }
432- const ret_t prob = fmin (exp (log_probability_distribution (a, neg_v, one_m_w)),
433- std::numeric_limits<ret_t >::max ());
434- const auto dav = log_probability_GradAV (a, neg_v, one_m_w);
435- auto dav_a = dav * a;
436- auto prob_deriv = is_inf (dav_a) ? ret_t (NEGATIVE_INFTY) : dav_a * prob;
437- ans = (-one_m_w_a + v * y) * (cdf - prob)
438- + ans * 4.0 * v * pi () / square (a)
439- * exp (-one_m_w_a_neg_v - 0.5 * square (neg_v) * y);
440- return -(prob_deriv + ans);
423+ ret_t ans = 0.0 ;
424+ for (auto k = K_large_value; k > 0 ; --k) {
425+ const auto kpi = k * pi ();
426+ const auto kpia2 = square (kpi / a);
427+ const auto ekpia2y = exp (-0.5 * kpia2 * y);
428+ const auto denom = square (neg_v) + kpia2;
429+ const auto denomk = k / denom;
430+ auto last = denomk * ekpia2y / denom;
431+ ans -= last * sin (kpi * one_m_w);
432+ }
433+ const ret_t prob = fmin (exp (log_probability_distribution (a, neg_v, one_m_w)),
434+ std::numeric_limits<ret_t >::max ());
435+ const auto dav = log_probability_GradAV (a, neg_v, one_m_w);
436+ auto dav_a = dav * a;
437+ auto prob_deriv = is_inf (dav_a) ? ret_t (NEGATIVE_INFTY) : dav_a * prob;
438+ ans = (-one_m_w_a + v * y) * (cdf - prob)
439+ + ans * 4.0 * v * pi () / square (a)
440+ * exp (-one_m_w_a_neg_v - 0.5 * square (neg_v) * y);
441+ return -(prob_deriv + ans);
441442}
442443
443444/* *
@@ -459,18 +460,17 @@ inline auto wiener4_cdf_grad_w(const T_y& y, const T_a& a, const T_v& v,
459460 using ret_t = return_type_t <T_y, T_a, T_w, T_v>;
460461 const auto neg_v = -v;
461462 const auto one_m_w = 1 - w;
462-
463+
463464 const auto one_m_w_a_neg_v = one_m_w * a * neg_v;
464-
465+
465466 const auto factor = one_m_w_a_neg_v + square (neg_v) * y / 2.0 + err;
466467
467468 const auto log_y = log (y);
468469 const auto log_a = log (a);
469470 const auto temp = -fmin (exp (log_a - LOG_PI - 0.5 * log_y),
470471 std::numeric_limits<ret_t >::max ());
471472 auto alphK_large
472- = fmin (exp (factor + 0.5 * (LOG_PI + log_y) - 1.5 * LOG_TWO - log_a),
473- 1.0 );
473+ = fmin (exp (factor + 0.5 * (LOG_PI + log_y) - 1.5 * LOG_TWO - log_a), 1.0 );
474474 alphK_large = fmax (0.0 , alphK_large);
475475 const auto K_large_value
476476 = fmax (ceil ((alphK_large == 0 )
@@ -484,8 +484,7 @@ inline auto wiener4_cdf_grad_w(const T_y& y, const T_a& a, const T_v& v,
484484 const auto K_large = fabs (neg_v) / a * y - wdash;
485485 const auto lv = log1p (square (neg_v) * y);
486486 const auto alphK_small = factor - LOG_TWO - lv;
487- const auto arg
488- = fmin (exp (alphK_small), 1.0 );
487+ const auto arg = fmin (exp (alphK_small), 1.0 );
489488 const auto K_small
490489 = (arg == 0 )
491490 ? INFTY
@@ -535,43 +534,43 @@ inline auto wiener4_cdf_grad_w(const T_y& y, const T_a& a, const T_v& v,
535534 const auto summands_small_y = ans / (y * F_k);
536535 return neg_v * a * cdf - summands_small_y;
537536 }
538- ret_t ans = 0.0 ;
539- for (auto k = K_large_value; k > 0 ; --k) {
540- const auto kpi = k * pi ();
541- const auto kpia2 = square (kpi / a);
542- const auto ekpia2y = exp (-0.5 * kpia2 * y);
543- const auto denom = square (neg_v) + kpia2;
544- const auto denomk = k / denom;
545- auto last = kpi;
546- last *= denomk * ekpia2y;
547- ans -= last * cos (kpi * one_m_w);
548- }
549- const auto evaw = exp (-one_m_w_a_neg_v - 0.5 * square (neg_v) * y);
550- const ret_t prob = fmin (exp (log_probability_distribution (a, neg_v, one_m_w)),
551- std::numeric_limits<ret_t >::max ());
537+ ret_t ans = 0.0 ;
538+ for (auto k = K_large_value; k > 0 ; --k) {
539+ const auto kpi = k * pi ();
540+ const auto kpia2 = square (kpi / a);
541+ const auto ekpia2y = exp (-0.5 * kpia2 * y);
542+ const auto denom = square (neg_v) + kpia2;
543+ const auto denomk = k / denom;
544+ auto last = kpi;
545+ last *= denomk * ekpia2y;
546+ ans -= last * cos (kpi * one_m_w);
547+ }
548+ const auto evaw = exp (-one_m_w_a_neg_v - 0.5 * square (neg_v) * y);
549+ const ret_t prob = fmin (exp (log_probability_distribution (a, neg_v, one_m_w)),
550+ std::numeric_limits<ret_t >::max ());
552551
553- // Calculate the probability term 'P' on log scale
554- auto dav = ret_t (-1 / w);
555- if (neg_v != 0 ) {
556- auto nearly_one = ret_t (1.0 - 1.0e-6 );
557- const auto sign_v = (neg_v < 0 ) ? 1 : -1 ;
558- const auto sign_two_va_one_minus_w = sign_v * (2.0 * neg_v * a * w);
559- const auto exp_arg = exp (sign_two_va_one_minus_w);
560- if (exp_arg >= nearly_one) {
561- dav = -1.0 / w;
562- } else {
563- auto prob = LOG_TWO + log (fabs (neg_v)) + log (a) - log1m (exp_arg);
564- if (neg_v < 0 ) {
565- prob += sign_two_va_one_minus_w;
566- }
567- dav = -exp (prob);
552+ // Calculate the probability term 'P' on log scale
553+ auto dav = ret_t (-1 / w);
554+ if (neg_v != 0 ) {
555+ auto nearly_one = ret_t (1.0 - 1.0e-6 );
556+ const auto sign_v = (neg_v < 0 ) ? 1 : -1 ;
557+ const auto sign_two_va_one_minus_w = sign_v * (2.0 * neg_v * a * w);
558+ const auto exp_arg = exp (sign_two_va_one_minus_w);
559+ if (exp_arg >= nearly_one) {
560+ dav = -1.0 / w;
561+ } else {
562+ auto prob = LOG_TWO + log (fabs (neg_v)) + log (a) - log1m (exp_arg);
563+ if (neg_v < 0 ) {
564+ prob += sign_two_va_one_minus_w;
568565 }
566+ dav = -exp (prob);
569567 }
568+ }
570569
571- const auto pia2 = 2.0 * pi () / square (a);
572- auto prob_deriv = dav * prob;
573- ans = v * a * (cdf - prob) + ans * pia2 * evaw;
574- return -(prob_deriv + ans);
570+ const auto pia2 = 2.0 * pi () / square (a);
571+ auto prob_deriv = dav * prob;
572+ ans = v * a * (cdf - prob) + ans * pia2 * evaw;
573+ return -(prob_deriv + ans);
575574}
576575} // namespace internal
577576
0 commit comments