Skip to content

Commit 253a936

Browse files
committed
reorg glms with init
1 parent 7dc5857 commit 253a936

5 files changed

Lines changed: 273 additions & 141 deletions

File tree

ya_glm/base/Glm.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@
1515

1616
_glm_base_params = dedent("""
1717
fit_intercept: bool
18-
Whether or not to fit an intercept.
18+
Whether or not to fit an intercept. The intercept will not be penalized.
1919
2020
standardize: bool
21-
Whether or not to perform internal standardization before fitting the data. Here standardization means mean centering and scaling each column by its standard deviation. Putting each column on the same scale makes sense for fitting penalized models. Note the fitted coefficient/intercept is transformed to be on the original scale of the input data.
21+
Whether or not to perform internal standardization before fitting the data. Standardization means mean centering and scaling each column by its standard deviation. For the group lasso penalty an additional scaling is applied that scales each variable by 1 / sqrt(group size). Putting each variable on the same scale makes sense for fitting penalized models. Note the fitted coefficient/intercept is transformed to be on the original scale of the input data.
2222
2323
opt_kws: dict
24-
Keyword arguments to the glm solver optimization algorithm.
24+
Additional keyword arguments for solve_glm.
2525
""")
2626

2727

@@ -102,7 +102,7 @@ def _validate_data(self, X, y, sample_weight=None, accept_sparse=True):
102102

103103
def preprocess(self, X, y, sample_weight=None, copy=True):
104104
"""
105-
Preprocesses the data for fitting. This method may transform the data e.g. centering and scaling X. If sample weights are provided then these are used for computing weighted means / standard deviations for standardization.
105+
Preprocesses the data for fitting. This method may transform the data e.g. centering and scaling X. If sample weights are provided then these are used for computing weighted means / standard deviations for standardization. For the group lasso penalty an additional scaling is applied that scales each variable by 1 / sqrt(group size).
106106
107107
Parameters
108108
----------
@@ -138,8 +138,8 @@ def preprocess(self, X, y, sample_weight=None, copy=True):
138138
groups=groups,
139139
sample_weight=sample_weight,
140140
copy=copy,
141-
check_input=False,
142-
accept_sparse=False, # TODO!
141+
check_input=True,
142+
accept_sparse=True,
143143
allow_const_cols=not self.fit_intercept)
144144

145145
y, y_out = self._process_y(y, sample_weight=sample_weight, copy=copy)

ya_glm/base/GlmCV.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from ya_glm.cv.cv_select import CVSlectMixin # select_best_cv_tune_param
1515

1616

17+
# TODO: move estimator descripting to subclasses
1718
_cv_params = dedent(
1819
"""
1920
estimator: estimator object
@@ -251,7 +252,6 @@ def get_tuning_param_grid(self):
251252

252253

253254
_enet_cv_params = dedent("""
254-
255255
l1_ratio: float, str, list
256256
The l1_ratio value to use. If a float is provided then this parameter is fixed and not tuned over. If l1_ratio='tune' then the l1_ratio is tuned over using an automatically generated tuning parameter sequence. Alternatively, the user may provide a list of l1_ratio values to tune over.
257257

ya_glm/base/GlmWithInit.py

Lines changed: 2 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,9 @@
33

44
from ya_glm.utils import fit_if_unfitted
55
from ya_glm.utils import get_coef_and_intercept
6-
from ya_glm.processing import process_init_data
76

87

9-
class InitMixin:
10-
"""
11-
init
12-
13-
_get_defualt_init
14-
15-
_get_init_data_from_fit_est
16-
"""
8+
class GlmWithInitMixin:
179

1810
def get_init_data(self, X, y=None, **fit_params):
1911
"""
@@ -60,44 +52,6 @@ def get_init_data(self, X, y=None, **fit_params):
6052
**fit_params)
6153
return self._get_init_data_from_fit_est(est=init_est)
6254

63-
def _get_defualt_init(self):
64-
raise NotImplementedError
65-
66-
def _get_init_data_from_fit_est(self, est, X, y):
67-
raise NotImplementedError
68-
69-
70-
class GlmWithInitMixin(InitMixin):
71-
72-
def fit(self, X, y, sample_weight=None):
73-
74-
# validate the data!
75-
X, y, sample_weight = self._validate_data(X, y,
76-
sample_weight=sample_weight)
77-
78-
# get data for initialization
79-
init_data = self.get_init_data(X, y)
80-
if 'est' in init_data:
81-
self.init_est_ = init_data['est']
82-
del init_data['est']
83-
84-
# pre-process data
85-
X_pro, y_pro, pre_pro_out = self.preprocess(X, y,
86-
sample_weight=sample_weight,
87-
copy=True)
88-
89-
# possibly process the init data e.g. shift/scale
90-
init_data_pro = process_init_data(init_data=init_data,
91-
pre_pro_out=pre_pro_out)
92-
93-
# Fit!
94-
fit_out = self.compute_fit(X=X_pro, y=y_pro,
95-
init_data=init_data_pro,
96-
sample_weight=sample_weight)
97-
98-
self._set_fit(fit_out=fit_out, pre_pro_out=pre_pro_out)
99-
return self
100-
10155
def _get_init_data_from_fit_est(self, est):
10256
out = {}
10357
coef, intercept = get_coef_and_intercept(est, copy=True, error=True)
@@ -112,23 +66,5 @@ def _get_init_data_from_fit_est(self, est):
11266

11367
return out
11468

115-
def get_pen_val_max(self, X, y, init_data=None, sample_weight=None):
116-
if init_data is None:
117-
init_data = self.get_init_data(X, y, sample_weight=sample_weight)
118-
119-
X_pro, y_pro, pre_pro_out = self.preprocess(X, y,
120-
sample_weight=sample_weight,
121-
copy=True)
122-
123-
init_data_pro = process_init_data(init_data=init_data,
124-
pre_pro_out=pre_pro_out)
125-
126-
return self._get_pen_val_max_from_pro(X=X_pro, y=y_pro,
127-
init_data=init_data_pro,
128-
sample_weight=sample_weight)
129-
130-
def _get_pen_val_max_from_pro(self, X, y, init_data, sample_weight=None):
131-
raise NotImplementedError
132-
133-
def compute_fit(self, X, y, init_data, sample_weight=None):
69+
def _get_defualt_init(self):
13470
raise NotImplementedError

0 commit comments

Comments
 (0)