Skip to content

Commit 3f9dff9

Browse files
YS-Llarsmans
authored andcommitted
FIX: Add allow_nans option to check_arrays
Grid search and cross validation should not panic when seeing NaNs in the input arrays, because that breaks Imputer. Fixes scikit-learn#2774 and scikit-learn#3044.
1 parent b6eedcc commit 3f9dff9

File tree

6 files changed

+69
-5
lines changed

6 files changed

+69
-5
lines changed

doc/whats_new.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,11 @@ Changelog
182182
:class:`cluster.WardAgglomeration` when no samples are given,
183183
rather than returning meaningless clustering.
184184

185+
- Grid search and cross validation allow NaNs in the input arrays so that
186+
preprocessors such as :class:`preprocessing.Imputer
187+
<preprocessing.Imputer>` can be trained within the cross validation loop,
188+
avoiding potentially skewed results.
189+
185190

186191
API changes summary
187192
-------------------

sklearn/cross_validation.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1097,7 +1097,8 @@ def cross_val_score(estimator, X, y=None, scoring=None, cv=None, n_jobs=1,
10971097
scores : array of float, shape=(len(list(cv)),)
10981098
Array of scores of the estimator for each run of the cross validation.
10991099
"""
1100-
X, y = check_arrays(X, y, sparse_format='csr', allow_lists=True)
1100+
X, y = check_arrays(X, y, sparse_format='csr', allow_lists=True,
1101+
allow_nans=True)
11011102
if y is not None:
11021103
y = np.asarray(y)
11031104

@@ -1408,7 +1409,7 @@ def permutation_test_score(estimator, X, y, score_func=None, cv=None,
14081409
vol. 11
14091410
14101411
"""
1411-
X, y = check_arrays(X, y, sparse_format='csr')
1412+
X, y = check_arrays(X, y, sparse_format='csr', allow_nans=True)
14121413
cv = _check_cv(cv, X, y, classifier=is_classifier(estimator))
14131414
scorer = check_scoring(estimator, scoring=scoring, score_func=score_func)
14141415
random_state = check_random_state(random_state)
@@ -1505,6 +1506,7 @@ def train_test_split(*arrays, **options):
15051506
train_size = options.pop('train_size', None)
15061507
random_state = options.pop('random_state', None)
15071508
options['sparse_format'] = 'csr'
1509+
options['allow_nans'] = True
15081510

15091511
if test_size is None and train_size is None:
15101512
test_size = 0.25

sklearn/grid_search.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,8 @@ def _fit(self, X, y, parameter_iterable):
349349
score_func=self.score_func)
350350

351351
n_samples = _num_samples(X)
352-
X, y = check_arrays(X, y, allow_lists=True, sparse_format='csr')
352+
X, y = check_arrays(X, y, allow_lists=True, sparse_format='csr',
353+
allow_nans=True)
353354

354355
if y is not None:
355356
if len(y) != n_samples:

sklearn/tests/test_cross_validation.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@
3535
from sklearn.linear_model import Ridge
3636
from sklearn.svm import SVC
3737

38+
from sklearn.preprocessing import Imputer
39+
from sklearn.pipeline import Pipeline
40+
3841

