Skip to content

Commit c9714b5

Browse files
committed
FIX: Changed str vs float invariance test
1 parent 86b7051 commit c9714b5

File tree

1 file changed

+16
-3
lines changed

1 file changed

+16
-3
lines changed

sklearn/metrics/tests/test_metrics.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1498,9 +1498,22 @@ def test_invariance_string_vs_numbers_labels():
14981498
err_msg="{0} failed string vs number "
14991499
"invariance test".format(name))
15001500

1501-
# TODO Currently not supported
1502-
for name, metrics in THRESHOLDED_METRICS.items():
1503-
assert_raises(ValueError, metrics, y1_str, y2_str)
1501+
for name, metric in THRESHOLDED_METRICS.items():
1502+
if name in ("log_loss", "hinge_loss"):
1503+
measure_with_number = metric(y1, y2)
1504+
measure_with_str = metric(y1_str, y2)
1505+
assert_array_equal(measure_with_number, measure_with_str,
1506+
err_msg="{0} failed string vs number invariance "
1507+
"test".format(name))
1508+
1509+
measure_with_strobj = metric(y1_str.astype('O'), y2)
1510+
assert_array_equal(measure_with_number, measure_with_strobj,
1511+
err_msg="{0} failed string object vs number "
1512+
"invariance test".format(name))
1513+
else:
1514+
# TODO those metrics doesn't support string label yet
1515+
assert_raises(ValueError, metric, y1_str, y2)
1516+
assert_raises(ValueError, metric, y1_str.astype('O'), y2)
15041517

15051518

15061519
@ignore_warnings

0 commit comments

Comments
 (0)