Skip to content

Commit bdd3cb7

Browse files
committed
Merge pull request scikit-learn#4763 from sonnyhu/RidgeCV_slice_sampleweight
[MRG + 1] Fix scikit-learn#4755 (RidgeCV ignores sample_weights if cv != None)
2 parents a7de146 + 92f9c9e commit bdd3cb7

File tree

2 files changed

+29
-4
lines changed

2 files changed

+29
-4
lines changed

sklearn/linear_model/ridge.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -871,10 +871,7 @@ def fit(self, X, y, sample_weight=None):
871871
raise ValueError("cv!=None and store_cv_values=True "
872872
" are incompatible")
873873
parameters = {'alpha': self.alphas}
874-
# FIXME: sample_weight must be split into training/validation data
875-
# too!
876-
#fit_params = {'sample_weight' : sample_weight}
877-
fit_params = {}
874+
fit_params = {'sample_weight' : sample_weight}
878875
gs = GridSearchCV(Ridge(fit_intercept=self.fit_intercept),
879876
parameters, fit_params=fit_params, cv=self.cv)
880877
gs.fit(X, y)

sklearn/linear_model/tests/test_ridge.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
from sklearn.linear_model.ridge import _solve_cholesky
2828
from sklearn.linear_model.ridge import _solve_cholesky_kernel
2929

30+
from sklearn.grid_search import GridSearchCV
31+
3032
from sklearn.cross_validation import KFold
3133

3234

@@ -527,6 +529,32 @@ def test_ridgecv_store_cv_values():
527529
assert_equal(r.cv_values_.shape, (n_samples, n_responses, n_alphas))
528530

529531

532+
def test_ridgecv_sample_weight():
533+
rng = np.random.RandomState(0)
534+
alphas = (0.1, 1.0, 10.0)
535+
536+
# There are different algorithms for n_samples > n_features
537+
# and the opposite, so test them both.
538+
for n_samples, n_features in ((6, 5), (5, 10)):
539+
y = rng.randn(n_samples)
540+
X = rng.randn(n_samples, n_features)
541+
sample_weight = 1 + rng.rand(n_samples)
542+
543+
cv = KFold(n_samples, 5)
544+
ridgecv = RidgeCV(alphas=alphas, cv=cv)
545+
ridgecv.fit(X, y, sample_weight=sample_weight)
546+
547+
# Check using GridSearchCV directly
548+
parameters = {'alpha': alphas}
549+
fit_params = {'sample_weight': sample_weight}
550+
gs = GridSearchCV(Ridge(), parameters, fit_params=fit_params,
551+
cv=cv)
552+
gs.fit(X, y)
553+
554+
assert_equal(ridgecv.alpha_, gs.best_estimator_.alpha)
555+
assert_array_almost_equal(ridgecv.coef_, gs.best_estimator_.coef_)
556+
557+
530558
def test_raises_value_error_if_sample_weights_greater_than_1d():
531559
# Sample weights must be either scalar or 1D
532560

0 commit comments

Comments
 (0)