Skip to content

Commit a8d6361

Browse files
shivamgargsyaShivamglemaitre
authored
TST replace assert_warns* by pytest.warns in module svm/tests (scikit-learn#19424)
Co-authored-by: Shivam <[email protected]> Co-authored-by: Guillaume Lemaitre <[email protected]>
1 parent 7f5ee92 commit a8d6361

File tree

2 files changed

+34
-14
lines changed

2 files changed

+34
-14
lines changed

sklearn/svm/tests/test_sparse.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,8 @@
99
from sklearn.svm.tests import test_svm
1010
from sklearn.exceptions import ConvergenceWarning
1111
from sklearn.utils.extmath import safe_sparse_dot
12-
from sklearn.utils._testing import (assert_warns,
13-
assert_raise_message, ignore_warnings,
14-
skip_if_32bit)
12+
from sklearn.utils._testing import (assert_raise_message, ignore_warnings,
13+
skip_if_32bit)
1514

1615

1716
# test sample 1
@@ -348,8 +347,12 @@ def test_sparse_svc_clone_with_callable_kernel():
348347
def test_timeout():
349348
sp = svm.SVC(C=1, kernel=lambda x, y: x * y.T,
350349
probability=True, random_state=0, max_iter=1)
351-
352-
assert_warns(ConvergenceWarning, sp.fit, X_sp, Y)
350+
warning_msg = (
351+
r'Solver terminated early \(max_iter=1\). Consider pre-processing '
352+
r'your data with StandardScaler or MinMaxScaler.'
353+
)
354+
with pytest.warns(ConvergenceWarning, match=warning_msg):
355+
sp.fit(X_sp, Y)
353356

354357

355358
def test_consistent_proba():

sklearn/svm/tests/test_svm.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,8 @@
1919
from sklearn.metrics import f1_score
2020
from sklearn.metrics.pairwise import rbf_kernel
2121
from sklearn.utils import check_random_state
22-
from sklearn.utils._testing import assert_warns
2322
from sklearn.utils._testing import assert_raise_message
2423
from sklearn.utils._testing import ignore_warnings
25-
from sklearn.utils._testing import assert_no_warnings
2624
from sklearn.utils.validation import _num_samples
2725
from sklearn.utils import shuffle
2826
from sklearn.exceptions import ConvergenceWarning
@@ -979,7 +977,12 @@ def test_svc_bad_kernel():
979977
def test_timeout():
980978
a = svm.SVC(kernel=lambda x, y: np.dot(x, y.T), probability=True,
981979
random_state=0, max_iter=1)
982-
assert_warns(ConvergenceWarning, a.fit, np.array(X), Y)
980+
warning_msg = (
981+
r'Solver terminated early \(max_iter=1\). Consider pre-processing '
982+
r'your data with StandardScaler or MinMaxScaler.'
983+
)
984+
with pytest.warns(ConvergenceWarning, match=warning_msg):
985+
a.fit(np.array(X), Y)
983986

984987

985988
def test_unfitted():
@@ -1008,11 +1011,16 @@ def test_linear_svm_convergence_warnings():
10081011
# Test that warnings are raised if model does not converge
10091012

10101013
lsvc = svm.LinearSVC(random_state=0, max_iter=2)
1011-
assert_warns(ConvergenceWarning, lsvc.fit, X, Y)
1014+
warning_msg = (
1015+
"Liblinear failed to converge, increase the number of iterations."
1016+
)
1017+
with pytest.warns(ConvergenceWarning, match=warning_msg):
1018+
lsvc.fit(X, Y)
10121019
assert lsvc.n_iter_ == 2
10131020

10141021
lsvr = svm.LinearSVR(random_state=0, max_iter=2)
1015-
assert_warns(ConvergenceWarning, lsvr.fit, iris.data, iris.target)
1022+
with pytest.warns(ConvergenceWarning, match=warning_msg):
1023+
lsvr.fit(iris.data, iris.target)
10161024
assert lsvr.n_iter_ == 2
10171025

10181026

@@ -1160,21 +1168,30 @@ def test_svc_ovr_tie_breaking(SVCClass):
11601168
def test_gamma_auto():
11611169
X, y = [[0.0, 1.2], [1.0, 1.3]], [0, 1]
11621170

1163-
assert_no_warnings(svm.SVC(kernel='linear').fit, X, y)
1164-
assert_no_warnings(svm.SVC(kernel='precomputed').fit, X, y)
1171+
with pytest.warns(None) as record:
1172+
svm.SVC(kernel='linear').fit(X, y)
1173+
assert not len(record)
1174+
1175+
with pytest.warns(None) as record:
1176+
svm.SVC(kernel='precomputed').fit(X, y)
1177+
assert not len(record)
11651178

11661179

11671180
def test_gamma_scale():
11681181
X, y = [[0.], [1.]], [0, 1]
11691182

11701183
clf = svm.SVC()
1171-
assert_no_warnings(clf.fit, X, y)
1184+
with pytest.warns(None) as record:
1185+
clf.fit(X, y)
1186+
assert not len(record)
11721187
assert_almost_equal(clf._gamma, 4)
11731188

11741189
# X_var ~= 1 shouldn't raise warning, for when
11751190
# gamma is not explicitly set.
11761191
X, y = [[1, 2], [3, 2 * np.sqrt(6) / 3 + 2]], [0, 1]
1177-
assert_no_warnings(clf.fit, X, y)
1192+
with pytest.warns(None) as record:
1193+
clf.fit(X, y)
1194+
assert not len(record)
11781195

11791196

11801197
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)