@@ -1498,9 +1498,22 @@ def test_invariance_string_vs_numbers_labels():
14981498 err_msg = "{0} failed string vs number "
14991499 "invariance test" .format (name ))
15001500
1501- # TODO Currently not supported
1502- for name , metrics in THRESHOLDED_METRICS .items ():
1503- assert_raises (ValueError , metrics , y1_str , y2_str )
1501+ for name , metric in THRESHOLDED_METRICS .items ():
1502+ if name in ("log_loss" , "hinge_loss" ):
1503+ measure_with_number = metric (y1 , y2 )
1504+ measure_with_str = metric (y1_str , y2 )
1505+ assert_array_equal (measure_with_number , measure_with_str ,
1506+ err_msg = "{0} failed string vs number invariance "
1507+ "test" .format (name ))
1508+
1509+ measure_with_strobj = metric (y1_str .astype ('O' ), y2 )
1510+ assert_array_equal (measure_with_number , measure_with_strobj ,
1511+ err_msg = "{0} failed string object vs number "
1512+ "invariance test" .format (name ))
1513+ else :
1514+ # TODO those metrics doesn't support string label yet
1515+ assert_raises (ValueError , metric , y1_str , y2 )
1516+ assert_raises (ValueError , metric , y1_str .astype ('O' ), y2 )
15041517
15051518
15061519@ignore_warnings
0 commit comments