Skip to content

Commit 6e15bb8

Browse files
committed
use solve_glm more directly
1 parent 6e305f0 commit 6e15bb8

10 files changed

Lines changed: 205 additions & 67 deletions

File tree

ya_glm/Glm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ def fit(self, X, y):
4949
# TODO: do we want to give the user the option to not copy?
5050
X, y, pre_pro_out = self.preprocess(X, y, copy=True)
5151

52-
coef, intercept, out_data = self.solve(X=X, y=y,
53-
**self._get_solve_kws())
52+
coef, intercept, out_data = self.solve_glm(X=X, y=y,
53+
**self._get_solve_kws())
5454

5555
self._set_fit(fit_out={'coef': coef, 'intercept': intercept,
5656
'opt_data': out_data},

ya_glm/backends/fista/LinearRegression.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
from ya_glm.lla.lla import solve_lla
1515
from .glm_solver import solve_glm, solve_glm_path
16-
from .fcp_lla_solver import WL1SolverGlm
1716

1817

1918
##############
@@ -22,32 +21,32 @@
2221

2322

2423
class Vanilla(LinRegMixin, GlmVanilla):
25-
solve = staticmethod(solve_glm)
24+
solve_glm = staticmethod(solve_glm)
2625

2726

2827
class Lasso(LinRegMixin, GlmLasso):
29-
solve = staticmethod(solve_glm)
28+
solve_glm = staticmethod(solve_glm)
3029

3130

3231
class LassoENet(LinRegMixin, GlmLassoENet):
33-
solve = staticmethod(solve_glm)
32+
solve_glm = staticmethod(solve_glm)
3433

3534

3635
class GroupLasso(LinRegMixin, GlmGroupLasso):
37-
solve = staticmethod(solve_glm)
36+
solve_glm = staticmethod(solve_glm)
3837

3938

4039
class GroupLassoENet(LinRegMixin, GlmGroupLassoENet):
41-
solve = staticmethod(solve_glm)
40+
solve_glm = staticmethod(solve_glm)
4241

4342

4443
class Ridge(LinRegMixin, GlmRidge):
45-
solve = staticmethod(solve_glm)
44+
solve_glm = staticmethod(solve_glm)
4645

4746

4847
class FcpLLA(LinRegMixin, GlmFcpFitLLA):
4948
solve_lla = staticmethod(solve_lla)
50-
base_wl1_solver = WL1SolverGlm
49+
solve_glm = staticmethod(solve_glm)
5150

5251
def _get_defualt_init(self):
5352
# return LassoCV()
@@ -60,7 +59,7 @@ def _get_defualt_init(self):
6059

6160
class GroupFcpLLA(LinRegMixin, GlmGroupFcpFitLLA):
6261
solve_lla = staticmethod(solve_lla)
63-
base_wl1_solver = WL1SolverGlm
62+
solve_glm = staticmethod(solve_glm)
6463

6564
def _get_defualt_init(self):
6665
est = GroupLasso(groups=self.groups,
@@ -76,21 +75,21 @@ def _get_defualt_init(self):
7675

7776

7877
class LassoCV(GlmLassoCVPath):
79-
solve_path = staticmethod(solve_glm_path)
78+
solve_glm_path = staticmethod(solve_glm_path)
8079

8180
@add_init_params(GlmLassoCVPath)
8281
def __init__(self, estimator=Lasso()): pass
8382

8483

8584
class LassoENetCV(GlmLassoENetCVPath):
86-
solve_path = staticmethod(solve_glm_path)
85+
solve_glm_path = staticmethod(solve_glm_path)
8786

8887
@add_init_params(GlmLassoENetCVPath)
8988
def __init__(self, estimator=LassoENet()): pass
9089

9190

9291
class GroupLassoCV(GlmGroupLassoCVPath):
93-
solve_path = staticmethod(solve_glm_path)
92+
solve_glm_path = staticmethod(solve_glm_path)
9493

9594
@add_init_params(GlmLassoCVPath)
9695
# gruops=[] is hack to get around required positional arugment
@@ -99,14 +98,14 @@ def __init__(self, estimator=GroupLasso(groups=[])): pass
9998

10099

101100
class GroupLassoENetCV(GlmGroupLassoENetCVPath):
102-
solve_path = staticmethod(solve_glm_path)
101+
solve_glm_path = staticmethod(solve_glm_path)
103102

104103
@add_init_params(GlmLassoENetCVPath)
105104
def __init__(self, estimator=GroupLassoENet(groups=[])): pass
106105

107106

108107
class RidgeCV(GlmRidgeCVPath):
109-
solve_path = staticmethod(solve_glm_path)
108+
solve_glm_path = staticmethod(solve_glm_path)
110109

111110
@add_init_params(GlmRidgeCVPath)
112111
def __init__(self, estimator=Ridge()): pass

ya_glm/backends/fista/LinearRegressionMultiResp.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
from ya_glm.lla.lla import solve_lla
1515
from .glm_solver import solve_glm, solve_glm_path
16-
from .fcp_lla_solver import WL1SolverGlm
1716

1817

1918
##############
@@ -22,24 +21,24 @@
2221

2322

2423
class Vanilla(LinRegMultiResponseMixin, GlmVanilla):
25-
solve = staticmethod(solve_glm)
24+
solve_glm = staticmethod(solve_glm)
2625

2726

2827
class MultiTaskLasso(LinRegMultiResponseMixin, GlmMultiTaskLasso):
29-
solve = staticmethod(solve_glm)
28+
solve_glm = staticmethod(solve_glm)
3029

3130

3231
class MultiTaskLassoENet(LinRegMultiResponseMixin, GlmMultiTaskLassoENet):
33-
solve = staticmethod(solve_glm)
32+
solve_glm = staticmethod(solve_glm)
3433

3534

3635
class NuclearNorm(LinRegMultiResponseMixin, GlmNuclearNorm):
37-
solve = staticmethod(solve_glm)
36+
solve_glm = staticmethod(solve_glm)
3837

3938

4039
class MultiTaskFcpLLA(LinRegMultiResponseMixin, GlmMultiTaskFcpFitLLA):
4140
solve_lla = staticmethod(solve_lla)
42-
base_wl1_solver = WL1SolverGlm
41+
solve_glm = staticmethod(solve_glm)
4342

4443
def _get_defualt_init(self):
4544
est = MultiTaskLasso(fit_intercept=self.fit_intercept,
@@ -51,7 +50,7 @@ def _get_defualt_init(self):
5150

5251
class NuclearNormFcpLLA(LinRegMultiResponseMixin, GlmNuclearNormFcpFitLLA):
5352
solve_lla = staticmethod(solve_lla)
54-
base_wl1_solver = WL1SolverGlm
53+
solve_glm = staticmethod(solve_glm)
5554

5655
def _get_defualt_init(self):
5756
est = NuclearNorm(fit_intercept=self.fit_intercept,
@@ -67,21 +66,21 @@ def _get_defualt_init(self):
6766

6867

6968
class MultiTaskLassoCV(GlmMultiTaskLassoCVPath):
70-
solve_path = staticmethod(solve_glm_path)
69+
solve_glm_path = staticmethod(solve_glm_path)
7170

7271
@add_init_params(GlmMultiTaskLassoCVPath)
7372
def __init__(self, estimator=MultiTaskLasso()): pass
7473

7574

7675
class MultiTaskLassoENetCV(GlmMultiTaskLassoENetCVPath):
77-
solve_path = staticmethod(solve_glm_path)
76+
solve_glm_path = staticmethod(solve_glm_path)
7877

7978
@add_init_params(GlmMultiTaskLassoENetCVPath)
8079
def __init__(self, estimator=MultiTaskLassoENet()): pass
8180

8281

8382
class NuclearNormCV(GlmNuclearNormCVPath):
84-
solve_path = staticmethod(solve_glm_path)
83+
solve_glm_path = staticmethod(solve_glm_path)
8584

8685
@add_init_params(GlmNuclearNormCVPath)
8786
def __init__(self, estimator=NuclearNorm()): pass

ya_glm/backends/fista/glm_solver.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from ya_glm.opt.fista import solve_fista
2222
from ya_glm.opt.base import Func, Sum
2323

24+
2425
_solve_glm_params = dedent("""
2526
X: array-like, shape (n_samples, n_features)
2627
The training covariate data.
@@ -472,10 +473,11 @@ def process_init(X, y, loss_func, fit_intercept=True, coef_init=None,
472473

473474
# format
474475
coef_init = np.array(coef_init)
475-
if coef_init.ndim > 1:
476-
intercept_init = np.array(intercept_init)
477-
else:
478-
intercept_init = float(intercept_init)
476+
if fit_intercept:
477+
if coef_init.ndim > 1:
478+
intercept_init = np.array(intercept_init)
479+
else:
480+
intercept_init = float(intercept_init)
479481

480482
# maybe concatenate
481483
if fit_intercept:

ya_glm/cv/CVPath.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
class CVPathMixin:
1010
"""
11-
solve_path
11+
solve_glm_path
1212
"""
1313

1414
def _fit_and_score_path_getter(self):
@@ -25,7 +25,7 @@ def est_from_fit(fit_out, pre_pro_out):
2525
preprocess = None
2626

2727
fit_and_score_path = partial(score_from_fit_path,
28-
solve_path=self.solve_path,
28+
solve_path=self.solve_glm_path,
2929
est_from_fit=est_from_fit,
3030
scorer=self.cv_scorer,
3131
preprocess=preprocess)

ya_glm/estimator_getter.py

Lines changed: 58 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,19 @@
3232
from ya_glm.backends.fista.glm_solver import solve_glm as solve_glm_fista
3333
from ya_glm.backends.fista.glm_solver import solve_glm_path \
3434
as solve_glm_path_fista
35-
from ya_glm.backends.fista.fcp_lla_solver import WL1SolverGlm \
36-
as WL1SolverGlm_fista
35+
3736

3837
# andersoncd solvers
3938
from ya_glm.backends.andersoncd.glm_solver import solve_glm as \
4039
solve_glm_andersoncd
4140
from ya_glm.backends.andersoncd.glm_solver import solve_glm_path \
4241
as solve_glm_path_andersoncd
43-
from ya_glm.backends.andersoncd.fcp_lla_solver import WL1SolverGlm \
44-
as WL1SolverGlm_andersoncd
42+
43+
# cvxpy solvers
44+
from ya_glm.backends.cvxpy.glm_solver import solve_glm as solve_glm_cvxpy
45+
from ya_glm.backends.cvxpy.glm_solver import solve_glm_path \
46+
as solve_glm_path_cvxpy
47+
4548

4649
# other
4750
from ya_glm.add_init_params import add_init_params
@@ -137,35 +140,65 @@ def get_fcp_penalty(penalty='lasso'):
137140
raise ValueError("Bad input for penalty: {}".format(penalty))
138141

139142

143+
def get_solver(backend='fista'):
144+
"""
145+
Parameters
146+
----------
147+
backend: str, dict
148+
149+
"""
150+
151+
# get solver
152+
if type(backend) == str:
153+
if backend == 'fista':
154+
solve_glm = solve_glm_fista
155+
solve_glm_path = solve_glm_path_fista
156+
157+
158+
elif backend == 'andersoncd':
159+
solve_glm = solve_glm_andersoncd
160+
solve_glm_path = solve_glm_path_andersoncd
161+
162+
elif backend == 'cvxpy':
163+
solve_glm = solve_glm_cvxpy
164+
solve_glm_path = solve_glm_path_cvxpy
165+
166+
else:
167+
solve_glm = backend.get('solve_glm', None)
168+
solve_glm_path = backend.get('solve_glm_path', None)
169+
170+
solve_glm = staticmethod(solve_glm)
171+
if solve_glm_path is not None:
172+
solve_glm_path = staticmethod(solve_glm_path)
173+
174+
return solve_glm, solve_glm_path
175+
176+
# TODO: handle static method + None
177+
178+
140179
def get_pen_glm(loss_func='linear_regression',
141180
penalty='lasso',
142181
backend='fista'):
143-
182+
144183
if penalty in _MULTI_RESP_PENS:
145184
assert loss_func in _MULTI_RESP_LOSSES
146185

147186
MODEL_MIXIN = get_model_mixin(loss_func=loss_func)
148187
GLM, GLM_CV = get_penalty(penalty=penalty)
188+
solve_glm, solve_glm_path = get_solver(backend=backend)
149189

150-
# get solver
151-
if type(backend) == str and backend == 'fista':
152-
solve_glm_impl = solve_glm_fista
153-
solve_glm_path_impl = solve_glm_path_fista
154-
155-
elif type(backend) == str and backend == 'andersoncd':
156-
solve_glm_impl = solve_glm_andersoncd
157-
solve_glm_path_impl = solve_glm_path_andersoncd
190+
# TODO-HACK: for reasons I do not understand I needed to
191+
# do this to get Estimator() to work below
192+
temp = {}
193+
temp['solve_glm'] = solve_glm
194+
temp['solve_glm_path'] = solve_glm_path
158195

159-
else:
160-
solve_glm_impl = backend.get('solve_glm', None)
161-
solve_glm_path_impl = backend.get('solve_glm_path', None)
162-
163-
###################
164196
# setup estimator #
165197
###################
166198

167199
class Estimator(MODEL_MIXIN, GLM):
168-
solve = staticmethod(solve_glm_impl)
200+
# solve_glm = solve_glm
201+
solve_glm = temp['solve_glm']
169202

170203
@add_init_params(GLM, MODEL_MIXIN)
171204
def __init__(self): pass
@@ -183,7 +216,8 @@ def __init__(self): pass
183216
estimator = Estimator()
184217

185218
class EstimatorCV(GLM_CV):
186-
solve_path = staticmethod(solve_glm_path_impl)
219+
# solve_glm_path = solve_glm_path
220+
solve_glm_path = temp['solve_glm_path']
187221

188222
@add_init_params(GLM_CV)
189223
def __init__(self, estimator=estimator): pass
@@ -201,16 +235,9 @@ def get_fcp_glm(loss_func='linear_regression', penalty='lasso',
201235
# get base model class
202236
MODEL_MIXIN = get_model_mixin(loss_func=loss_func)
203237
GLM_FCP, GLM_FCP_CV = get_fcp_penalty(penalty=penalty)
238+
solve_glm = get_solver(backend=backend)[0]
204239

205-
# get wl1 solver
206-
if type(backend) == str and backend == 'fista':
207-
WL1_impl = WL1SolverGlm_fista
208-
209-
elif type(backend) == str and backend == 'andersoncd':
210-
WL1_impl = WL1SolverGlm_andersoncd
211-
212-
else:
213-
WL1_impl = backend.get('wl1', None)
240+
temp = {'solve_glm': solve_glm} # TODO-HACK: see above
214241

215242
# get default initializer
216243
Default, DefaultCV = get_pen_glm(loss_func=loss_func,
@@ -223,7 +250,8 @@ def get_fcp_glm(loss_func='linear_regression', penalty='lasso',
223250

224251
class Estimator(MODEL_MIXIN, GLM_FCP):
225252
solve_lla = staticmethod(solve_lla)
226-
base_wl1_solver = WL1_impl
253+
# solve_glm = solve_glm
254+
solve_glm = temp['solve_glm']
227255

228256
@add_init_params(GLM_FCP, MODEL_MIXIN)
229257
def __init__(self): pass

ya_glm/fcp/GlmFcp.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from ya_glm.opt.concave_penalty import get_penalty_func
1414
from ya_glm.processing import process_weights_group_lasso
1515
from ya_glm.opt.GroupLasso import euclid_norm
16+
from ya_glm.lla.WeightedLassoSolver import WL1SolverGlm
1617

1718

1819
class InitMixin:
@@ -183,7 +184,7 @@ def compute_fit(self, X, y, init_data):
183184

184185
class GlmFcpFitLLA(GlmFcp):
185186

186-
base_wl1_solver = None
187+
solve_glm = None
187188
solve_lla = None
188189

189190
@add_init_params(GlmFcp)
@@ -211,7 +212,9 @@ def compute_fit(self, X, y, init_data):
211212
**self._extra_wl1_kws()
212213
}
213214

214-
wl1_solver = self.base_wl1_solver(X=X, y=y, **kws)
215+
# Setup weighted L1 solver
216+
wl1_solver = WL1SolverGlm(X=X, y=y, **kws)
217+
wl1_solver.solve_glm = self.solve_glm
215218

216219
# solve!
217220
coef, intercept, opt_data = \

0 commit comments

Comments
 (0)