Skip to content

Commit 652a873

Browse files
committed
ENH: add a new DataConversionWarning
for implicit conversions such as raveling the y
1 parent 0292f14 commit 652a873

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

sklearn/tests/test_common.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from sklearn.svm.base import BaseLibSVM
4444

4545
from sklearn.cross_validation import train_test_split
46+
from sklearn.utils.validation import DataConversionWarning
4647

4748
dont_test = ['SparseCoder', 'EllipticEnvelope', 'EllipticEnvelop',
4849
'DictVectorizer', 'LabelBinarizer', 'LabelEncoder',
@@ -495,7 +496,6 @@ def test_clustering():
495496
continue
496497
# catch deprecation and neighbors warnings
497498
with warnings.catch_warnings(record=True):
498-
warnings.simplefilter("always")
499499
alg = Alg()
500500
if hasattr(alg, "n_clusters"):
501501
alg.set_params(n_clusters=3)
@@ -663,9 +663,12 @@ def test_classifiers_input_shapes():
663663

664664
set_random_state(classifier)
665665
classifier.fit(X, y[:, np.newaxis])
666+
# Check that when a 2D y is given, a DataConversionWarning is
667+
# raised
666668
with warnings.catch_warnings(record=True) as w:
669+
warnings.simplefilter("always", DataConversionWarning)
667670
classifier.fit(X, y[:, np.newaxis])
668-
print(w)
671+
assert_equal(len(w), 1)
669672
assert_array_equal(y_pred, classifier.predict(X))
670673

671674

sklearn/utils/validation.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,14 @@
1212
from .fixes import safe_copy
1313

1414

15+
class DataConversionWarning(UserWarning):
16+
"A warning on implicit data conversions happening in the code"
17+
pass
18+
19+
20+
warnings.simplefilter("always", DataConversionWarning)
21+
22+
1523
def _assert_all_finite(X):
1624
"""Like assert_all_finite, but only for ndarray."""
1725
if (X.dtype.char in np.typecodes['AllFloat'] and not np.isfinite(X.sum())
@@ -250,7 +258,8 @@ def column_or_1d(y, warn=False):
250258
if warn:
251259
warnings.warn("A column-vector y was passed when a 1d array was"
252260
" expected. Please change the shape of y to "
253-
"(n_samples, ), for example using ravel().")
261+
"(n_samples, ), for example using ravel().",
262+
DataConversionWarning, stacklevel=2)
254263
return np.ravel(y)
255264

256265
raise ValueError("bad input shape {0}".format(shape))

0 commit comments

Comments
 (0)