@@ -156,7 +156,7 @@ namespace internal {
156156 * (x_left + x_right) / 2 is returned instead.
157157 */
158158template <typename Scalar>
159- [[nodiscard]] inline Scalar cubic_interpolation (Scalar x_left, Scalar f_left,
159+ [[nodiscard]] inline Scalar cubic_or_bisect_max (Scalar x_left, Scalar f_left,
160160 Scalar df_left, Scalar x_right,
161161 Scalar f_right,
162162 Scalar df_right) noexcept {
@@ -283,8 +283,8 @@ template <typename Scalar>
283283}
284284
285285template <typename Eval, typename Options>
286- inline auto cubic_interpolation (Eval&& low, Eval&& high, Options&& opt) {
287- auto alpha = cubic_interpolation (low.alpha (), low.obj (), low.dir (),
286+ inline auto cubic_or_bisect_max (Eval&& low, Eval&& high, Options&& opt) {
287+ auto alpha = cubic_or_bisect_max (low.alpha (), low.obj (), low.dir (),
288288 high.alpha (), high.obj (), high.dir ());
289289 const double width = high.alpha () - low.alpha ();
290290 const double guard = 1e-3 * width; // or make this an option
@@ -714,7 +714,7 @@ inline auto retry_evaluate(Update&& update, Proposal&& proposal, Curr&& curr,
714714 *
715715 * - If `low.dir()` and `high.dir()` have opposite signs and the right
716716 * endpoint `high` satisfies Armijo, a cubic interpolation of the endpoints
717- * is used (`cubic_interpolation (low, high, opt)`).
717+ * is used (`cubic_or_bisect_max (low, high, opt)`).
718718 * - Otherwise the trial is the simple bisection midpoint
719719 * \f$\tfrac{1}{2}(\alpha_\text{low} + \alpha_\text{high})\f$.
720720 *
@@ -864,6 +864,10 @@ inline WolfeStatus wolfe_line_search(Info& wolfe_info, UpdateFun&& update_fun,
864864 Eval low{0.0 , prev.obj (), dir_deriv_init};
865865 prev.dir () = dir_deriv_init;
866866 int total_updates = 0 ;
867+ auto eval_finite = [](const Eval& e, const WolfeData& state) {
868+ return std::isfinite (e.obj ()) && std::isfinite (e.dir ())
869+ && state.theta ().allFinite () && state.theta_grad ().allFinite ();
870+ };
867871 Eval best = low; // keep the best Armijo-OK in case strong-Wolfe fails
868872 auto update_with_tick = [&total_updates, &opt, &best, &update_fun](
869873 auto && proposal, auto && curr, auto && prev,
@@ -891,6 +895,7 @@ inline WolfeStatus wolfe_line_search(Info& wolfe_info, UpdateFun&& update_fun,
891895 = std::clamp (curr.alpha () * opt.scale_up , opt.min_alpha , opt.max_alpha );
892896 Eval high{alpha_start, curr.obj (), dir_deriv_init};
893897 WolfeStatus wolfe_check{WolfeReturn::Continue, 0 , 0 , false };
898+ bool high_has_eval = true ;
894899 // Initial check for numerical trouble
895900 {
896901 wolfe_check = update_with_tick (scratch, curr, prev, high, p);
@@ -915,6 +920,7 @@ inline WolfeStatus wolfe_line_search(Info& wolfe_info, UpdateFun&& update_fun,
915920 if (wolfe_check.stop_ != WolfeReturn::Continue) {
916921 return wolfe_check;
917922 }
923+ high_has_eval = true ;
918924 }
919925 wolfe_check = update_with_tick (scratch, curr, prev, best, p);
920926 if (wolfe_check.stop_ != WolfeReturn::Continue) {
@@ -929,50 +935,55 @@ inline WolfeStatus wolfe_line_search(Info& wolfe_info, UpdateFun&& update_fun,
929935 }
930936 }
931937 }
938+ bool found_right = false ;
932939 int num_backtracks = 0 ;
933940 /* *
934- * From Nocedal–Wright (2006), Algorithm 3.5:
935- * https://www.math.uci.edu/~qnie/Publications/NumericalOptimization.pdf
941+ * For each case
936942 * | armijo | wolfe | sign(g) | Action
937943 * -------+-------+---------+--------------------------------
938944 * | [1] T | T | | Accept alpha
939945 * | [2] T | F | > 0 | set low=high, expand high
940- * | [3] T | F | < 0 | Bracket found: stop
941- * | [4] F | T | | Bracket found: stop
942- * | [5] F | F | | Bracket found: stop
943- * NOTE: In an ideal case we would end up with a positive low directional gradient and
944- * negative high directional gradient. Cubic interpolation requires a bracket with directional
945- * shape like /\. This scheme does not gurantee a bracket with that shape will be found.
946- * So in the zoom we will have to check if we can do cubic or have to fallback to bisection.
946+ * | [3] T | F | < 0 | Set alpha_high <- alpha, stop
947+ * | [4] F | T | | Set alpha_high <- alpha, stop
948+ * | [5] F | F | | Set alpha_high <- alpha, stop
947949 **/
948- while (high.alpha () < opt.max_alpha ) {
950+ while (!found_right && high.alpha () < opt.max_alpha ) {
949951 num_backtracks++;
952+ // 1. Evaluate f(alpha) and g(alpha)
950953 wolfe_check = update_with_tick (scratch, curr, prev, high, p);
951954 if (wolfe_check.stop_ != WolfeReturn::Continue) {
952955 return wolfe_check;
953956 }
957+ high_has_eval = true ;
958+ const bool finite_ok = eval_finite (high, scratch);
959+ // 2. Handle numerical trouble first
960+ if (!finite_ok) { // f or g is NaN/Inf → shrink
961+ high.alpha () *= 0.5 ;
962+ high_has_eval = false ;
963+ if (high.alpha () < opt.min_alpha ) {
964+ break ;
965+ }
966+ continue ;
967+ }
954968 const bool armijo = check_armijo (high, prev, opt);
955969 const bool wolfe = check_wolfe (high, prev, opt);
956- // [1]
957- if (armijo && wolfe) {
970+ if (armijo && wolfe) { // [1]
958971 curr.update (scratch, high);
959972 return WolfeStatus{WolfeReturn::Wolfe, total_updates, num_backtracks,
960973 true };
961- } else if (armijo) {
962- if (best.obj () < high.obj ()) {
963- best = high;
964- }
965- // [2]
966- if (high.dir () > 0 ) {
967- low = high;
968- high.alpha () *= opt.scale_up ;
969- continue ;
970- }
971- // [3]
972- break ;
973974 }
974- // [3, 4, 5]
975- break ;
975+ if (armijo && best.obj () < high.obj ()) {
976+ best = high;
977+ }
978+ const bool dir_pos = high.dir () > 0 ;
979+ if (armijo && !wolfe && dir_pos) { // [2]
980+ low = high;
981+ high.alpha () *= opt.scale_up ;
982+ high_has_eval = false ;
983+ continue ;
984+ }
985+ // [3,4,5]
986+ found_right = true ;
976987 }
977988 const double grad_tol
978989 = std::max (opt.abs_grad_threshold ,
@@ -1007,6 +1018,13 @@ inline WolfeStatus wolfe_line_search(Info& wolfe_info, UpdateFun&& update_fun,
10071018 return WolfeStatus{WolfeReturn::Continue, total_updates, num_backtracks,
10081019 false };
10091020 };
1021+ if (!high_has_eval) {
1022+ wolfe_check = update_with_tick (scratch, curr, prev, high, p);
1023+ if (wolfe_check.stop_ != WolfeReturn::Continue) {
1024+ return wolfe_check;
1025+ }
1026+ high_has_eval = true ;
1027+ }
10101028 auto check_b = check_bounds (high);
10111029 if (check_b.stop_ != WolfeReturn::Continue) {
10121030 if (check_b.accept_ ) {
@@ -1018,19 +1036,7 @@ inline WolfeStatus wolfe_line_search(Info& wolfe_info, UpdateFun&& update_fun,
10181036 if (wolfe_check.stop_ != WolfeReturn::Continue) {
10191037 return wolfe_check;
10201038 }
1021- /* *
1022- * Zoom Step: (Alg 3.6, adapted to maximization via phi=-obj)
1023- *
1024- * Exit/update table (evaluated at `mid`, with `low` = best Armijo endpoint):
1025- * | Armijo? | obj(mid) >= obj(low)? | Wolfe? | dir(mid) >= 0? | Action
1026- * |---------|-----------------------|--------|----------------|--------------------------|
1027- * | T | F | T | * | accept mid [1] |
1028- * | T | T | * | * | high = mid [2] |
1029- * | T | F | F | T | high = low; low = mid [3]|
1030- * | T | F | F | F | low = mid [4] |
1031- * | F | * | * | * | high = mid [5] |
1032- * ----------------------------------------------------------------------------------------
1033- **/
1039+ // Zoom phase
10341040 while ((high.alpha () - low.alpha () > opt.min_alpha )
10351041 && high.alpha () > opt.min_alpha ) {
10361042 num_backtracks++;
@@ -1040,12 +1046,9 @@ inline WolfeStatus wolfe_line_search(Info& wolfe_info, UpdateFun&& update_fun,
10401046 const bool use_cubic = have_sign_change && high_armijo_ok;
10411047
10421048 // Choose trial alpha: cubic when bracket is good, else bisection.
1043- double alpha_mid{0 };
1044- if (use_cubic) {
1045- alpha_mid = cubic_interpolation (low, high, opt);
1046- } else {
1047- alpha_mid = 0.5 * (low.alpha () + high.alpha ());
1048- }
1049+ double alpha_mid = use_cubic ? cubic_or_bisect_max (low, high, opt)
1050+ : 0.5 * (low.alpha () + high.alpha ());
1051+
10491052 if (alpha_mid <= opt.min_alpha ) {
10501053 break ;
10511054 }
@@ -1060,7 +1063,6 @@ inline WolfeStatus wolfe_line_search(Info& wolfe_info, UpdateFun&& update_fun,
10601063 }
10611064 if (check_armijo (mid, prev, opt)) {
10621065 if (check_wolfe (mid, prev, opt)) {
1063- // [1]
10641066 curr.update (scratch, mid);
10651067 return WolfeStatus{WolfeReturn::Wolfe, total_updates, num_backtracks,
10661068 true };
@@ -1069,17 +1071,17 @@ inline WolfeStatus wolfe_line_search(Info& wolfe_info, UpdateFun&& update_fun,
10691071 if (mid.obj () > best.obj ()) {
10701072 best = mid;
10711073 }
1072- if (mid.obj () >= low.obj ()) {
1073- // [2]
1074- high = mid;
1075- } else if (mid.dir () >= 0 ) {
1076- // [3]
1077- high = low;
1078- low = mid;
1079- }
1080- // [4]
1074+ }
1075+
1076+ // Update bracket based on derivative sign
1077+ if (mid.dir () * low.dir () < 0 ) {
1078+ // sign change between low and mid -> [low, mid]
1079+ high = mid;
1080+ } else {
1081+ // otherwise shift left endpoint -> [mid, high]
10811082 low = mid;
10821083 }
1084+
10831085 // Convergence/guard-rail checks (uses prev/grad_tol/obj_tol etc.)
10841086 auto bounds_check = check_bounds (mid);
10851087 if (bounds_check.stop_ != WolfeReturn::Continue) {
@@ -1088,8 +1090,6 @@ inline WolfeStatus wolfe_line_search(Info& wolfe_info, UpdateFun&& update_fun,
10881090 }
10891091 return bounds_check;
10901092 }
1091- // [5]
1092- high = mid;
10931093 }
10941094 // On failure, use the best point we have found so far that at least satisfies
10951095 // armijo
0 commit comments