Skip to content

Commit c50dc0e

Browse files
markroth8jnothman
authored andcommitted
MAINT Fix scikit-learn#9350: Enable has_fit_parameter() and fit_score_takes_y() to work with @deprecated in Python 2 (scikit-learn#11277)
1 parent 768ff4d commit c50dc0e

File tree

3 files changed

+28
-0
lines changed

3 files changed

+28
-0
lines changed

sklearn/utils/deprecation.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ def wrapped(*args, **kwargs):
7878
return fun(*args, **kwargs)
7979

8080
wrapped.__doc__ = self._update_doc(wrapped.__doc__)
81+
# Add a reference to the wrapped function so that we can introspect
82+
# on function arguments in Python 2 (already works in Python 3)
83+
wrapped.__wrapped__ = fun
8184

8285
return wrapped
8386

sklearn/utils/tests/test_estimator_checks.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@
99
from sklearn.externals import joblib
1010

1111
from sklearn.base import BaseEstimator, ClassifierMixin
12+
from sklearn.utils import deprecated
1213
from sklearn.utils.testing import (assert_raises_regex, assert_true,
1314
assert_equal, ignore_warnings)
1415
from sklearn.utils.estimator_checks import check_estimator
1516
from sklearn.utils.estimator_checks import set_random_state
1617
from sklearn.utils.estimator_checks import set_checking_parameters
1718
from sklearn.utils.estimator_checks import check_estimators_unfitted
19+
from sklearn.utils.estimator_checks import check_fit_score_takes_y
1820
from sklearn.utils.estimator_checks import check_no_attributes_set_in_init
1921
from sklearn.ensemble import AdaBoostClassifier, RandomForestClassifier
2022
from sklearn.linear_model import LinearRegression, SGDClassifier
@@ -176,6 +178,19 @@ def transform(self, X):
176178
return sp.csr_matrix(X)
177179

178180

181+
def test_check_fit_score_takes_y_works_on_deprecated_fit():
182+
# Tests that check_fit_score_takes_y works on a class with
183+
# a deprecated fit method
184+
185+
class TestEstimatorWithDeprecatedFitMethod(BaseEstimator):
186+
@deprecated("Deprecated for the purpose of testing "
187+
"check_fit_score_takes_y")
188+
def fit(self, X, y):
189+
return self
190+
191+
check_fit_score_takes_y("test", TestEstimatorWithDeprecatedFitMethod())
192+
193+
179194
def test_check_estimator():
180195
# tests that the estimator actually fails on "bad" estimators.
181196
# not a complete test of all checks, which are very extensive.

sklearn/utils/tests/test_validation.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from sklearn.utils.testing import assert_allclose_dense_sparse
2323
from sklearn.utils import as_float_array, check_array, check_symmetric
2424
from sklearn.utils import check_X_y
25+
from sklearn.utils import deprecated
2526
from sklearn.utils.mocking import MockDataFrame
2627
from sklearn.utils.estimator_checks import NotAnArray
2728
from sklearn.random_projection import sparse_random_matrix
@@ -563,6 +564,15 @@ def test_has_fit_parameter():
563564
assert_true(has_fit_parameter(SVR, "sample_weight"))
564565
assert_true(has_fit_parameter(SVR(), "sample_weight"))
565566

567+
class TestClassWithDeprecatedFitMethod:
568+
@deprecated("Deprecated for the purpose of testing has_fit_parameter")
569+
def fit(self, X, y, sample_weight=None):
570+
pass
571+
572+
assert has_fit_parameter(TestClassWithDeprecatedFitMethod,
573+
"sample_weight"), \
574+
"has_fit_parameter fails for class with deprecated fit method."
575+
566576

567577
def test_check_symmetric():
568578
arr_sym = np.array([[0, 1], [1, 2]])

0 commit comments

Comments
 (0)