3232from ya_glm .backends .fista .glm_solver import solve_glm as solve_glm_fista
3333from ya_glm .backends .fista .glm_solver import solve_glm_path \
3434 as solve_glm_path_fista
35- from ya_glm .backends .fista .fcp_lla_solver import WL1SolverGlm \
36- as WL1SolverGlm_fista
35+
3736
3837# andersoncd solvers
3938from ya_glm .backends .andersoncd .glm_solver import solve_glm as \
4039 solve_glm_andersoncd
4140from ya_glm .backends .andersoncd .glm_solver import solve_glm_path \
4241 as solve_glm_path_andersoncd
43- from ya_glm .backends .andersoncd .fcp_lla_solver import WL1SolverGlm \
44- as WL1SolverGlm_andersoncd
42+
43+ # cvxpy solvers
44+ from ya_glm .backends .cvxpy .glm_solver import solve_glm as solve_glm_cvxpy
45+ from ya_glm .backends .cvxpy .glm_solver import solve_glm_path \
46+ as solve_glm_path_cvxpy
47+
4548
4649# other
4750from ya_glm .add_init_params import add_init_params
@@ -137,35 +140,65 @@ def get_fcp_penalty(penalty='lasso'):
137140 raise ValueError ("Bad input for penalty: {}" .format (penalty ))
138141
139142
143+ def get_solver (backend = 'fista' ):
144+ """
145+ Parameters
146+ ----------
147+ backend: str, dict
148+
149+ """
150+
151+ # get solver
152+ if type (backend ) == str :
153+ if backend == 'fista' :
154+ solve_glm = solve_glm_fista
155+ solve_glm_path = solve_glm_path_fista
156+
157+
158+ elif backend == 'andersoncd' :
159+ solve_glm = solve_glm_andersoncd
160+ solve_glm_path = solve_glm_path_andersoncd
161+
162+ elif backend == 'cvxpy' :
163+ solve_glm = solve_glm_cvxpy
164+ solve_glm_path = solve_glm_path_cvxpy
165+
166+ else :
167+ solve_glm = backend .get ('solve_glm' , None )
168+ solve_glm_path = backend .get ('solve_glm_path' , None )
169+
170+ solve_glm = staticmethod (solve_glm )
171+ if solve_glm_path is not None :
172+ solve_glm_path = staticmethod (solve_glm_path )
173+
174+ return solve_glm , solve_glm_path
175+
176+ # TODO: handle static method + None
177+
178+
140179def get_pen_glm (loss_func = 'linear_regression' ,
141180 penalty = 'lasso' ,
142181 backend = 'fista' ):
143-
182+
144183 if penalty in _MULTI_RESP_PENS :
145184 assert loss_func in _MULTI_RESP_LOSSES
146185
147186 MODEL_MIXIN = get_model_mixin (loss_func = loss_func )
148187 GLM , GLM_CV = get_penalty (penalty = penalty )
188+ solve_glm , solve_glm_path = get_solver (backend = backend )
149189
150- # get solver
151- if type (backend ) == str and backend == 'fista' :
152- solve_glm_impl = solve_glm_fista
153- solve_glm_path_impl = solve_glm_path_fista
154-
155- elif type (backend ) == str and backend == 'andersoncd' :
156- solve_glm_impl = solve_glm_andersoncd
157- solve_glm_path_impl = solve_glm_path_andersoncd
190+ # TODO-HACK: for reasons I do not understand I needed to
191+ # do this to get Estimator() to work below
192+ temp = {}
193+ temp ['solve_glm' ] = solve_glm
194+ temp ['solve_glm_path' ] = solve_glm_path
158195
159- else :
160- solve_glm_impl = backend .get ('solve_glm' , None )
161- solve_glm_path_impl = backend .get ('solve_glm_path' , None )
162-
163- ###################
164196 # setup estimator #
165197 ###################
166198
167199 class Estimator (MODEL_MIXIN , GLM ):
168- solve = staticmethod (solve_glm_impl )
200+ # solve_glm = solve_glm
201+ solve_glm = temp ['solve_glm' ]
169202
170203 @add_init_params (GLM , MODEL_MIXIN )
171204 def __init__ (self ): pass
@@ -183,7 +216,8 @@ def __init__(self): pass
183216 estimator = Estimator ()
184217
185218 class EstimatorCV (GLM_CV ):
186- solve_path = staticmethod (solve_glm_path_impl )
219+ # solve_glm_path = solve_glm_path
220+ solve_glm_path = temp ['solve_glm_path' ]
187221
188222 @add_init_params (GLM_CV )
189223 def __init__ (self , estimator = estimator ): pass
@@ -201,16 +235,9 @@ def get_fcp_glm(loss_func='linear_regression', penalty='lasso',
201235 # get base model class
202236 MODEL_MIXIN = get_model_mixin (loss_func = loss_func )
203237 GLM_FCP , GLM_FCP_CV = get_fcp_penalty (penalty = penalty )
238+ solve_glm = get_solver (backend = backend )[0 ]
204239
205- # get wl1 solver
206- if type (backend ) == str and backend == 'fista' :
207- WL1_impl = WL1SolverGlm_fista
208-
209- elif type (backend ) == str and backend == 'andersoncd' :
210- WL1_impl = WL1SolverGlm_andersoncd
211-
212- else :
213- WL1_impl = backend .get ('wl1' , None )
240+ temp = {'solve_glm' : solve_glm } # TODO-HACK: see above
214241
215242 # get default initializer
216243 Default , DefaultCV = get_pen_glm (loss_func = loss_func ,
@@ -223,7 +250,8 @@ def get_fcp_glm(loss_func='linear_regression', penalty='lasso',
223250
224251 class Estimator (MODEL_MIXIN , GLM_FCP ):
225252 solve_lla = staticmethod (solve_lla )
226- base_wl1_solver = WL1_impl
253+ # solve_glm = solve_glm
254+ solve_glm = temp ['solve_glm' ]
227255
228256 @add_init_params (GLM_FCP , MODEL_MIXIN )
229257 def __init__ (self ): pass
0 commit comments