Skip to content

Commit cd363fe

Browse files
committed
Revert "update with last review"
This reverts commit eab35ec.
1 parent 237f2a4 commit cd363fe

2 files changed

Lines changed: 78 additions & 79 deletions

File tree

stan/math/mix/functor/laplace_marginal_density_estimator.hpp

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
#include <stan/math/prim/functor/iter_tuple_nested.hpp>
1414
#include <unsupported/Eigen/MatrixFunctions>
1515
#include <cmath>
16-
#include <sstream>
1716

1817
/**
1918
* @file
@@ -444,12 +443,8 @@ inline void llt_with_jitter(LLT& llt_B, B_t& B, double min_jitter = 1e-10,
444443
}
445444
}
446445
if (llt_B.info() != Eigen::Success) {
447-
std::stringstream msg;
448-
msg << "laplace_marginal_density: Cholesky decomposition failed on "
449-
<< "Hessian matrix B after attempting jitter values from "
450-
<< min_jitter << " to " << max_jitter
451-
<< ". Matrix may not be positive definite.";
452-
throw std::domain_error(msg.str());
446+
throw std::domain_error(
447+
"laplace_marginal_density: Cholesky (Diag) failed");
453448
}
454449
}
455450
}
@@ -947,13 +942,16 @@ inline auto run_newton_loop(SolverPolicy& solver, NewtonStateT& state,
947942
scratch.alpha() = 1.0;
948943
update_fun(scratch, state.curr(), state.prev(), scratch.eval_,
949944
state.wolfe_info.p_);
950-
bool force_finish = false;
945+
bool run_convergence_check = true;
951946
if (scratch.alpha() <= options.line_search.min_alpha) {
952947
state.wolfe_status.accept_ = false;
953-
force_finish = true;
948+
finish_update = true;
949+
run_convergence_check = false;
954950
} else if (options.line_search.max_iterations == 0) {
955951
state.curr().update(scratch);
956952
state.wolfe_status.accept_ = true;
953+
finish_update = false;
954+
run_convergence_check = false;
957955
} else {
958956
Eigen::VectorXd s = scratch.a() - state.prev().a();
959957
auto full_step_grad
@@ -966,15 +964,16 @@ inline auto run_newton_loop(SolverPolicy& solver, NewtonStateT& state,
966964
state.wolfe_status = internal::wolfe_line_search(
967965
state.wolfe_info, update_fun, options.line_search, msgs);
968966
}
969-
/**
970-
* Stop when objective change is small, or when a rejected Wolfe step
971-
* fails to improve; finish_update then exits the Newton loop.
972-
*/
973-
const bool obj_below_tol = std::abs(state.curr().obj() - state.prev().obj()) <
974-
options.tolerance;
975-
const bool wolfe_failed = !state.wolfe_status.accept_
976-
&& state.curr().obj() <= state.prev().obj();
977-
finish_update = force_finish || obj_below_tol || wolfe_failed;
967+
if (run_convergence_check) {
968+
/**
969+
* Stop when objective change is small, or when a rejected Wolfe step
970+
* fails to improve; finish_update then exits the Newton loop.
971+
*/
972+
finish_update = std::abs(state.curr().obj() - state.prev().obj())
973+
< options.tolerance
974+
|| (!state.wolfe_status.accept_
975+
&& state.curr().obj() <= state.prev().obj());
976+
}
978977
}
979978
if (finish_update) {
980979
if (!state.final_loop && state.wolfe_status.accept_) {

stan/math/mix/functor/wolfe_line_search.hpp

Lines changed: 61 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ namespace internal {
156156
* (x_left + x_right) / 2 is returned instead.
157157
*/
158158
template <typename Scalar>
159-
[[nodiscard]] inline Scalar cubic_interpolation(Scalar x_left, Scalar f_left,
159+
[[nodiscard]] inline Scalar cubic_or_bisect_max(Scalar x_left, Scalar f_left,
160160
Scalar df_left, Scalar x_right,
161161
Scalar f_right,
162162
Scalar df_right) noexcept {
@@ -283,8 +283,8 @@ template <typename Scalar>
283283
}
284284

285285
template <typename Eval, typename Options>
286-
inline auto cubic_interpolation(Eval&& low, Eval&& high, Options&& opt) {
287-
auto alpha = cubic_interpolation(low.alpha(), low.obj(), low.dir(),
286+
inline auto cubic_or_bisect_max(Eval&& low, Eval&& high, Options&& opt) {
287+
auto alpha = cubic_or_bisect_max(low.alpha(), low.obj(), low.dir(),
288288
high.alpha(), high.obj(), high.dir());
289289
const double width = high.alpha() - low.alpha();
290290
const double guard = 1e-3 * width; // or make this an option
@@ -714,7 +714,7 @@ inline auto retry_evaluate(Update&& update, Proposal&& proposal, Curr&& curr,
714714
*
715715
* - If `low.dir()` and `high.dir()` have opposite signs and the right
716716
* endpoint `high` satisfies Armijo, a cubic interpolation of the endpoints
717-
* is used (`cubic_interpolation(low, high, opt)`).
717+
* is used (`cubic_or_bisect_max(low, high, opt)`).
718718
* - Otherwise the trial is the simple bisection midpoint
719719
* \f$\tfrac{1}{2}(\alpha_\text{low} + \alpha_\text{high})\f$.
720720
*
@@ -864,6 +864,10 @@ inline WolfeStatus wolfe_line_search(Info& wolfe_info, UpdateFun&& update_fun,
864864
Eval low{0.0, prev.obj(), dir_deriv_init};
865865
prev.dir() = dir_deriv_init;
866866
int total_updates = 0;
867+
auto eval_finite = [](const Eval& e, const WolfeData& state) {
868+
return std::isfinite(e.obj()) && std::isfinite(e.dir())
869+
&& state.theta().allFinite() && state.theta_grad().allFinite();
870+
};
867871
Eval best = low; // keep the best Armijo-OK in case strong-Wolfe fails
868872
auto update_with_tick = [&total_updates, &opt, &best, &update_fun](
869873
auto&& proposal, auto&& curr, auto&& prev,
@@ -891,6 +895,7 @@ inline WolfeStatus wolfe_line_search(Info& wolfe_info, UpdateFun&& update_fun,
891895
= std::clamp(curr.alpha() * opt.scale_up, opt.min_alpha, opt.max_alpha);
892896
Eval high{alpha_start, curr.obj(), dir_deriv_init};
893897
WolfeStatus wolfe_check{WolfeReturn::Continue, 0, 0, false};
898+
bool high_has_eval = true;
894899
// Initial check for numerical trouble
895900
{
896901
wolfe_check = update_with_tick(scratch, curr, prev, high, p);
@@ -915,6 +920,7 @@ inline WolfeStatus wolfe_line_search(Info& wolfe_info, UpdateFun&& update_fun,
915920
if (wolfe_check.stop_ != WolfeReturn::Continue) {
916921
return wolfe_check;
917922
}
923+
high_has_eval = true;
918924
}
919925
wolfe_check = update_with_tick(scratch, curr, prev, best, p);
920926
if (wolfe_check.stop_ != WolfeReturn::Continue) {
@@ -929,50 +935,55 @@ inline WolfeStatus wolfe_line_search(Info& wolfe_info, UpdateFun&& update_fun,
929935
}
930936
}
931937
}
938+
bool found_right = false;
932939
int num_backtracks = 0;
933940
/**
934-
* From Nocedal–Wright (2006), Algorithm 3.5:
935-
* https://www.math.uci.edu/~qnie/Publications/NumericalOptimization.pdf
941+
* For each case
936942
* | armijo | wolfe | sign(g) | Action
937943
* -------+-------+---------+--------------------------------
938944
* | [1] T | T | | Accept alpha
939945
* | [2] T | F | > 0 | set low=high, expand high
940-
* | [3] T | F | < 0 | Bracket found: stop
941-
* | [4] F | T | | Bracket found: stop
942-
* | [5] F | F | | Bracket found: stop
943-
* NOTE: In an ideal case we would end up with a positive low directional gradient and
944-
* negative high directional gradient. Cubic interpolation requires a bracket with directional
945-
* shape like /\. This scheme does not gurantee a bracket with that shape will be found.
946-
* So in the zoom we will have to check if we can do cubic or have to fallback to bisection.
946+
* | [3] T | F | < 0 | Set alpha_high <- alpha, stop
947+
* | [4] F | T | | Set alpha_high <- alpha, stop
948+
* | [5] F | F | | Set alpha_high <- alpha, stop
947949
**/
948-
while (high.alpha() < opt.max_alpha) {
950+
while (!found_right && high.alpha() < opt.max_alpha) {
949951
num_backtracks++;
952+
// 1. Evaluate f(alpha) and g(alpha)
950953
wolfe_check = update_with_tick(scratch, curr, prev, high, p);
951954
if (wolfe_check.stop_ != WolfeReturn::Continue) {
952955
return wolfe_check;
953956
}
957+
high_has_eval = true;
958+
const bool finite_ok = eval_finite(high, scratch);
959+
// 2. Handle numerical trouble first
960+
if (!finite_ok) { // f or g is NaN/Inf → shrink
961+
high.alpha() *= 0.5;
962+
high_has_eval = false;
963+
if (high.alpha() < opt.min_alpha) {
964+
break;
965+
}
966+
continue;
967+
}
954968
const bool armijo = check_armijo(high, prev, opt);
955969
const bool wolfe = check_wolfe(high, prev, opt);
956-
// [1]
957-
if (armijo && wolfe) {
970+
if (armijo && wolfe) { // [1]
958971
curr.update(scratch, high);
959972
return WolfeStatus{WolfeReturn::Wolfe, total_updates, num_backtracks,
960973
true};
961-
} else if (armijo) {
962-
if (best.obj() < high.obj()) {
963-
best = high;
964-
}
965-
// [2]
966-
if (high.dir() > 0) {
967-
low = high;
968-
high.alpha() *= opt.scale_up;
969-
continue;
970-
}
971-
// [3]
972-
break;
973974
}
974-
// [3, 4, 5]
975-
break;
975+
if (armijo && best.obj() < high.obj()) {
976+
best = high;
977+
}
978+
const bool dir_pos = high.dir() > 0;
979+
if (armijo && !wolfe && dir_pos) { // [2]
980+
low = high;
981+
high.alpha() *= opt.scale_up;
982+
high_has_eval = false;
983+
continue;
984+
}
985+
// [3,4,5]
986+
found_right = true;
976987
}
977988
const double grad_tol
978989
= std::max(opt.abs_grad_threshold,
@@ -1007,6 +1018,13 @@ inline WolfeStatus wolfe_line_search(Info& wolfe_info, UpdateFun&& update_fun,
10071018
return WolfeStatus{WolfeReturn::Continue, total_updates, num_backtracks,
10081019
false};
10091020
};
1021+
if (!high_has_eval) {
1022+
wolfe_check = update_with_tick(scratch, curr, prev, high, p);
1023+
if (wolfe_check.stop_ != WolfeReturn::Continue) {
1024+
return wolfe_check;
1025+
}
1026+
high_has_eval = true;
1027+
}
10101028
auto check_b = check_bounds(high);
10111029
if (check_b.stop_ != WolfeReturn::Continue) {
10121030
if (check_b.accept_) {
@@ -1018,19 +1036,7 @@ inline WolfeStatus wolfe_line_search(Info& wolfe_info, UpdateFun&& update_fun,
10181036
if (wolfe_check.stop_ != WolfeReturn::Continue) {
10191037
return wolfe_check;
10201038
}
1021-
/**
1022-
* Zoom Step: (Alg 3.6, adapted to maximization via phi=-obj)
1023-
*
1024-
* Exit/update table (evaluated at `mid`, with `low` = best Armijo endpoint):
1025-
* | Armijo? | obj(mid) >= obj(low)? | Wolfe? | dir(mid) >= 0? | Action
1026-
* |---------|-----------------------|--------|----------------|--------------------------|
1027-
* | T | F | T | * | accept mid [1] |
1028-
* | T | T | * | * | high = mid [2] |
1029-
* | T | F | F | T | high = low; low = mid [3]|
1030-
* | T | F | F | F | low = mid [4] |
1031-
* | F | * | * | * | high = mid [5] |
1032-
* ----------------------------------------------------------------------------------------
1033-
**/
1039+
// Zoom phase
10341040
while ((high.alpha() - low.alpha() > opt.min_alpha)
10351041
&& high.alpha() > opt.min_alpha) {
10361042
num_backtracks++;
@@ -1040,12 +1046,9 @@ inline WolfeStatus wolfe_line_search(Info& wolfe_info, UpdateFun&& update_fun,
10401046
const bool use_cubic = have_sign_change && high_armijo_ok;
10411047

10421048
// Choose trial alpha: cubic when bracket is good, else bisection.
1043-
double alpha_mid{0};
1044-
if (use_cubic) {
1045-
alpha_mid = cubic_interpolation(low, high, opt);
1046-
} else {
1047-
alpha_mid = 0.5 * (low.alpha() + high.alpha());
1048-
}
1049+
double alpha_mid = use_cubic ? cubic_or_bisect_max(low, high, opt)
1050+
: 0.5 * (low.alpha() + high.alpha());
1051+
10491052
if (alpha_mid <= opt.min_alpha) {
10501053
break;
10511054
}
@@ -1060,7 +1063,6 @@ inline WolfeStatus wolfe_line_search(Info& wolfe_info, UpdateFun&& update_fun,
10601063
}
10611064
if (check_armijo(mid, prev, opt)) {
10621065
if (check_wolfe(mid, prev, opt)) {
1063-
// [1]
10641066
curr.update(scratch, mid);
10651067
return WolfeStatus{WolfeReturn::Wolfe, total_updates, num_backtracks,
10661068
true};
@@ -1069,17 +1071,17 @@ inline WolfeStatus wolfe_line_search(Info& wolfe_info, UpdateFun&& update_fun,
10691071
if (mid.obj() > best.obj()) {
10701072
best = mid;
10711073
}
1072-
if (mid.obj() >= low.obj()) {
1073-
// [2]
1074-
high = mid;
1075-
} else if (mid.dir() >= 0) {
1076-
// [3]
1077-
high = low;
1078-
low = mid;
1079-
}
1080-
// [4]
1074+
}
1075+
1076+
// Update bracket based on derivative sign
1077+
if (mid.dir() * low.dir() < 0) {
1078+
// sign change between low and mid -> [low, mid]
1079+
high = mid;
1080+
} else {
1081+
// otherwise shift left endpoint -> [mid, high]
10811082
low = mid;
10821083
}
1084+
10831085
// Convergence/guard-rail checks (uses prev/grad_tol/obj_tol etc.)
10841086
auto bounds_check = check_bounds(mid);
10851087
if (bounds_check.stop_ != WolfeReturn::Continue) {
@@ -1088,8 +1090,6 @@ inline WolfeStatus wolfe_line_search(Info& wolfe_info, UpdateFun&& update_fun,
10881090
}
10891091
return bounds_check;
10901092
}
1091-
// [5]
1092-
high = mid;
10931093
}
10941094
// On failure, use the best point we have found so far that at least satisfies
10951095
// armijo

0 commit comments

Comments
 (0)