Skip to content

Commit c84f170

Browse files
committed
up
1 parent 6c8f65e commit c84f170

2 files changed

Lines changed: 7 additions & 3 deletions

File tree

ya_glm/backends/cvxpy/glm_solver.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from time import time
44

55
from ya_glm.utils import clip_zero
6+
from ya_glm.processing import process_weights_group_lasso
67
from ya_glm.cvxpy.penalty import lasso_penalty, ridge_penalty,\
78
tikhonov_penalty, multi_task_lasso_penalty, group_lasso_penalty
89
from ya_glm.cvxpy.loss_functions import lin_reg_loss, log_reg_loss,\
@@ -183,6 +184,9 @@ def setup_problem(X, y,
183184
if lasso_pen is not None:
184185

185186
if groups is not None:
187+
lasso_weights = process_weights_group_lasso(groups=groups,
188+
weights=lasso_weights)
189+
186190
objective += lasso_pen * \
187191
group_lasso_penalty(coef, groups=groups, weights=lasso_weights)
188192

@@ -195,7 +199,7 @@ def setup_problem(X, y,
195199

196200
# Add ridge
197201
if ridge_pen is not None:
198-
if tikhonov:
202+
if tikhonov is not None:
199203
objective += ridge_pen * \
200204
tikhonov_penalty(coef, tikhonov=tikhonov)
201205

ya_glm/cvxpy/penalty.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def tikhonov_penalty(coef, tikhonov):
2121

2222
def multi_task_lasso_penalty(coef, weights=None):
2323

24-
row_norms = cp.norm(coef, p='fro', axis=1)
24+
row_norms = cp.norm(coef, axis=1)
2525

2626
if weights:
2727
return weights.T @ row_norms
@@ -31,7 +31,7 @@ def multi_task_lasso_penalty(coef, weights=None):
3131

3232
def group_lasso_penalty(coef, groups, weights=None):
3333

34-
group_norms = [cp.norm(coef[grp_idxs], p='fro') for grp_idxs in groups]
34+
group_norms = [cp.norm(coef[grp_idxs]) for grp_idxs in groups]
3535

3636
if weights is None:
3737
return cp.sum(group_norms)

0 commit comments

Comments
 (0)