Skip to content

Commit e719d2a

Browse files
committed
sample weights
1 parent 49f7579 commit e719d2a

17 files changed

Lines changed: 311 additions & 137 deletions

ya_glm/base/Glm.py

Lines changed: 50 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
from sklearn.base import BaseEstimator
22
from sklearn.utils.validation import check_is_fitted
33
from sklearn.utils.extmath import safe_sparse_dot
4-
from sklearn.utils.validation import check_array, FLOAT_DTYPES
4+
from sklearn.utils.validation import check_array, _check_sample_weight, \
5+
FLOAT_DTYPES
56
from scipy.linalg import svd
67

78
import numpy as np
89
from textwrap import dedent
910

1011
from ya_glm.autoassign import autoassign
1112
from ya_glm.processing import process_X, deprocess_fit
12-
from ya_glm.opt.GroupLasso import euclid_norm
13+
from ya_glm.opt.utils import euclid_norm
1314

1415

1516
_glm_base_params = dedent("""
@@ -33,7 +34,7 @@ class Glm(BaseEstimator):
3334
def __init__(self, fit_intercept=True, standardize=False, opt_kws={}):
3435
pass
3536

36-
def fit(self, X, y):
37+
def fit(self, X, y, sample_weight=None):
3738
"""
3839
Fits the GLM.
3940
@@ -44,22 +45,32 @@ def fit(self, X, y):
4445
4546
y: array-like, shape (n_samples, )
4647
The training response data.
48+
49+
sample_weight: None or array-like, shape (n_samples,)
50+
Individual weights for each sample.
4751
"""
4852

49-
X, y = self._validate_data(X, y)
53+
X, y, sample_weight = self._validate_data(X, y,
54+
sample_weight=sample_weight)
5055

5156
# TODO: do we want to give the user the option to not copy?
52-
X, y, pre_pro_out = self.preprocess(X, y, copy=True)
57+
X, y, pre_pro_out = self.preprocess(X=X, y=y,
58+
sample_weight=sample_weight,
59+
copy=True)
60+
61+
kws = self._get_solve_kws()
62+
if sample_weight is not None:
63+
kws['sample_weight'] = sample_weight
5364

5465
coef, intercept, out_data = self.solve_glm(X=X, y=y,
55-
**self._get_solve_kws())
66+
**kws)
5667

5768
self._set_fit(fit_out={'coef': coef, 'intercept': intercept,
5869
'opt_data': out_data},
5970
pre_pro_out=pre_pro_out)
6071
return self
6172

62-
def _validate_data(self, X, y, accept_sparse=False):
73+
def _validate_data(self, X, y, sample_weight=None, accept_sparse=True):
6374
"""
6475
Validates the X/y data. This should not change the raw input data, but may reformat the data (e.g. convert pandas to numpy).
6576
@@ -76,18 +87,22 @@ def _validate_data(self, X, y, accept_sparse=False):
7687
X = check_array(X, accept_sparse=accept_sparse,
7788
dtype=FLOAT_DTYPES)
7889

90+
if sample_weight is not None:
91+
sample_weight = _check_sample_weight(sample_weight, X,
92+
dtype=X.dtype)
93+
7994
# make sure y is numpy and of same dtype as X
8095
y = np.asarray(y, dtype=X.dtype)
8196

8297
# make sure X, y have same number of samples
8398
if y.shape[0] != X.shape[0]:
8499
raise ValueError("X and y must have the same number of rows!")
85100

86-
return X, y
101+
return X, y, sample_weight
87102

88-
def preprocess(self, X, y, copy=True):
103+
def preprocess(self, X, y, sample_weight=None, copy=True):
89104
"""
90-
Preprocesses the data for fitting. This method may transform the data e.g. centering and scaling X.
105+
Preprocesses the data for fitting. This method may transform the data e.g. centering and scaling X. If sample weights are provided then these are used for computing weighted means / standard deviations for standardization.
91106
92107
Parameters
93108
----------
@@ -97,6 +112,9 @@ def preprocess(self, X, y, copy=True):
97112
y: array-like, shape (n_samples, ) or (n_samples, n_responses)
98113
The response data.
99114
115+
sample_weight: None or array-like, shape (n_samples,)
116+
Individual weights for each sample.
117+
100118
copy: bool
101119
Whether or not to copy the X/y arrays or modify them in place.
102120
@@ -113,38 +131,22 @@ def preprocess(self, X, y, copy=True):
113131
pro_pro_out: dict
114132
Data from preprocessing e.g. X_center, X_scale.
115133
"""
134+
groups = self.groups if hasattr(self, 'groups') else None
116135

117136
X, out = process_X(X,
118137
standardize=self.standardize,
119-
groups=self._get_groups(),
138+
groups=groups,
139+
sample_weight=sample_weight,
120140
copy=copy,
121141
check_input=False,
122142
accept_sparse=False, # TODO!
123143
allow_const_cols=not self.fit_intercept)
124144

125-
y, y_out = self._process_y(y, copy=copy)
145+
y, y_out = self._process_y(y, sample_weight=sample_weight, copy=copy)
126146
out.update(y_out)
127147

128148
return X, y, out
129149

130-
# TODO: do we want this?
131-
# def _maybe_get(self, param):
132-
# """
133-
# Safely gets an attribute that may not exist (e.g. like self.param). Returns None if the object does not have the attribute.
134-
# """
135-
# if hasattr(self, param):
136-
# return self.__dict__[param]
137-
# else:
138-
# return None
139-
def _get_groups(self):
140-
"""
141-
Safely gets an attribute that may not exist (e.g. like self.param). Returns None if the object does not have the attribute.
142-
"""
143-
if hasattr(self, 'groups'):
144-
return self.groups
145-
else:
146-
return None
147-
148150
def _set_fit(self, fit_out, pre_pro_out):
149151
"""
150152
Sets the fit.
@@ -221,7 +223,7 @@ def decision_function(self, X):
221223
def _more_tags(self):
222224
return {'requires_y': True}
223225

224-
def get_pen_val_max(self, X, y):
226+
def get_pen_val_max(self, X, y, sample_weight=None):
225227
"""
226228
Returns the largest reasonable penalty parameter for the processed data.
227229
@@ -233,13 +235,20 @@ def get_pen_val_max(self, X, y):
233235
y: array-like, shape (n_samples, )
234236
The training response data.
235237
238+
sample_weight: None or array-like, shape (n_samples,)
239+
Individual weights for each sample.
240+
236241
Output
237242
------
238243
pen_val_max: float
239244
Largest reasonable tuning parameter value.
240245
"""
241-
X_pro, y_pro, _ = self.preprocess(X, y, copy=True)
242-
return self._get_pen_val_max_from_pro(X_pro, y_pro)
246+
X_pro, y_pro, _ = self.preprocess(X, y,
247+
sample_weight=sample_weight,
248+
copy=True)
249+
250+
return self._get_pen_val_max_from_pro(X_pro, y_pro,
251+
sample_weight=sample_weight)
243252

244253
def _get_penalty_kind(self):
245254
"""
@@ -295,13 +304,19 @@ def transform(x):
295304

296305
return transform
297306

298-
def _process_y(self, y, copy=True):
307+
def _process_y(self, y, sample_weight=None, copy=True):
299308
"""
300309
Parameters
301310
---------
302311
y: array-like, shape (n_samples, ) or (n_samples, n_responses)
303312
The response data.
304313
314+
sample_weight: None or array-like, shape (n_samples,)
315+
Individual weights for each sample
316+
317+
copy: bool
318+
Whether or not to copy the X/y arrays or modify them in place.
319+
305320
Output
306321
------
307322
y: array-like
@@ -315,7 +330,7 @@ def _get_solve_kws(self):
315330
"""
316331
raise NotImplementedError
317332

318-
def _get_pen_val_max_from_pro(self, X, y):
333+
def _get_pen_val_max_from_pro(self, X, y, sample_weight=None):
319334
"""
320335
Computes the largest reasonable tuning parameter value.
321336
"""

ya_glm/base/GlmCV.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def __init__(self,
5555
cv_pre_dispatch='2*n_jobs'):
5656
pass
5757

58-
def fit(self, X, y):
58+
def fit(self, X, y, sample_weight=None):
5959
"""
6060
Runs cross-validation then refits the GLM with the selected tuning parameter.
6161
@@ -66,19 +66,32 @@ def fit(self, X, y):
6666
6767
y: array-like, shape (n_samples, )
6868
The training response data.
69+
70+
sample_weight: None or array-like, shape (n_samples,)
71+
Individual weights for each sample.
6972
"""
7073

7174
# check the input data
7275
self._check_base_estimator(self.estimator)
7376
est = clone(self.estimator)
74-
X, y = est._validate_data(X, y)
77+
X, y, sample_weight = est._validate_data(X, y,
78+
sample_weight=sample_weight)
7579

7680
# set up the tuning parameter values using the processed data
77-
self._set_tuning_values(X=X, y=y)
81+
self._set_tuning_values(X=X, y=y, sample_weight=sample_weight)
82+
83+
# maybe add sample weight to fit params
84+
if sample_weight is not None:
85+
fit_params = {'sample_weight': sample_weight}
86+
else:
87+
fit_params = None
7888

7989
# run cross-validation on the raw data
8090
start_time = time()
81-
self.cv_results_ = self._run_cv(estimator=est, X=X, y=y, cv=self.cv)
91+
self.cv_results_ = \
92+
self._run_cv(estimator=est, X=X, y=y, cv=self.cv,
93+
fit_params=fit_params)
94+
8295
self.cv_data_ = {'cv_runtime': time() - start_time}
8396

8497
# select best tuning parameter values
@@ -90,7 +103,7 @@ def fit(self, X, y):
90103

91104
# refit on the raw data
92105
start_time = time()
93-
self.best_estimator_ = est.fit(X, y)
106+
self.best_estimator_ = est.fit(X, y, sample_weight=sample_weight)
94107
self.cv_data_['refit_runtime'] = time() - start_time
95108

96109
return self
@@ -127,7 +140,7 @@ def check_base_estimator(self, estimator):
127140
"""
128141
raise NotImplementedError
129142

130-
def _set_tuning_values(self, X, y):
143+
def _set_tuning_values(self, X, y, **kws):
131144
"""
132145
Sets the tuning parameter sequence from the transformed data.
133146
@@ -138,6 +151,9 @@ def _set_tuning_values(self, X, y):
138151
139152
y: array-like, shape (n_samples, )
140153
The processed training response data.
154+
155+
**kws:
156+
Additional keyword arguments.
141157
"""
142158
# subclass should overwrite
143159
raise NotImplementedError
@@ -182,11 +198,13 @@ def __init__(self,
182198
):
183199
pass
184200

185-
def _set_tuning_values(self, X, y):
201+
def _set_tuning_values(self, X, y, sample_weight=None):
186202
if self.pen_vals is None:
187-
pen_val_max = self.estimator.get_pen_val_max(X, y)
203+
pen_val_max = self.estimator.\
204+
get_pen_val_max(X=X, y=y, sample_weight=sample_weight)
188205
else:
189206
pen_val_max = None
207+
190208
self._set_tune_from_pen_max(pen_val_max=pen_val_max)
191209

192210
def _set_tune_from_pen_max(self, pen_val_max=None):
@@ -283,9 +301,10 @@ def _tune_pen_val(self):
283301
else:
284302
return True
285303

286-
def _set_tuning_values(self, X, y):
304+
def _set_tuning_values(self, X, y, sample_weight=None):
287305
if self.pen_vals is None:
288-
enet_pen_max = self.estimator.get_pen_val_max(X, y)
306+
enet_pen_max = self.estimator.\
307+
get_pen_val_max(X, y, sample_weight=sample_weight)
289308
lasso_pen_max = enet_pen_max * self.estimator.l1_ratio
290309
else:
291310
lasso_pen_max = None

0 commit comments

Comments
 (0)