Skip to content

Commit 0c6b173

Browse files
committed
Add test to check sparse predictions in cross_val_predict
1 parent 26d3323 commit 0c6b173

File tree

1 file changed

+20
-1
lines changed

1 file changed

+20
-1
lines changed

sklearn/tests/test_cross_validation.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import numpy as np
66
from scipy.sparse import coo_matrix
7+
from scipy.sparse import csr_matrix
78
from scipy import stats
89

910
from sklearn.utils.testing import assert_true
@@ -25,14 +26,15 @@
2526
from sklearn.datasets import load_boston
2627
from sklearn.datasets import load_digits
2728
from sklearn.datasets import load_iris
29+
from sklearn.datasets import make_multilabel_classification
2830
from sklearn.metrics import explained_variance_score
2931
from sklearn.metrics import make_scorer
3032
from sklearn.metrics import precision_score
31-
3233
from sklearn.externals import six
3334
from sklearn.externals.six.moves import zip
3435

3536
from sklearn.linear_model import Ridge
37+
from sklearn.multiclass import OneVsRestClassifier
3638
from sklearn.neighbors import KNeighborsClassifier
3739
from sklearn.svm import SVC
3840
from sklearn.cluster import KMeans
@@ -1094,3 +1096,20 @@ def test_check_is_partition():
10941096

10951097
p[0] = 23
10961098
assert_false(cval._check_is_partition(p, 100))
1099+
1100+
def test_cross_val_predict_sparse_prediction():
1101+
"""Check that cross_val_predict gives the same result for sparse and dense inputs"""
1102+
X, Y = make_multilabel_classification(n_classes=2, n_labels=1,
1103+
allow_unlabeled=False,
1104+
return_indicator=True,
1105+
random_state=1)
1106+
X_sparse = csr_matrix(X)
1107+
Y_sparse = csr_matrix(Y)
1108+
classif = OneVsRestClassifier(SVC(kernel='linear'))
1109+
preds = cval.cross_val_predict(classif, X,
1110+
Y, cv=10)
1111+
preds_sparse = cval.cross_val_predict(classif, X_sparse,
1112+
Y_sparse, cv=10)
1113+
preds_sparse = preds_sparse.toarray()
1114+
assert_array_almost_equal(preds_sparse, preds)
1115+

0 commit comments

Comments
 (0)