@@ -59,9 +59,12 @@ def check_svm_model_equal(dense_svm, sparse_svm, X_train, y_train, X_test):
5959 sparse_svm .decision_function (X_test ))
6060 assert_array_almost_equal (dense_svm .decision_function (X_test_dense ),
6161 sparse_svm .decision_function (X_test_dense ))
62- assert_array_almost_equal (dense_svm .predict_proba (X_test_dense ),
63- sparse_svm .predict_proba (X_test ), 4 )
64- msg = "cannot use sparse input in 'SVC' trained on dense data"
62+ if isinstance (dense_svm , svm .OneClassSVM ):
63+ msg = "cannot use sparse input in 'OneClassSVM' trained on dense data"
64+ else :
65+ assert_array_almost_equal (dense_svm .predict_proba (X_test_dense ),
66+ sparse_svm .predict_proba (X_test ), 4 )
67+ msg = "cannot use sparse input in 'SVC' trained on dense data"
6568 if sparse .isspmatrix (X_test ):
6669 assert_raise_message (ValueError , msg , dense_svm .predict , X_test )
6770
@@ -255,6 +258,23 @@ def test_sparse_liblinear_intercept_handling():
255258 test_svm .test_dense_liblinear_intercept_handling (svm .LinearSVC )
256259
257260
261+ def test_sparse_oneclasssvm ():
262+ """Check that sparse OneClassSVM gives the same result as dense OneClassSVM"""
263+ # many class dataset:
264+ X_blobs , _ = make_blobs (n_samples = 100 , centers = 10 , random_state = 0 )
265+ X_blobs = sparse .csr_matrix (X_blobs )
266+
267+ datasets = [[X_sp , None , T ], [X2_sp , None , T2 ],
268+ [X_blobs [:80 ], None , X_blobs [80 :]],
269+ [iris .data , None , iris .data ]]
270+ kernels = ["linear" , "poly" , "rbf" , "sigmoid" ]
271+ for dataset in datasets :
272+ for kernel in kernels :
273+ clf = svm .OneClassSVM (kernel = kernel , random_state = 0 )
274+ sp_clf = svm .OneClassSVM (kernel = kernel , random_state = 0 )
275+ check_svm_model_equal (clf , sp_clf , * dataset )
276+
277+
258278def test_sparse_realdata ():
259279 # Test on a subset from the 20newsgroups dataset.
260280 # This catchs some bugs if input is not correctly converted into
0 commit comments