Skip to content

Commit 32b2f8e

Browse files
committed
Merge pull request scikit-learn#4838 from trevorstephens/ridge_sw
[MRG+1] Add sample_weight support to RidgeClassifier
2 parents 8b44fa1 + 2f69574 commit 32b2f8e

File tree

3 files changed

+41
-6
lines changed

3 files changed

+41
-6
lines changed

doc/whats_new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ Enhancements
7171
It is now possible to ignore one or more labels, such as where
7272
a multiclass problem has a majority class to ignore. By `Joel Nothman`_.
7373

74+
- Add ``sample_weight`` support to :class:`linear_model.RidgeClassifier`.
75+
By `Trevor Stephens`_.
7476

7577
Bug fixes
7678
.........

sklearn/linear_model/ridge.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -572,7 +572,7 @@ def __init__(self, alpha=1.0, fit_intercept=True, normalize=False,
572572
copy_X=copy_X, max_iter=max_iter, tol=tol, solver=solver)
573573
self.class_weight = class_weight
574574

575-
def fit(self, X, y):
575+
def fit(self, X, y, sample_weight=None):
576576
"""Fit Ridge regression model.
577577
578578
Parameters
@@ -583,20 +583,24 @@ def fit(self, X, y):
583583
y : array-like, shape = [n_samples]
584584
Target values
585585
586+
sample_weight : float or numpy array of shape (n_samples,)
587+
Sample weight.
588+
586589
Returns
587590
-------
588591
self : returns an instance of self.
589592
"""
593+
if sample_weight is None:
594+
sample_weight = 1.
595+
590596
self._label_binarizer = LabelBinarizer(pos_label=1, neg_label=-1)
591597
Y = self._label_binarizer.fit_transform(y)
592598
if not self._label_binarizer.y_type_.startswith('multilabel'):
593599
y = column_or_1d(y, warn=True)
594600

595-
if self.class_weight:
596-
# get the class weight corresponding to each sample
597-
sample_weight = compute_sample_weight(self.class_weight, y)
598-
else:
599-
sample_weight = None
601+
# modify the sample weights with the corresponding class weight
602+
sample_weight = (sample_weight *
603+
compute_sample_weight(self.class_weight, y))
600604

601605
super(RidgeClassifier, self).fit(X, Y, sample_weight=sample_weight)
602606
return self

sklearn/linear_model/tests/test_ridge.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,35 @@ def test_class_weights():
489489
assert_array_almost_equal(clf.intercept_, clfa.intercept_)
490490

491491

492+
def test_class_weight_vs_sample_weight():
493+
"""Check class_weights resemble sample_weights behavior."""
494+
for clf in (RidgeClassifier, RidgeClassifierCV):
495+
496+
# Iris is balanced, so no effect expected for using 'balanced' weights
497+
clf1 = clf()
498+
clf1.fit(iris.data, iris.target)
499+
clf2 = clf(class_weight='balanced')
500+
clf2.fit(iris.data, iris.target)
501+
assert_almost_equal(clf1.coef_, clf2.coef_)
502+
503+
# Inflate importance of class 1, check against user-defined weights
504+
sample_weight = np.ones(iris.target.shape)
505+
sample_weight[iris.target == 1] *= 100
506+
class_weight = {0: 1., 1: 100., 2: 1.}
507+
clf1 = clf()
508+
clf1.fit(iris.data, iris.target, sample_weight)
509+
clf2 = clf(class_weight=class_weight)
510+
clf2.fit(iris.data, iris.target)
511+
assert_almost_equal(clf1.coef_, clf2.coef_)
512+
513+
# Check that sample_weight and class_weight are multiplicative
514+
clf1 = clf()
515+
clf1.fit(iris.data, iris.target, sample_weight ** 2)
516+
clf2 = clf(class_weight=class_weight)
517+
clf2.fit(iris.data, iris.target, sample_weight)
518+
assert_almost_equal(clf1.coef_, clf2.coef_)
519+
520+
492521
def test_class_weights_cv():
493522
# Test class weights for cross validated ridge classifier.
494523
X = np.array([[-1.0, -1.0], [-1.0, 0], [-.8, -1.0],

0 commit comments

Comments
 (0)