Skip to content

Commit 6c8f65e

Browse files
committed
added more penalties
1 parent af969ae commit 6c8f65e

2 files changed

Lines changed: 36 additions & 17 deletions

File tree

ya_glm/backends/cvxpy/glm_solver.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
from time import time
44

55
from ya_glm.utils import clip_zero
6-
from ya_glm.cvxpy.penalty import lasso_penalty, ridge_penalty, tikhonov_penalty
6+
from ya_glm.cvxpy.penalty import lasso_penalty, ridge_penalty,\
7+
tikhonov_penalty, multi_task_lasso_penalty, group_lasso_penalty
78
from ya_glm.cvxpy.loss_functions import lin_reg_loss, log_reg_loss,\
89
quantile_reg_loss
910
from ya_glm.backends.fista.glm_solver import process_param_path
@@ -148,11 +149,6 @@ def setup_problem(X, y,
148149
"""
149150
glm_loss = get_glm_loss(loss_func)
150151

151-
# TODO: add these
152-
if groups is not None:
153-
raise NotImplementedError
154-
if L1to2:
155-
raise NotImplementedError
156152
if nuc:
157153
raise NotImplementedError
158154

@@ -185,13 +181,23 @@ def setup_problem(X, y,
185181

186182
# Add lasso
187183
if lasso_pen is not None:
188-
objective += lasso_pen * lasso_penalty(coef, weights=lasso_weights)
184+
185+
if groups is not None:
186+
objective += lasso_pen * \
187+
group_lasso_penalty(coef, groups=groups, weights=lasso_weights)
188+
189+
elif L1to2:
190+
objective += lasso_pen * \
191+
multi_task_lasso_penalty(coef, weights=lasso_weights)
192+
193+
else:
194+
objective += lasso_pen * lasso_penalty(coef, weights=lasso_weights)
189195

190196
# Add ridge
191197
if ridge_pen is not None:
192198
if tikhonov:
193199
objective += ridge_pen * \
194-
tikhonov_penalty(coef, tikT_tik=tikhonov.T @ tikhonov)
200+
tikhonov_penalty(coef, tikhonov=tikhonov)
195201

196202
else:
197203
objective += ridge_pen * ridge_penalty(coef, weights=ridge_weights)

ya_glm/cvxpy/penalty.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,25 @@ def ridge_penalty(coef, weights=None):
1515
return 0.5 * cp.sum_squares(coef)
1616

1717

18-
def tikhonov_penalty(coef, tikT_tik):
19-
"""
20-
Parameters
21-
----------
22-
tik_tik: array-like, shape (n_features, n_features)
23-
The squared tikhonov matrix tikT_tik = tikhonov.T @ tikhonov
24-
25-
"""
26-
return 0.5 * cp.quad_form(coef, tikT_tik)
18+
def tikhonov_penalty(coef, tikhonov):
19+
return 0.5 * cp.quad_form(coef, tikhonov.T @ tikhonov)
20+
21+
22+
def multi_task_lasso_penalty(coef, weights=None):
23+
24+
row_norms = cp.norm(coef, p='fro', axis=1)
25+
26+
if weights:
27+
return weights.T @ row_norms
28+
else:
29+
return cp.sum(row_norms)
30+
31+
32+
def group_lasso_penalty(coef, groups, weights=None):
33+
34+
group_norms = [cp.norm(coef[grp_idxs], p='fro') for grp_idxs in groups]
35+
36+
if weights is None:
37+
return cp.sum(group_norms)
38+
else:
39+
return weights.T @ cp.hstack(group_norms)

0 commit comments

Comments
 (0)