|
4 | 4 |
|
5 | 5 | import numpy as np |
6 | 6 | from scipy.sparse import coo_matrix |
| 7 | +from scipy.sparse import csr_matrix |
7 | 8 | from scipy import stats |
8 | 9 |
|
9 | 10 | from sklearn.utils.testing import assert_true |
|
25 | 26 | from sklearn.datasets import load_boston |
26 | 27 | from sklearn.datasets import load_digits |
27 | 28 | from sklearn.datasets import load_iris |
| 29 | +from sklearn.datasets import make_multilabel_classification |
28 | 30 | from sklearn.metrics import explained_variance_score |
29 | 31 | from sklearn.metrics import make_scorer |
30 | 32 | from sklearn.metrics import precision_score |
31 | | - |
32 | 33 | from sklearn.externals import six |
33 | 34 | from sklearn.externals.six.moves import zip |
34 | 35 |
|
35 | 36 | from sklearn.linear_model import Ridge |
| 37 | +from sklearn.multiclass import OneVsRestClassifier |
36 | 38 | from sklearn.neighbors import KNeighborsClassifier |
37 | 39 | from sklearn.svm import SVC |
38 | 40 | from sklearn.cluster import KMeans |
@@ -1094,3 +1096,20 @@ def test_check_is_partition(): |
1094 | 1096 |
|
1095 | 1097 | p[0] = 23 |
1096 | 1098 | assert_false(cval._check_is_partition(p, 100)) |
| 1099 | + |
| 1100 | +def test_cross_val_predict_sparse_prediction(): |
| 1101 | + """Check that cross_val_predict gives the same result for sparse and dense inputs""" |
| 1102 | + X, Y = make_multilabel_classification(n_classes=2, n_labels=1, |
| 1103 | + allow_unlabeled=False, |
| 1104 | + return_indicator=True, |
| 1105 | + random_state=1) |
| 1106 | + X_sparse = csr_matrix(X) |
| 1107 | + Y_sparse = csr_matrix(Y) |
| 1108 | + classif = OneVsRestClassifier(SVC(kernel='linear')) |
| 1109 | + preds = cval.cross_val_predict(classif, X, |
| 1110 | + Y, cv=10) |
| 1111 | + preds_sparse = cval.cross_val_predict(classif, X_sparse, |
| 1112 | + Y_sparse, cv=10) |
| 1113 | + preds_sparse = preds_sparse.toarray() |
| 1114 | + assert_array_almost_equal(preds_sparse, preds) |
| 1115 | + |
0 commit comments