3942
class MockListClassifier(BaseEstimator):
4043
"""Dummy classifier to test the cross-validation.
@@ -852,3 +855,35 @@ def test_safe_split_with_precomputed_kernel():
852855
X_te, y_te = cval._safe_split(clf, X, y, te, tr)
853856
K_te, y_te2 = cval._safe_split(clfp, K, y, te, tr)
854857
assert_array_almost_equal(K_te, np.dot(X_te, X_tr.T))
858+
859+
860+
def test_cross_val_score_allow_nans():
861+
# Check that cross_val_score allows input data with NaNs
862+
X = np.arange(200, dtype=np.float64).reshape(10, -1)
863+
X[2, :] = np.nan
864+
y = np.repeat([0, 1], X.shape[0]/2)
865+
p = Pipeline([
866+
('imputer', Imputer(strategy='mean', missing_values='NaN')),
867+
('classifier', MockClassifier()),
868+
])
869+
cval.cross_val_score(p, X, y, cv=5)
870+
871+
872+
def test_train_test_split_allow_nans():
873+
# Check that train_test_split allows input data with NaNs
874+
X = np.arange(200, dtype=np.float64).reshape(10, -1)
875+
X[2, :] = np.nan
876+
y = np.repeat([0, 1], X.shape[0]/2)
877+
split = cval.train_test_split(X, y, test_size=0.2, random_state=42)
878+
879+
880+
def test_permutation_test_score_allow_nans():
881+
# Check that permutation_test_score allows input data with NaNs
882+
X = np.arange(200, dtype=np.float64).reshape(10, -1)
883+
X[2, :] = np.nan
884+
y = np.repeat([0, 1], X.shape[0]/2)
885+
p = Pipeline([
886+
('imputer', Imputer(strategy='mean', missing_values='NaN')),
887+
('classifier', MockClassifier()),
888+
])
889+
cval.permutation_test_score(p, X, y, cv=5)

sklearn/tests/test_grid_search.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
from sklearn.metrics import make_scorer
4040
from sklearn.metrics import roc_auc_score
4141
from sklearn.cross_validation import KFold, StratifiedKFold
42+
from sklearn.preprocessing import Imputer
43+
from sklearn.pipeline import Pipeline
4244

4345

4446
# Neither of the following two estimators inherit from BaseEstimator,
@@ -654,3 +656,15 @@ def test_predict_proba_disabled():
654656
clf = SVC(probability=False)
655657
gs = GridSearchCV(clf, {}, cv=2).fit(X, y)
656658
assert_false(hasattr(gs, "predict_proba"))
659+
660+
661+
def test_grid_search_allows_nans():
662+
""" Test GridSearchCV with Imputer """
663+
X = np.arange(20, dtype=np.float64).reshape(5, -1)
664+
X[2, :] = np.nan
665+
y = [0, 0, 1, 1, 1]
666+
p = Pipeline([
667+
('imputer', Imputer(strategy='mean', missing_values='NaN')),
668+
('classifier', MockClassifier()),
669+
])
670+
gs = GridSearchCV(p, {'classifier__foo_param': [1, 2, 3]}, cv=2).fit(X, y)

sklearn/utils/validation.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,9 @@ def check_arrays(*arrays, **options):
211211
allow_lists : bool
212212
Allow lists of arbitrary objects as input, just check their length.
213213
Disables
214+
215+
allow_nans : boolean, False by default
216+
Allows nans in the arrays
214217
"""
215218
sparse_format = options.pop('sparse_format', None)
216219
if sparse_format not in (None, 'csr', 'csc', 'dense'):
@@ -219,6 +222,8 @@ def check_arrays(*arrays, **options):
219222
check_ccontiguous = options.pop('check_ccontiguous', False)
220223
dtype = options.pop('dtype', None)
221224
allow_lists = options.pop('allow_lists', False)
225+
allow_nans = options.pop('allow_nans', False)
226+
222227
if options:
223228
raise TypeError("Unexpected keyword arguments: %r" % options.keys())
224229

@@ -254,13 +259,15 @@ def check_arrays(*arrays, **options):
254259
array.data = np.ascontiguousarray(array.data, dtype=dtype)
255260
else:
256261
array.data = np.asarray(array.data, dtype=dtype)
257-
_assert_all_finite(array.data)
262+
if not allow_nans:
263+
_assert_all_finite(array.data)
258264
else:
259265
if check_ccontiguous:
260266
array = np.ascontiguousarray(array, dtype=dtype)
261267
else:
262268
array = np.asarray(array, dtype=dtype)
263-
_assert_all_finite(array)
269+
if not allow_nans:
270+
_assert_all_finite(array)
264271

265272
if array.ndim >= 3:
266273
raise ValueError("Found array with dim %d. Expected <= 2" %

0 commit comments

Comments
 (0)