|
35 | 35 | from sklearn.linear_model import Ridge |
36 | 36 | from sklearn.svm import SVC |
37 | 37 |
|
| 38 | +from sklearn.preprocessing import Imputer |
| 39 | +from sklearn.pipeline import Pipeline |
| 40 | + |
38 | 41 |
|
39 | 42 | class MockListClassifier(BaseEstimator): |
40 | 43 | """Dummy classifier to test the cross-validation. |
@@ -852,3 +855,35 @@ def test_safe_split_with_precomputed_kernel(): |
852 | 855 | X_te, y_te = cval._safe_split(clf, X, y, te, tr) |
853 | 856 | K_te, y_te2 = cval._safe_split(clfp, K, y, te, tr) |
854 | 857 | assert_array_almost_equal(K_te, np.dot(X_te, X_tr.T)) |
| 858 | + |
| 859 | + |
| 860 | +def test_cross_val_score_allow_nans(): |
| 861 | + # Check that cross_val_score allows input data with NaNs |
| 862 | + X = np.arange(200, dtype=np.float64).reshape(10, -1) |
| 863 | + X[2, :] = np.nan |
| 864 | + y = np.repeat([0, 1], X.shape[0]/2) |
| 865 | + p = Pipeline([ |
| 866 | + ('imputer', Imputer(strategy='mean', missing_values='NaN')), |
| 867 | + ('classifier', MockClassifier()), |
| 868 | + ]) |
| 869 | + cval.cross_val_score(p, X, y, cv=5) |
| 870 | + |
| 871 | + |
| 872 | +def test_train_test_split_allow_nans(): |
| 873 | + # Check that train_test_split allows input data with NaNs |
| 874 | + X = np.arange(200, dtype=np.float64).reshape(10, -1) |
| 875 | + X[2, :] = np.nan |
| 876 | + y = np.repeat([0, 1], X.shape[0]/2) |
| 877 | + split = cval.train_test_split(X, y, test_size=0.2, random_state=42) |
| 878 | + |
| 879 | + |
| 880 | +def test_permutation_test_score_allow_nans(): |
| 881 | + # Check that permutation_test_score allows input data with NaNs |
| 882 | + X = np.arange(200, dtype=np.float64).reshape(10, -1) |
| 883 | + X[2, :] = np.nan |
| 884 | + y = np.repeat([0, 1], X.shape[0]/2) |
| 885 | + p = Pipeline([ |
| 886 | + ('imputer', Imputer(strategy='mean', missing_values='NaN')), |
| 887 | + ('classifier', MockClassifier()), |
| 888 | + ]) |
| 889 | + cval.permutation_test_score(p, X, y, cv=5) |
0 commit comments