Skip to content

Commit 1c42e79

Browse files
J-A16jnothman
authored andcommitted
FIX pass sample weights to final estimator (scikit-learn#15773)
1 parent 29932e6 commit 1c42e79

File tree

3 files changed

+37
-1
lines changed

3 files changed

+37
-1
lines changed

doc/whats_new/v0.23.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,14 @@ Changelog
5858
:func:`datasets.make_moons` now accept two-element tuple.
5959
:pr:`15707` by :user:`Maciej J Mikulski <mjmikulski>`
6060

61+
:mod:`sklearn.linear_model`
62+
...........................
63+
64+
- |Fix| Fixed a bug where if a `sample_weight` parameter was passed to the fit
65+
method of :class:`linear_model.RANSACRegressor`, it would not be passed to
66+
the wrapped `base_estimator` during the fitting of the final model.
67+
:pr:`15573` by :user:`Jeremy Alexandre <J-A16>`.
68+
6169
:mod:`sklearn.preprocessing`
6270
............................
6371

sklearn/linear_model/_ransac.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,7 @@ def fit(self, X, y, sample_weight=None):
328328
inlier_mask_best = None
329329
X_inlier_best = None
330330
y_inlier_best = None
331+
inlier_best_idxs_subset = None
331332
self.n_skips_no_inliers_ = 0
332333
self.n_skips_invalid_data_ = 0
333334
self.n_skips_invalid_model_ = 0
@@ -404,6 +405,7 @@ def fit(self, X, y, sample_weight=None):
404405
inlier_mask_best = inlier_mask_subset
405406
X_inlier_best = X_inlier_subset
406407
y_inlier_best = y_inlier_subset
408+
inlier_best_idxs_subset = inlier_idxs_subset
407409

408410
max_trials = min(
409411
max_trials,
@@ -441,7 +443,13 @@ def fit(self, X, y, sample_weight=None):
441443
ConvergenceWarning)
442444

443445
# estimate final model using all inliers
444-
base_estimator.fit(X_inlier_best, y_inlier_best)
446+
if sample_weight is None:
447+
base_estimator.fit(X_inlier_best, y_inlier_best)
448+
else:
449+
base_estimator.fit(
450+
X_inlier_best,
451+
y_inlier_best,
452+
sample_weight=sample_weight[inlier_best_idxs_subset])
445453

446454
self.estimator_ = base_estimator
447455
self.inlier_mask_ = inlier_mask_best

sklearn/linear_model/tests/test_ransac.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from sklearn.utils._testing import assert_almost_equal
1111
from sklearn.utils._testing import assert_raises_regexp
1212
from sklearn.utils._testing import assert_raises
13+
from sklearn.utils._testing import assert_allclose
14+
from sklearn.datasets import make_regression
1315
from sklearn.linear_model import LinearRegression, RANSACRegressor, Lasso
1416
from sklearn.linear_model._ransac import _dynamic_max_trials
1517
from sklearn.exceptions import ConvergenceWarning
@@ -494,3 +496,21 @@ def test_ransac_fit_sample_weight():
494496
base_estimator = Lasso()
495497
ransac_estimator = RANSACRegressor(base_estimator)
496498
assert_raises(ValueError, ransac_estimator.fit, X, y, weights)
499+
500+
501+
def test_ransac_final_model_fit_sample_weight():
502+
X, y = make_regression(n_samples=1000, random_state=10)
503+
rng = check_random_state(42)
504+
sample_weight = rng.randint(1, 4, size=y.shape[0])
505+
sample_weight = sample_weight / sample_weight.sum()
506+
ransac = RANSACRegressor(base_estimator=LinearRegression(), random_state=0)
507+
ransac.fit(X, y, sample_weight=sample_weight)
508+
509+
final_model = LinearRegression()
510+
mask_samples = ransac.inlier_mask_
511+
final_model.fit(
512+
X[mask_samples], y[mask_samples],
513+
sample_weight=sample_weight[mask_samples]
514+
)
515+
516+
assert_allclose(ransac.estimator_.coef_, final_model.coef_)

0 commit comments

Comments
 (0)