Skip to content

Commit 4c10c82

Browse files
author
dengemann
committed
ENH: address discussion
1 parent 163e659 commit 4c10c82

File tree

4 files changed

+97
-30
lines changed

4 files changed

+97
-30
lines changed

sklearn/feature_extraction/tests/test_text.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from numpy.testing import assert_array_equal
2929
from numpy.testing import assert_raises
3030
from sklearn.utils.testing import (assert_in, assert_less, assert_greater,
31-
assert_warns)
31+
assert_warns_message)
3232

3333
from collections import defaultdict, Mapping
3434
from functools import partial
@@ -180,8 +180,11 @@ def test_unicode_decode_error():
180180
assert_raises(UnicodeDecodeError, ca, text_bytes)
181181

182182
# Check the old interface
183-
ca = assert_warns(DeprecationWarning, CountVectorizer, analyzer='char',
184-
ngram_range=(3, 6), charset='ascii').build_analyzer()
183+
in_warning_message = 'charset'
184+
ca = assert_warns_message(DeprecationWarning, in_warning_message,
185+
CountVectorizer, analyzer='char',
186+
ngram_range=(3, 6),
187+
charset='ascii').build_analyzer()
185188
assert_raises(UnicodeDecodeError, ca, text_bytes)
186189

187190

@@ -349,7 +352,9 @@ def test_tfidf_no_smoothing():
349352
1. / np.array([0.])
350353
numpy_provides_div0_warning = len(w) == 1
351354

352-
tfidf = assert_warns(RuntimeWarning,tr.fit_transform, X).toarray()
355+
in_warning_message = 'divide by zero'
356+
tfidf = assert_warns_message(RuntimeWarning, in_warning_message,
357+
tr.fit_transform, X).toarray()
353358
if not numpy_provides_div0_warning:
354359
raise SkipTest("Numpy does not provide div 0 warnings.")
355360

sklearn/linear_model/tests/test_least_angle.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from sklearn.utils.testing import assert_less
99
from sklearn.utils.testing import assert_greater
1010
from sklearn.utils.testing import assert_raises
11-
from sklearn.utils.testing import ignore_warnings
11+
from sklearn.utils.testing import ignore_warnings, assert_warns_message
1212
from sklearn import linear_model, datasets
1313

1414
diabetes = datasets.load_diabetes()
@@ -182,7 +182,10 @@ def test_singular_matrix():
182182
# to give a good answer
183183
X1 = np.array([[1, 1.], [1., 1.]])
184184
y1 = np.array([1, 1])
185-
alphas, active, coef_path = ignore_warnings(linear_model.lars_path)(X1, y1)
185+
in_warn_message = 'Dropping a regressor'
186+
f = assert_warns_message
187+
alphas, active, coef_path = f(UserWarning, in_warn_message,
188+
linear_model.lars_path, X1, y1)
186189
assert_array_almost_equal(coef_path.T, [[0, 0], [1, 0]])
187190

188191

@@ -315,22 +318,27 @@ def test_lasso_lars_vs_lasso_cd_ill_conditioned():
315318
y += sigma * rng.rand(*y.shape)
316319
y = y.squeeze()
317320

318-
f = ignore_warnings
319-
lars_alphas, _, lars_coef = f(linear_model.lars_path)(X, y,
320-
method='lasso')
321-
322-
_, lasso_coef2, _ = f(linear_model.lasso_path)(X, y,
323-
alphas=lars_alphas,
324-
tol=1e-6,
325-
fit_intercept=False)
326-
327-
lasso_coef = np.zeros((w.shape[0], len(lars_alphas)))
328-
for i, model in enumerate(f(linear_model.lasso_path)(X, y,
321+
f = assert_warns_message
322+
def in_warn_message(msg):
323+
return 'Early stopping' in msg or 'Dropping regressor' in msg
324+
lars_alphas, _, lars_coef = f(UserWarning,
325+
in_warn_message,
326+
linear_model.lars_path, X, y, method='lasso')
327+
328+
with ignore_warnings():
329+
_, lasso_coef2, _ = linear_model.lasso_path(X, y,
330+
alphas=lars_alphas,
331+
tol=1e-6,
332+
fit_intercept=False)
333+
334+
lasso_coef = np.zeros((w.shape[0], len(lars_alphas)))
335+
iter_models = enumerate(linear_model.lasso_path(X, y,
329336
alphas=lars_alphas,
330337
tol=1e-6,
331338
return_models=True,
332-
fit_intercept=False)):
333-
lasso_coef[:, i] = model.coef_
339+
fit_intercept=False))
340+
for i, model in iter_models:
341+
lasso_coef[:, i] = model.coef_
334342

335343
np.testing.assert_array_almost_equal(lars_coef, lasso_coef, decimal=1)
336344
np.testing.assert_array_almost_equal(lars_coef, lasso_coef2, decimal=1)

sklearn/metrics/tests/test_metrics.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1826,42 +1826,42 @@ def test_prf_warnings():
18261826
for average in [None, 'weighted', 'macro']:
18271827
msg = ('Precision and F-score are ill-defined and '
18281828
'being set to 0.0 in labels with no predicted samples.')
1829-
my_assert(w, f, msg, [0, 1, 2], [1, 1, 2], average=average)
1829+
my_assert(w, msg, f, [0, 1, 2], [1, 1, 2], average=average)
18301830

