Skip to content

Commit 3fc4ff3

Browse files
committed
Merge pull request scikit-learn#5098 from olologin/OneClassSvm_sparse_test
[MRG + 1] OneClassSVM sparsity regression test added
2 parents 0706636 + 617baa1 commit 3fc4ff3

File tree

1 file changed

+23
-3
lines changed

1 file changed

+23
-3
lines changed

sklearn/svm/tests/test_sparse.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,12 @@ def check_svm_model_equal(dense_svm, sparse_svm, X_train, y_train, X_test):
5959
sparse_svm.decision_function(X_test))
6060
assert_array_almost_equal(dense_svm.decision_function(X_test_dense),
6161
sparse_svm.decision_function(X_test_dense))
62-
assert_array_almost_equal(dense_svm.predict_proba(X_test_dense),
63-
sparse_svm.predict_proba(X_test), 4)
64-
msg = "cannot use sparse input in 'SVC' trained on dense data"
62+
if isinstance(dense_svm, svm.OneClassSVM):
63+
msg = "cannot use sparse input in 'OneClassSVM' trained on dense data"
64+
else:
65+
assert_array_almost_equal(dense_svm.predict_proba(X_test_dense),
66+
sparse_svm.predict_proba(X_test), 4)
67+
msg = "cannot use sparse input in 'SVC' trained on dense data"
6568
if sparse.isspmatrix(X_test):
6669
assert_raise_message(ValueError, msg, dense_svm.predict, X_test)
6770

@@ -255,6 +258,23 @@ def test_sparse_liblinear_intercept_handling():
255258
test_svm.test_dense_liblinear_intercept_handling(svm.LinearSVC)
256259

257260

261+
def test_sparse_oneclasssvm():
262+
"""Check that sparse OneClassSVM gives the same result as dense OneClassSVM"""
263+
# many class dataset:
264+
X_blobs, _ = make_blobs(n_samples=100, centers=10, random_state=0)
265+
X_blobs = sparse.csr_matrix(X_blobs)
266+
267+
datasets = [[X_sp, None, T], [X2_sp, None, T2],
268+
[X_blobs[:80], None, X_blobs[80:]],
269+
[iris.data, None, iris.data]]
270+
kernels = ["linear", "poly", "rbf", "sigmoid"]
271+
for dataset in datasets:
272+
for kernel in kernels:
273+
clf = svm.OneClassSVM(kernel=kernel, random_state=0)
274+
sp_clf = svm.OneClassSVM(kernel=kernel, random_state=0)
275+
check_svm_model_equal(clf, sp_clf, *dataset)
276+
277+
258278
def test_sparse_realdata():
259279
# Test on a subset from the 20newsgroups dataset.
260280
# This catchs some bugs if input is not correctly converted into

0 commit comments

Comments
 (0)