|
| 1 | +import cvxpy as cp |
| 2 | +from time import time |
| 3 | + |
| 4 | +from ya_glm.backends.fista.glm_solver import process_param_path |
| 5 | +from ya_glm.backends.quantile_lp.utils import get_lin_prog_data, \ |
| 6 | + get_coef_inter, get_quad_mat |
| 7 | + |
| 8 | + |
| 9 | +def solve(X, y, fit_intercept=True, quantile=0.5, sample_weights=None, |
| 10 | + lasso_pen=1, ridge_pen=None, |
| 11 | + lasso_weights=None, ridge_weights=None, tikhonov=None, |
| 12 | + coef_init=None, intercept_init=None, |
| 13 | + solver=None, |
| 14 | + cp_kws={}): |
| 15 | + """ |
| 16 | + Solves the L1 + L2 penalized quantile regression problem by formulating it as a linear quadratic program then appealing to cvxpy. |
| 17 | + """ |
| 18 | + |
| 19 | + if lasso_weights is not None and lasso_pen is None: |
| 20 | + lasso_pen = 1 |
| 21 | + |
| 22 | + if (ridge_weights is not None or tikhonov is not None) \ |
| 23 | + and ridge_pen is None: |
| 24 | + ridge_pen = 1 |
| 25 | + |
| 26 | + start_time = time() |
| 27 | + |
| 28 | + problem, var, lasso_pen, ridge_pen = \ |
| 29 | + setup_problem(X=X, y=y, |
| 30 | + fit_intercept=fit_intercept, |
| 31 | + quantile=quantile, |
| 32 | + sample_weights=sample_weights, |
| 33 | + lasso_pen=lasso_pen, |
| 34 | + ridge_pen=ridge_pen, |
| 35 | + lasso_weights=lasso_weights, |
| 36 | + ridge_weights=ridge_weights, |
| 37 | + tikhonov=tikhonov, |
| 38 | + coef_init=coef_init, |
| 39 | + intercept_init=intercept_init) |
| 40 | + |
| 41 | + problem.solve(solver=solver, **cp_kws) |
| 42 | + # solve_with_backups(problem=problem, variable=var, **cp_kws) |
| 43 | + |
| 44 | + opt_data = {**problem.solver_stats.__dict__, |
| 45 | + 'status': problem.status, |
| 46 | + 'runtime': time() - start_time} |
| 47 | + |
| 48 | + if fit_intercept: |
| 49 | + n_params = X.shape[1] + 1 |
| 50 | + else: |
| 51 | + n_params = X.shape[1] |
| 52 | + |
| 53 | + coef, intercept = get_coef_inter(solution=var.value, |
| 54 | + n_params=n_params, |
| 55 | + fit_intercept=fit_intercept) |
| 56 | + |
| 57 | +# coef = clip_zero(coef, zero_tol=zero_tol) |
| 58 | +# if fit_intercept: |
| 59 | +# intercept = clip_zero(intercept, zero_tol=zero_tol) |
| 60 | +# else: |
| 61 | +# intercept = None |
| 62 | + |
| 63 | + return coef, intercept, opt_data |
| 64 | + |
| 65 | + |
| 66 | +def solve_path(fit_intercept=True, cp_kws={}, zero_tol=1e-8, |
| 67 | + lasso_pen_seq=None, ridge_pen_seq=None, |
| 68 | + check_decr=True, **kws): |
| 69 | + |
| 70 | + param_path = process_param_path(lasso_pen_seq=lasso_pen_seq, |
| 71 | + ridge_pen_seq=ridge_pen_seq, |
| 72 | + check_decr=check_decr) |
| 73 | + |
| 74 | + # make sure we setup the right penalty |
| 75 | + if 'lasso_pen' in param_path[0]: |
| 76 | + kws['lasso_pen'] = param_path[0]['lasso_pen'] |
| 77 | + if 'ridge_pen' in param_path[0]: |
| 78 | + kws['ridge_pen'] = param_path[0]['ridge_pen'] |
| 79 | + |
| 80 | + start_time = time() |
| 81 | + problem, var, lasso_pen, ridge_pen = setup_problem(**kws) |
| 82 | + pre_setup_runtime = time() - start_time |
| 83 | + |
| 84 | + for params in param_path: |
| 85 | + start_time = time() |
| 86 | + |
| 87 | + if 'lasso_pen' in params: |
| 88 | + lasso_pen.value = params['lasso_pen'] |
| 89 | + |
| 90 | + if 'ridge_pen' in params: |
| 91 | + ridge_pen.value = params['ridge_pen'] |
| 92 | + |
| 93 | + problem.solve(**cp_kws) |
| 94 | + # solve_with_backups(problem=problem, variable=var, **cp_kws) |
| 95 | + |
| 96 | + if var.value is None: |
| 97 | + raise RuntimeError("cvxpy solvers failed") |
| 98 | + |
| 99 | + opt_data = {**problem.solver_stats.__dict__, |
| 100 | + 'status': problem.status, |
| 101 | + 'runtime': time() - start_time, |
| 102 | + 'pre_setup_runtime': pre_setup_runtime} |
| 103 | + |
| 104 | + if fit_intercept: |
| 105 | + n_params = kws['X'].shape[1] + 1 |
| 106 | + else: |
| 107 | + n_params = kws['X'].shape[1] |
| 108 | + |
| 109 | + coef, intercept = get_coef_inter(solution=var.value, |
| 110 | + n_params=n_params, |
| 111 | + fit_intercept=fit_intercept) |
| 112 | + |
| 113 | + # coef = clip_zero(coef, zero_tol=zero_tol) |
| 114 | + # if fit_intercept: |
| 115 | + # intercept = clip_zero(intercept, zero_tol=zero_tol) |
| 116 | + # else: |
| 117 | + # intercept = None |
| 118 | + |
| 119 | + fit_out = {'coef': coef, 'intercept': intercept, 'opt_data': opt_data} |
| 120 | + yield fit_out, params |
| 121 | + |
| 122 | + |
| 123 | +def setup_problem(X, y, fit_intercept=True, quantile=0.5, sample_weights=None, |
| 124 | + lasso_pen=1, ridge_pen=None, |
| 125 | + lasso_weights=None, ridge_weights=None, tikhonov=None, |
| 126 | + coef_init=None, intercept_init=None): |
| 127 | + |
| 128 | + if lasso_pen is not None: |
| 129 | + lasso_pen = cp.Parameter(pos=True, value=lasso_pen) |
| 130 | + |
| 131 | + if ridge_pen is not None: |
| 132 | + ridge_pen = cp.Parameter(pos=True, value=ridge_pen) |
| 133 | + |
| 134 | + if coef_init is not None or intercept_init is not None: |
| 135 | + raise NotImplementedError("I do not think initialization works for these solvers") |
| 136 | + |
| 137 | + ###################### |
| 138 | + # setup problem data # |
| 139 | + ###################### |
| 140 | + A_eq, b_eq, lin_coef, n_params = \ |
| 141 | + get_lin_prog_data(X, y, |
| 142 | + fit_intercept=fit_intercept, |
| 143 | + quantile=quantile, |
| 144 | + lasso_pen=lasso_pen, |
| 145 | + sample_weights=sample_weights, |
| 146 | + lasso_weights=lasso_weights) |
| 147 | + |
| 148 | + lin_coef = cp.hstack(lin_coef) |
| 149 | + |
| 150 | + if ridge_pen is not None: |
| 151 | + quad_mat = get_quad_mat(X=X, |
| 152 | + fit_intercept=fit_intercept, |
| 153 | + weights=ridge_weights, |
| 154 | + tikhonov=tikhonov) |
| 155 | + |
| 156 | + n_dim = A_eq.shape[1] |
| 157 | + var = cp.Variable(shape=n_dim) |
| 158 | + |
| 159 | + #################### |
| 160 | + # setup cp problem # |
| 161 | + #################### |
| 162 | + if ridge_pen is None: |
| 163 | + objective = cp.Minimize(var.T @ lin_coef) |
| 164 | + else: |
| 165 | + objective = cp.Minimize(var.T @ lin_coef + |
| 166 | + 0.5 * ridge_pen * cp.quad_form(var, quad_mat)) |
| 167 | + |
| 168 | + constraints = [var >= 0, |
| 169 | + A_eq @ var == b_eq] |
| 170 | + |
| 171 | + problem = cp.Problem(objective, constraints) |
| 172 | + |
| 173 | + return problem, var, lasso_pen, ridge_pen |
0 commit comments