2424from sklearn .utils .testing import assert_less
2525from sklearn .utils .testing import assert_array_almost_equal
2626from sklearn .utils .testing import assert_array_equal
27+ from sklearn .utils .testing import assert_allclose
2728from sklearn .utils .mocking import CheckingClassifier , MockDataFrame
2829
2930from sklearn .model_selection import cross_val_score , ShuffleSplit
@@ -1333,8 +1334,8 @@ def check_cross_val_predict_binary(est, X, y, method):
13331334
13341335 # Check actual outputs for several representations of y
13351336 for tg in [y , y + 1 , y - 2 , y .astype ('str' )]:
1336- assert_array_equal (cross_val_predict (est , X , tg , method = method , cv = cv ),
1337- expected_predictions )
1337+ assert_allclose (cross_val_predict (est , X , tg , method = method , cv = cv ),
1338+ expected_predictions )
13381339
13391340
13401341def check_cross_val_predict_multiclass (est , X , y , method ):
@@ -1358,8 +1359,8 @@ def check_cross_val_predict_multiclass(est, X, y, method):
13581359
13591360 # Check actual outputs for several representations of y
13601361 for tg in [y , y + 1 , y - 2 , y .astype ('str' )]:
1361- assert_array_equal (cross_val_predict (est , X , tg , method = method , cv = cv ),
1362- expected_predictions )
1362+ assert_allclose (cross_val_predict (est , X , tg , method = method , cv = cv ),
1363+ expected_predictions )
13631364
13641365
13651366def check_cross_val_predict_multilabel (est , X , y , method ):
@@ -1406,7 +1407,7 @@ def check_cross_val_predict_multilabel(est, X, y, method):
14061407 cv_predict_output = cross_val_predict (est , X , tg , method = method , cv = cv )
14071408 assert_equal (len (cv_predict_output ), len (expected_preds ))
14081409 for i in range (len (cv_predict_output )):
1409- assert_array_equal (cv_predict_output [i ], expected_preds [i ])
1410+ assert_allclose (cv_predict_output [i ], expected_preds [i ])
14101411
14111412
14121413def check_cross_val_predict_with_method_binary (est ):
0 commit comments