11from sklearn .base import BaseEstimator
22from sklearn .utils .validation import check_is_fitted
33from sklearn .utils .extmath import safe_sparse_dot
4- from sklearn .utils .validation import check_array , FLOAT_DTYPES
4+ from sklearn .utils .validation import check_array , _check_sample_weight , \
5+ FLOAT_DTYPES
56from scipy .linalg import svd
67
78import numpy as np
89from textwrap import dedent
910
1011from ya_glm .autoassign import autoassign
1112from ya_glm .processing import process_X , deprocess_fit
12- from ya_glm .opt .GroupLasso import euclid_norm
13+ from ya_glm .opt .utils import euclid_norm
1314
1415
1516_glm_base_params = dedent ("""
@@ -33,7 +34,7 @@ class Glm(BaseEstimator):
3334 def __init__ (self , fit_intercept = True , standardize = False , opt_kws = {}):
3435 pass
3536
36- def fit (self , X , y ):
37+ def fit (self , X , y , sample_weight = None ):
3738 """
3839 Fits the GLM.
3940
@@ -44,22 +45,32 @@ def fit(self, X, y):
4445
4546 y: array-like, shape (n_samples, )
4647 The training response data.
48+
49+ sample_weight: None or array-like, shape (n_samples,)
50+ Individual weights for each sample.
4751 """
4852
49- X , y = self ._validate_data (X , y )
53+ X , y , sample_weight = self ._validate_data (X , y ,
54+ sample_weight = sample_weight )
5055
5156 # TODO: do we want to give the user the option to not copy?
52- X , y , pre_pro_out = self .preprocess (X , y , copy = True )
57+ X , y , pre_pro_out = self .preprocess (X = X , y = y ,
58+ sample_weight = sample_weight ,
59+ copy = True )
60+
61+ kws = self ._get_solve_kws ()
62+ if sample_weight is not None :
63+ kws ['sample_weight' ] = sample_weight
5364
5465 coef , intercept , out_data = self .solve_glm (X = X , y = y ,
55- ** self . _get_solve_kws () )
66+ ** kws )
5667
5768 self ._set_fit (fit_out = {'coef' : coef , 'intercept' : intercept ,
5869 'opt_data' : out_data },
5970 pre_pro_out = pre_pro_out )
6071 return self
6172
62- def _validate_data (self , X , y , accept_sparse = False ):
73+ def _validate_data (self , X , y , sample_weight = None , accept_sparse = True ):
6374 """
6475 Validates the X/y data. This should not change the raw input data, but may reformat the data (e.g. convert pandas to numpy).
6576
@@ -76,18 +87,22 @@ def _validate_data(self, X, y, accept_sparse=False):
7687 X = check_array (X , accept_sparse = accept_sparse ,
7788 dtype = FLOAT_DTYPES )
7889
90+ if sample_weight is not None :
91+ sample_weight = _check_sample_weight (sample_weight , X ,
92+ dtype = X .dtype )
93+
7994 # make sure y is numpy and of same dtype as X
8095 y = np .asarray (y , dtype = X .dtype )
8196
8297 # make sure X, y have same number of samples
8398 if y .shape [0 ] != X .shape [0 ]:
8499 raise ValueError ("X and y must have the same number of rows!" )
85100
86- return X , y
101+ return X , y , sample_weight
87102
88- def preprocess (self , X , y , copy = True ):
103+ def preprocess (self , X , y , sample_weight = None , copy = True ):
89104 """
90- Preprocesses the data for fitting. This method may transform the data e.g. centering and scaling X.
105+ Preprocesses the data for fitting. This method may transform the data e.g. centering and scaling X. If sample weights are provided then these are used for computing weighted means / standard deviations for standardization.
91106
92107 Parameters
93108 ----------
@@ -97,6 +112,9 @@ def preprocess(self, X, y, copy=True):
97112 y: array-like, shape (n_samples, ) or (n_samples, n_responses)
98113 The response data.
99114
115+ sample_weight: None or array-like, shape (n_samples,)
116+ Individual weights for each sample.
117+
100118 copy: bool
101119 Whether or not to copy the X/y arrays or modify them in place.
102120
@@ -113,38 +131,22 @@ def preprocess(self, X, y, copy=True):
113131 pro_pro_out: dict
114132 Data from preprocessing e.g. X_center, X_scale.
115133 """
134+ groups = self .groups if hasattr (self , 'groups' ) else None
116135
117136 X , out = process_X (X ,
118137 standardize = self .standardize ,
119- groups = self ._get_groups (),
138+ groups = groups ,
139+ sample_weight = sample_weight ,
120140 copy = copy ,
121141 check_input = False ,
122142 accept_sparse = False , # TODO!
123143 allow_const_cols = not self .fit_intercept )
124144
125- y , y_out = self ._process_y (y , copy = copy )
145+ y , y_out = self ._process_y (y , sample_weight = sample_weight , copy = copy )
126146 out .update (y_out )
127147
128148 return X , y , out
129149
130- # TODO: do we want this?
131- # def _maybe_get(self, param):
132- # """
133- # Safely gets an attribute that may not exist (e.g. like self.param). Returns None if the object does not have the attribute.
134- # """
135- # if hasattr(self, param):
136- # return self.__dict__[param]
137- # else:
138- # return None
139- def _get_groups (self ):
140- """
141- Safely gets an attribute that may not exist (e.g. like self.param). Returns None if the object does not have the attribute.
142- """
143- if hasattr (self , 'groups' ):
144- return self .groups
145- else :
146- return None
147-
148150 def _set_fit (self , fit_out , pre_pro_out ):
149151 """
150152 Sets the fit.
@@ -221,7 +223,7 @@ def decision_function(self, X):
221223 def _more_tags (self ):
222224 return {'requires_y' : True }
223225
224- def get_pen_val_max (self , X , y ):
226+ def get_pen_val_max (self , X , y , sample_weight = None ):
225227 """
226228 Returns the largest reasonable penalty parameter for the processed data.
227229
@@ -233,13 +235,20 @@ def get_pen_val_max(self, X, y):
233235 y: array-like, shape (n_samples, )
234236 The training response data.
235237
238+ sample_weight: None or array-like, shape (n_samples,)
239+ Individual weights for each sample.
240+
236241 Output
237242 ------
238243 pen_val_max: float
239244 Largest reasonable tuning parameter value.
240245 """
241- X_pro , y_pro , _ = self .preprocess (X , y , copy = True )
242- return self ._get_pen_val_max_from_pro (X_pro , y_pro )
246+ X_pro , y_pro , _ = self .preprocess (X , y ,
247+ sample_weight = sample_weight ,
248+ copy = True )
249+
250+ return self ._get_pen_val_max_from_pro (X_pro , y_pro ,
251+ sample_weight = sample_weight )
243252
244253 def _get_penalty_kind (self ):
245254 """
@@ -295,13 +304,19 @@ def transform(x):
295304
296305 return transform
297306
298- def _process_y (self , y , copy = True ):
307+ def _process_y (self , y , sample_weight = None , copy = True ):
299308 """
300309 Parameters
301310 ---------
302311 y: array-like, shape (n_samples, ) or (n_samples, n_responses)
303312 The response data.
304313
314+ sample_weight: None or array-like, shape (n_samples,)
315+ Individual weights for each sample
316+
317+ copy: bool
318+ Whether or not to copy the X/y arrays or modify them in place.
319+
305320 Output
306321 ------
307322 y: array-like
@@ -315,7 +330,7 @@ def _get_solve_kws(self):
315330 """
316331 raise NotImplementedError
317332
318- def _get_pen_val_max_from_pro (self , X , y ):
333+ def _get_pen_val_max_from_pro (self , X , y , sample_weight = None ):
319334 """
320335 Computes the largest reasonable tuning parameter value.
321336 """
0 commit comments