18311831
msg = ('Recall and F-score are ill-defined and '
18321832
'being set to 0.0 in labels with no true samples.')
1833-
my_assert(w, f, msg, [1, 1, 2], [0, 1, 2], average=average)
1833+
my_assert(w, msg, f, [1, 1, 2], [0, 1, 2], average=average)
18341834

18351835
# average of per-sample scores
18361836
msg = ('Precision and F-score are ill-defined and '
18371837
'being set to 0.0 in samples with no predicted labels.')
1838-
my_assert(w, f, msg, np.array([[1, 0], [1, 0]]),
1838+
my_assert(w, msg, f, np.array([[1, 0], [1, 0]]),
18391839
np.array([[1, 0], [0, 0]]), average='samples')
18401840

18411841
msg = ('Recall and F-score are ill-defined and '
18421842
'being set to 0.0 in samples with no true labels.')
1843-
my_assert(w, f, msg, np.array([[1, 0], [0, 0]]), np.array([[1, 0], [1, 0]]),
1843+
my_assert(w, msg, f, np.array([[1, 0], [0, 0]]), np.array([[1, 0], [1, 0]]),
18441844
average='samples')
18451845

18461846
# single score: micro-average
18471847
msg = ('Precision and F-score are ill-defined and '
18481848
'being set to 0.0 due to no predicted samples.')
1849-
my_assert(w, f, msg, np.array([[1, 1], [1, 1]]),
1849+
my_assert(w, msg, f, np.array([[1, 1], [1, 1]]),
18501850
np.array([[0, 0], [0, 0]]), average='micro')
18511851

18521852
msg =('Recall and F-score are ill-defined and '
18531853
'being set to 0.0 due to no true samples.')
1854-
my_assert(w, f, msg, np.array([[0, 0], [0, 0]]),
1854+
my_assert(w, msg, f, np.array([[0, 0], [0, 0]]),
18551855
np.array([[1, 1], [1, 1]]), average='micro')
18561856

18571857
# single postive label
18581858
msg = ('Precision and F-score are ill-defined and '
18591859
'being set to 0.0 due to no predicted samples.')
1860-
my_assert(w, f, msg, [1, 1], [-1, -1], average='macro')
1860+
my_assert(w, msg, f, [1, 1], [-1, -1], average='macro')
18611861

18621862
msg = ('Recall and F-score are ill-defined and '
18631863
'being set to 0.0 due to no true samples.')
1864-
my_assert(w, f, msg, [-1, -1], [1, 1], average='macro')
1864+
my_assert(w, msg, f, [-1, -1], [1, 1], average='macro')
18651865

18661866

18671867
def test__check_clf_targets():

sklearn/utils/testing.py

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,27 @@ def _assert_greater(a, b, msg=None):
8383

8484
# To remove when we support numpy 1.7
8585
def assert_warns(warning_class, func, *args, **kw):
86+
"""Test that a certain warning occurs.
87+
88+
Parameters
89+
----------
90+
warning_class : the warning class
91+
The class to test for, e.g. UserWarning.
92+
93+
func : callable
94+
Calable object to trigger warnings.
95+
96+
*args : the positional arguments to `func`.
97+
98+
**kw : the keyword arguments to `func`
99+
100+
Returns
101+
-------
102+
103+
result : the return value of `func`
104+
105+
"""
106+
86107
# very important to avoid uncontrolled state propagation
87108
clean_warning_registry()
88109
with warnings.catch_warnings(record=True) as w:
@@ -102,8 +123,34 @@ def assert_warns(warning_class, func, *args, **kw):
102123

103124
return result
104125

105-
def assert_warns_message(warning_class, func, message, *args, **kw):
126+
127+
def assert_warns_message(warning_class, message, func, *args, **kw):
106128
# very important to avoid uncontrolled state propagation
129+
"""Test that a certain warning occurs and with a certain message.
130+
131+
Parameters
132+
----------
133+
warning_class : the warning class
134+
The class to test for, e.g. UserWarning.
135+
136+
message : str | callable
137+
The entire message or a substring to test for. If callable,
138+
it takes a string as argument and will trigger an assertion error
139+
if it returns `False`.
140+
141+
func : callable
142+
Calable object to trigger warnings.
143+
144+
*args : the positional arguments to `func`.
145+
146+
**kw : the keyword arguments to `func`.
147+
148+
Returns
149+
-------
150+
151+
result : the return value of `func`
152+
153+
"""
107154
clean_warning_registry()
108155
with warnings.catch_warnings(record=True) as w:
109156
# Cause all warnings to always be triggered.
@@ -119,8 +166,15 @@ def assert_warns_message(warning_class, func, message, *args, **kw):
119166
raise AssertionError("First warning for %s is not a "
120167
"%s( is %s)"
121168
% (func.__name__, warning_class, w[0]))
122-
msg = str(w[0].message)
123-
if msg != message:
169+
170+
# substring will match, the entire message with typo won't
171+
msg = w[0].message # For Python 3 compatibility
172+
msg = str(msg.args[0] if hasattr(msg, 'args') else msg)
173+
if callable(message): # add support for certain tests
174+
check_in_message = message
175+
else:
176+
check_in_message = lambda msg : message in msg
177+
if not check_in_message(msg):
124178
raise AssertionError("The message received ('%s') for <%s> is "
125179
"not the one you expected ('%s')"
126180
% (msg, func.__name__, message

0 commit comments

Comments
 (0)