Skip to content

Commit 4c61e8b

Browse files
sam-samueller
authored andcommitted
[MRG+1] avoid integer overflow by using floats for matthews_corrcoef (scikit-learn#9693)
* Fix bug#9622: avoid integer overflow by using floats for matthews_corrcoef * matthews_corrcoef: cosmetic change requested by jnothman * Add test_matthews_corrcoef_overflow for Bug#9622 * test_matthews_corrcoef_overflow: clean-up and make deterministic * matthews_corrcoef: pass dtype=np.float64 to sum & trace instead of using astype * test_matthews_corrcoef_overflow: add simple deterministic tests
1 parent ea41a78 commit 4c61e8b

File tree

2 files changed

+39
-4
lines changed

2 files changed

+39
-4
lines changed

sklearn/metrics/classification.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def accuracy_score(y_true, y_pred, normalize=True, sample_weight=None):
167167
2
168168
169169
In the multilabel case with binary label indicators:
170-
170+
171171
>>> accuracy_score(np.array([[0, 1], [1, 1]]), np.ones((2, 2)))
172172
0.5
173173
"""
@@ -528,9 +528,9 @@ def matthews_corrcoef(y_true, y_pred, sample_weight=None):
528528
y_pred = lb.transform(y_pred)
529529

530530
C = confusion_matrix(y_true, y_pred, sample_weight=sample_weight)
531-
t_sum = C.sum(axis=1)
532-
p_sum = C.sum(axis=0)
533-
n_correct = np.trace(C)
531+
t_sum = C.sum(axis=1, dtype=np.float64)
532+
p_sum = C.sum(axis=0, dtype=np.float64)
533+
n_correct = np.trace(C, dtype=np.float64)
534534
n_samples = p_sum.sum()
535535
cov_ytyp = n_correct * n_samples - np.dot(t_sum, p_sum)
536536
cov_ypyp = n_samples ** 2 - np.dot(p_sum, p_sum)

sklearn/metrics/tests/test_classification.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,41 @@ def test_matthews_corrcoef_multiclass():
483483
assert_almost_equal(mcc, 0.)
484484

485485

486+
def test_matthews_corrcoef_overflow():
487+
# https://github.com/scikit-learn/scikit-learn/issues/9622
488+
rng = np.random.RandomState(20170906)
489+
490+
def mcc_safe(y_true, y_pred):
491+
conf_matrix = confusion_matrix(y_true, y_pred)
492+
true_pos = conf_matrix[1, 1]
493+
false_pos = conf_matrix[1, 0]
494+
false_neg = conf_matrix[0, 1]
495+
n_points = len(y_true)
496+
pos_rate = (true_pos + false_neg) / n_points
497+
activity = (true_pos + false_pos) / n_points
498+
mcc_numerator = true_pos / n_points - pos_rate * activity
499+
mcc_denominator = activity * pos_rate * (1 - activity) * (1 - pos_rate)
500+
return mcc_numerator / np.sqrt(mcc_denominator)
501+
502+
def random_ys(n_points): # binary
503+
x_true = rng.random_sample(n_points)
504+
x_pred = x_true + 0.2 * (rng.random_sample(n_points) - 0.5)
505+
y_true = (x_true > 0.5)
506+
y_pred = (x_pred > 0.5)
507+
return y_true, y_pred
508+
509+
for n_points in [100, 10000, 1000000]:
510+
arr = np.repeat([0., 1.], n_points) # binary
511+
assert_almost_equal(matthews_corrcoef(arr, arr), 1.0)
512+
arr = np.repeat([0., 1., 2.], n_points) # multiclass
513+
assert_almost_equal(matthews_corrcoef(arr, arr), 1.0)
514+
515+
y_true, y_pred = random_ys(n_points)
516+
assert_almost_equal(matthews_corrcoef(y_true, y_true), 1.0)
517+
assert_almost_equal(matthews_corrcoef(y_true, y_pred),
518+
mcc_safe(y_true, y_pred))
519+
520+
486521
def test_precision_recall_f1_score_multiclass():
487522
# Test Precision Recall and F1 Score for multiclass classification task
488523
y_true, y_pred, _ = make_prediction(binary=False)

0 commit comments

Comments
 (0)