Skip to content

Commit 203298e

Browse files
committed
Merge pull request scikit-learn#4541 from amueller/robust_input_dtype_check
[MRG + 1] FIX be robust to columns name dtype, robust dtype checking
2 parents f99830c + 89c1dc5 commit 203298e

File tree

2 files changed

+29
-4
lines changed

2 files changed

+29
-4
lines changed

sklearn/utils/tests/test_validation.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from sklearn.utils.testing import assert_raises_regexp
1414
from sklearn.utils import as_float_array, check_array, check_symmetric
1515
from sklearn.utils import check_X_y
16+
from sklearn.utils.mocking import MockDataFrame
1617
from sklearn.utils.estimator_checks import NotAnArray
1718
from sklearn.random_projection import sparse_random_matrix
1819
from sklearn.linear_model import ARDRegression
@@ -218,6 +219,25 @@ def test_check_array():
218219
assert_true(isinstance(result, np.ndarray))
219220

220221

222+
def test_check_array_pandas_dtype_object_conversion():
223+
# test that data-frame like objects with dtype object
224+
# get converted
225+
X = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.object)
226+
X_df = MockDataFrame(X)
227+
assert_equal(check_array(X_df).dtype.kind, "f")
228+
assert_equal(check_array(X_df, ensure_2d=False).dtype.kind, "f")
229+
# smoke-test against dataframes with column named "dtype"
230+
X_df.dtype = "Hans"
231+
assert_equal(check_array(X_df, ensure_2d=False).dtype.kind, "f")
232+
233+
234+
def test_check_array_dtype_stability():
235+
# test that lists with ints don't get converted to floats
236+
X = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
237+
assert_equal(check_array(X).dtype.kind, "i")
238+
assert_equal(check_array(X, ensure_2d=False).dtype.kind, "i")
239+
240+
221241
def test_check_array_min_samples_and_features_messages():
222242
# empty list is considered 2D by default:
223243
msg = "0 feature(s) (shape=(1, 0)) while a minimum of 1 is required."

sklearn/utils/validation.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -324,21 +324,27 @@ def check_array(array, accept_sparse=None, dtype="numeric", order=None,
324324
if isinstance(accept_sparse, str):
325325
accept_sparse = [accept_sparse]
326326

327+
# store whether originally we wanted numeric dtype
328+
dtype_numeric = dtype == "numeric"
329+
327330
if sp.issparse(array):
328-
if dtype == "numeric":
331+
if dtype_numeric:
329332
dtype = None
330333
array = _ensure_sparse_format(array, accept_sparse, dtype, order,
331334
copy, force_all_finite)
332335
else:
333336
if ensure_2d:
334337
array = np.atleast_2d(array)
335-
if dtype == "numeric":
336-
if hasattr(array, "dtype") and array.dtype.kind == "O":
338+
if dtype_numeric:
339+
if hasattr(array, "dtype") and getattr(array.dtype, "kind", None) == "O":
337340
# if input is object, convert to float.
338341
dtype = np.float64
339342
else:
340343
dtype = None
341344
array = np.array(array, dtype=dtype, order=order, copy=copy)
345+
# make sure we actually converted to numeric:
346+
if dtype_numeric and array.dtype.kind == "O":
347+
array = array.astype(np.float64)
342348
if not allow_nd and array.ndim >= 3:
343349
raise ValueError("Found array with dim %d. Expected <= 2" %
344350
array.ndim)
@@ -353,7 +359,6 @@ def check_array(array, accept_sparse=None, dtype="numeric", order=None,
353359
" minimum of %d is required."
354360
% (n_samples, shape_repr, ensure_min_samples))
355361

356-
357362
if ensure_min_features > 0 and array.ndim == 2:
358363
n_features = array.shape[1]
359364
if n_features < ensure_min_features:

0 commit comments

Comments
 (0)