|
3 | 3 | from time import time |
4 | 4 |
|
5 | 5 | from ya_glm.utils import clip_zero |
6 | | -from ya_glm.cvxpy.penalty import lasso_penalty, ridge_penalty, tikhonov_penalty |
| 6 | +from ya_glm.cvxpy.penalty import lasso_penalty, ridge_penalty,\ |
| 7 | + tikhonov_penalty, multi_task_lasso_penalty, group_lasso_penalty |
7 | 8 | from ya_glm.cvxpy.loss_functions import lin_reg_loss, log_reg_loss,\ |
8 | 9 | quantile_reg_loss |
9 | 10 | from ya_glm.backends.fista.glm_solver import process_param_path |
@@ -148,11 +149,6 @@ def setup_problem(X, y, |
148 | 149 | """ |
149 | 150 | glm_loss = get_glm_loss(loss_func) |
150 | 151 |
|
151 | | - # TODO: add these |
152 | | - if groups is not None: |
153 | | - raise NotImplementedError |
154 | | - if L1to2: |
155 | | - raise NotImplementedError |
156 | 152 | if nuc: |
157 | 153 | raise NotImplementedError |
158 | 154 |
|
@@ -185,13 +181,23 @@ def setup_problem(X, y, |
185 | 181 |
|
186 | 182 | # Add lasso |
187 | 183 | if lasso_pen is not None: |
188 | | - objective += lasso_pen * lasso_penalty(coef, weights=lasso_weights) |
| 184 | + |
| 185 | + if groups is not None: |
| 186 | + objective += lasso_pen * \ |
| 187 | + group_lasso_penalty(coef, groups=groups, weights=lasso_weights) |
| 188 | + |
| 189 | + elif L1to2: |
| 190 | + objective += lasso_pen * \ |
| 191 | + multi_task_lasso_penalty(coef, weights=lasso_weights) |
| 192 | + |
| 193 | + else: |
| 194 | + objective += lasso_pen * lasso_penalty(coef, weights=lasso_weights) |
189 | 195 |
|
190 | 196 | # Add ridge |
191 | 197 | if ridge_pen is not None: |
192 | 198 | if tikhonov: |
193 | 199 | objective += ridge_pen * \ |
194 | | - tikhonov_penalty(coef, tikT_tik=tikhonov.T @ tikhonov) |
| 200 | + tikhonov_penalty(coef, tikhonov=tikhonov) |
195 | 201 |
|
196 | 202 | else: |
197 | 203 | objective += ridge_pen * ridge_penalty(coef, weights=ridge_weights) |
|
0 commit comments