|
26 | 26 | from sklearn.datasets import make_classification |
27 | 27 |
|
28 | 28 | from sklearn.cross_validation import train_test_split |
| 29 | +from sklearn.linear_model.base import LinearClassifierMixin |
29 | 30 | from sklearn.utils.estimator_checks import ( |
30 | 31 | check_parameters_default_constructible, |
31 | 32 | check_regressors_classifiers_sparse_data, |
|
44 | 45 | check_classifiers_pickle, |
45 | 46 | check_class_weight_classifiers, |
46 | 47 | check_class_weight_auto_classifiers, |
| 48 | + check_class_weight_auto_linear_classifier, |
47 | 49 | check_estimators_overwrite_params, |
48 | 50 | check_cluster_overwrite_params, |
49 | 51 | check_sparsify_binary_classifier, |
@@ -214,7 +216,7 @@ def test_class_weight_classifiers(): |
214 | 216 | yield check_class_weight_classifiers, name, Classifier |
215 | 217 |
|
216 | 218 |
|
217 | | -def test_class_weight_auto_classifies(): |
| 219 | +def test_class_weight_auto_classifiers(): |
218 | 220 | """Test that class_weight="auto" improves f1-score""" |
219 | 221 |
|
220 | 222 | # This test is broken; its success depends on: |
@@ -251,6 +253,26 @@ def test_class_weight_auto_classifies(): |
251 | 253 | X_train, y_train, X_test, y_test, weights) |
252 | 254 |
|
253 | 255 |
|
| 256 | +def test_class_weight_auto_linear_classifiers(): |
| 257 | + classifiers = all_estimators(type_filter='classifier') |
| 258 | + |
| 259 | + with warnings.catch_warnings(record=True): |
| 260 | + linear_classifiers = [ |
| 261 | + (name, clazz) |
| 262 | + for name, clazz in classifiers |
| 263 | + if 'class_weight' in clazz().get_params().keys() |
| 264 | + and issubclass(clazz, LinearClassifierMixin)] |
| 265 | + |
| 266 | + for name, Classifier in linear_classifiers: |
| 267 | + if name == "LogisticRegressionCV": |
| 268 | + # Contrary to RidgeClassifierCV, LogisticRegressionCV use actual |
| 269 | + # CV folds and fit a model for each CV iteration before averaging |
| 270 | + # the coef. Therefore it is expected to not behave exactly as the |
| 271 | + # other linear model. |
| 272 | + continue |
| 273 | + yield check_class_weight_auto_linear_classifier, name, Classifier |
| 274 | + |
| 275 | + |
254 | 276 | def test_estimators_overwrite_params(): |
255 | 277 | # test whether any classifier overwrites his init parameters during fit |
256 | 278 | for est_type in ["classifier", "regressor", "transformer"]: |
|
0 commit comments