Skip to content

Commit 873561c

Browse files
committed
began adding docs
1 parent 253a936 commit 873561c

4 files changed

Lines changed: 119 additions & 15 deletions

File tree

todos.txt

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,21 @@
11
Need to do
22
----------
33

4-
- multinomial
5-
- huber
6-
- Poisson
7-
- Gamma
8-
- max pen val for Tikhonov
9-
- total variation CV and max pen val
10-
- extend LLA algorithm to group lasso and nuclear norm
11-
124
- testing testing testing
135
- documentation documentation documentation
146
- speed comparison of opt module
157
- compare fits to sklearn baseline
168

17-
- 1se rule for ENet
9+
- Gamma
10+
- max pen val for Tikhonov
11+
- total variation CV and max pen val
1812

13+
- 1se rule for ENet
1914

2015
Eventualy
2116
---------
22-
- sample weights
2317
- cox
2418
- cv over other parameters with path algorithms
25-
- quantile regression (this will require a different default solver -- perhaps cvxpy?)
2619
- constraints: positive, simplex
2720
- bulid coordinate descent framework (e.g. based on https://arxiv.org/abs/1410.1386)
2821
- for cv_scorer figure out how to have custom fit_metrics (e.g. for n_nonzero) instead of the ugly train/test

ya_glm/pen_glms/GlmLasso.py

Lines changed: 88 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from textwrap import dedent
2+
13
from ya_glm.base.Glm import Glm
24
from ya_glm.base.GlmCV import GlmCVSinglePen, GlmCVENet
35
from ya_glm.cv.CVPath import CVPathMixin
@@ -11,12 +13,41 @@
1113
from ya_glm.processing import check_estimator_type
1214

1315

16+
_glm_lasso_params = dedent("""
17+
pen_val: float
18+
The penalty value.
19+
20+
lasso_weights: None, array-like
21+
Optional weights to put on each term in the penalty.
22+
23+
groups: None, list of ints
24+
Optional groups of variables. If groups is provided then each element in the list should be a list of feature indices. Variables not in a group are not penalized.
25+
26+
ridge_pen_val: None, float
27+
Penalty strength for an optional ridge penalty.
28+
29+
ridge_weights: None, array-like shape (n_featuers, )
30+
Optional features weights for the ridge peanlty.
31+
32+
tikhonov: None, array-like (K, n_features)
33+
Optional tikhonov matrix for the ridge penalty. Both tikhonov and ridge weights cannot be provided at the same time.
34+
""")
35+
36+
1437
class GlmLasso(Glm):
1538

39+
descr = dedent("""
40+
Lasso or group lasso penalty with an optional ridge penalty.
41+
""")
42+
43+
descr_mr = dedent("""
44+
Lasso, group lasso, multi-task lasso or nuclear norm penalty with an optional ridge penalty.
45+
""")
46+
1647
@add_from_classes(Glm)
17-
def __init__(self, pen_val=1, lasso_weights=None,
18-
ridge_pen_val=None, ridge_weights=None, tikhonov=None,
19-
groups=None): pass
48+
def __init__(self, pen_val=1, lasso_weights=None, groups=None,
49+
ridge_pen_val=None, ridge_weights=None, tikhonov=None
50+
): pass
2051

2152
def _get_solve_kws(self):
2253
"""
@@ -89,6 +120,10 @@ def _get_pen_val_max_from_pro(self, X, y, sample_weight=None):
89120

90121
class GlmLassoCVPath(CVPathMixin, GlmCVSinglePen):
91122

123+
desrc = dedent("""
124+
Tunes the lasso penalty parameter via cross-validation using a path algorithm.
125+
""")
126+
92127
def _get_solve_path_kws(self):
93128
if not hasattr(self, 'pen_val_seq_'):
94129
raise RuntimeError("pen_val_seq_ has not yet been set")
@@ -103,12 +138,53 @@ def _check_base_estimator(self, estimator):
103138

104139

105140
class GlmLassoCVGridSearch(CVGridSearchMixin, GlmCVSinglePen):
141+
desrc = dedent("""
142+
Tunes the lasso penalty parameter via cross-validation.
143+
""")
144+
106145
def _check_base_estimator(self, estimator):
107146
check_estimator_type(estimator, GlmLasso)
108147

109148

149+
_glm_lasso_params = dedent("""
150+
pen_val: float
151+
The penalty strength (corresponds to lambda in glmnet)
152+
153+
l1_ratio: float
154+
The ElasticNet mixing parameter, with ``0 <= l1_ratio <= 1``. For
155+
``l1_ratio = 0`` the penalty is an L2 penalty. ``For l1_ratio = 1`` it
156+
is an L1 penalty. For ``0 < l1_ratio < 1``, the penalty is a
157+
combination of L1 and L2.
158+
159+
lasso_weights: None, array-like
160+
Optional weights to put on each term in the penalty.
161+
162+
groups: None, list of ints
163+
Optional groups of variables. If groups is provided then each element in the list should be a list of feature indices. Variables not in a group are not penalized.
164+
165+
tikhonov: None, array-like (K, n_features)
166+
Optional tikhonov matrix for the ridge penalty.
167+
""")
168+
169+
110170
class GlmENet(Glm):
111171

172+
descr = dedent("""
173+
Elastic net penalty
174+
175+
pen_val * (l1_ratio) Lasso(coef) + pen_val * (1 - l1_ratio) * Ridge(coef)
176+
177+
where Lasso(coef) is either the Lasso or group Lasso penalty.
178+
""")
179+
180+
descr_mr = dedent("""
181+
Elastic net penalty
182+
183+
pen_val * (l1_ratio) Lasso(coef) + pen_val * (1 - l1_ratio) * Ridge(coef)
184+
185+
where Lasso(coef) is either the Lasso, group Lasso, multi-task Lasso or nuclear norm.
186+
""")
187+
112188
@add_from_classes(Glm)
113189
def __init__(self, pen_val=1, l1_ratio=0.5,
114190
lasso_weights=None, ridge_weights=None, tikhonov=None,
@@ -192,6 +268,10 @@ def _get_pen_val_max_from_pro(self, X, y, sample_weight=None):
192268
class GlmENetCVPath(ENetCVPathMixin, GlmCVENet):
193269
solve_glm_path = None
194270

271+
desrc = dedent("""
272+
Tunes the ElasticNet penalty parameter and or the l1_ratio via cross-validation. Makes use of a path algorithm for computing the penalty value tuning path.
273+
""")
274+
195275
def _get_solve_path_enet_base_kws(self):
196276
kws = self.estimator._get_solve_kws()
197277
del kws['lasso_pen']
@@ -203,5 +283,10 @@ def _check_base_estimator(self, estimator):
203283

204284

205285
class GlmENetCVGridSearch(CVGridSearchMixin, GlmCVSinglePen):
286+
287+
desrc = dedent("""
288+
Tunes the ElasticNet penalty parameter and or the l1_ratio via cross-validation.
289+
""")
290+
206291
def _check_base_estimator(self, estimator):
207292
check_estimator_type(estimator, GlmENet)

ya_glm/pen_glms/GlmRidge.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from textwrap import dedent
2+
13
from ya_glm.base.Glm import Glm
24
from ya_glm.base.GlmCV import GlmCVSinglePen
35
from ya_glm.cv.CVPath import CVPathMixin
@@ -9,8 +11,24 @@
911
from ya_glm.processing import check_estimator_type
1012

1113

14+
_glm_ridge_params = dedent("""
15+
pen_val: float
16+
The penalty value.
17+
18+
weights: None, array-like shape (n_featuers, )
19+
Optional features weights for the ridge peanlty.
20+
21+
tikhonov: None, array-like (K, n_features)
22+
Optional tikhonov matrix for the ridge penalty. Both tikhonov and weights cannot be provided at the same time.
23+
""")
24+
25+
1226
class GlmRidge(Glm):
1327

28+
descr = dedent("""
29+
Ridge penalty.
30+
""")
31+
1432
@add_from_classes(Glm)
1533
def __init__(self, pen_val=1, weights=None, tikhonov=None): pass
1634

@@ -59,6 +77,10 @@ def _get_pen_val_max_from_pro(self, X, y, sample_weight=None):
5977

6078
class GlmRidgeCVPath(CVPathMixin, GlmCVSinglePen):
6179

80+
descr = dedent("""
81+
Tunes the ridge penalty parameter via cross-validation using a path algorithm.
82+
""")
83+
6284
def _get_solve_path_kws(self):
6385
if not hasattr(self, 'pen_val_seq_'):
6486
raise RuntimeError("pen_val_seq_ has not yet been set")
@@ -74,5 +96,9 @@ def _check_base_estimator(self, estimator):
7496

7597
class GlmRidgeCVGridSearch(CVGridSearchMixin, GlmCVSinglePen):
7698

99+
descr = dedent("""
100+
Tunes the ridge penalty parameter via cross-validation.
101+
""")
102+
77103
def _check_base_estimator(self, estimator):
78104
check_estimator_type(estimator, GlmRidge)

ya_glm/pen_glms/GlmVanilla.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,5 @@ def _get_solve_kws(self):
1010
'loss_kws': loss_kws,
1111

1212
'fit_intercept': self.fit_intercept,
13-
**self.opt_kws,
13+
**self.opt_kws
1414
}

0 commit comments

Comments
 (0)