|
9 | 9 | from sklearn.externals import joblib |
10 | 10 |
|
11 | 11 | from sklearn.base import BaseEstimator, ClassifierMixin |
| 12 | +from sklearn.utils import deprecated |
12 | 13 | from sklearn.utils.testing import (assert_raises_regex, assert_true, |
13 | 14 | assert_equal, ignore_warnings) |
14 | 15 | from sklearn.utils.estimator_checks import check_estimator |
15 | 16 | from sklearn.utils.estimator_checks import set_random_state |
16 | 17 | from sklearn.utils.estimator_checks import set_checking_parameters |
17 | 18 | from sklearn.utils.estimator_checks import check_estimators_unfitted |
| 19 | +from sklearn.utils.estimator_checks import check_fit_score_takes_y |
18 | 20 | from sklearn.utils.estimator_checks import check_no_attributes_set_in_init |
19 | 21 | from sklearn.ensemble import AdaBoostClassifier, RandomForestClassifier |
20 | 22 | from sklearn.linear_model import LinearRegression, SGDClassifier |
@@ -176,6 +178,19 @@ def transform(self, X): |
176 | 178 | return sp.csr_matrix(X) |
177 | 179 |
|
178 | 180 |
|
| 181 | +def test_check_fit_score_takes_y_works_on_deprecated_fit(): |
| 182 | + # Tests that check_fit_score_takes_y works on a class with |
| 183 | + # a deprecated fit method |
| 184 | + |
| 185 | + class TestEstimatorWithDeprecatedFitMethod(BaseEstimator): |
| 186 | + @deprecated("Deprecated for the purpose of testing " |
| 187 | + "check_fit_score_takes_y") |
| 188 | + def fit(self, X, y): |
| 189 | + return self |
| 190 | + |
| 191 | + check_fit_score_takes_y("test", TestEstimatorWithDeprecatedFitMethod()) |
| 192 | + |
| 193 | + |
179 | 194 | def test_check_estimator(): |
180 | 195 | # tests that the estimator actually fails on "bad" estimators. |
181 | 196 | # not a complete test of all checks, which are very extensive. |
|
0 commit comments