Skip to content

Commit 5cc8032

Browse files
committed
ENH add normalize option to accuracy_score + FIX bug with 1d array
1 parent 58b96aa commit 5cc8032

File tree

4 files changed

+309
-46
lines changed

4 files changed

+309
-46
lines changed

doc/modules/model_evaluation.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ Accuracy score
7272
---------------
7373
The :func:`accuracy_score` function computes the
7474
`accuracy <http://en.wikipedia.org/wiki/Accuracy_and_precision>`_, the fraction
75-
of correct predictions. In multilabel classification,
75+
(default) or the number of correct predictions. In multilabel classification,
7676
the function returns the subset accuracy:
7777
the entire set of labels for a sample must be entirely correct
7878
or the sample has an accuracy of zero.
@@ -96,6 +96,8 @@ where :math:`1(x)` is the `indicator function
9696
>>> y_true = [0, 1, 2, 3]
9797
>>> accuracy_score(y_true, y_pred)
9898
0.5
99+
>>> accuracy_score(y_true, y_pred, normalize=False)
100+
2
99101

100102
In the multilabel case with binary indicator format:
101103

doc/whats_new.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,10 @@ Changelog
6565
- Performance improvements in :class:`isotonic.IsotonicRegression` by
6666
Nelle Varoquaux.
6767

68+
- :func:`metrics.accuracy_score` has an option normalize to return
69+
the fraction or the number of correctly classified sample
70+
by `Arnaud Joly`_.
71+
6872

6973
API changes summary
7074
-------------------

sklearn/metrics/metrics.py

Lines changed: 172 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424

2525
from ..externals.six.moves import zip
2626
from ..preprocessing import LabelBinarizer
27-
from ..utils import check_arrays, deprecated
27+
from ..utils import check_arrays
28+
from ..utils import deprecated
2829
from ..utils.multiclass import is_label_indicator_matrix
2930
from ..utils.multiclass import is_multilabel
3031
from ..utils.multiclass import unique_labels
@@ -33,6 +34,121 @@
3334
###############################################################################
3435
# General utilities
3536
###############################################################################
37+
def _is_1d(x):
38+
"""Return True if x can be considered as a 1d vector.
39+
40+
This function allows to distinguish between a 1d vector, e.g. :
41+
- ``np.array([1, 2])``
42+
- ``np.array([[1, 2]])``
43+
- ``np.array([[1], [2]])``
44+
45+
and 2d matrix, e.g.:
46+
- ``np.array([[1, 2], [3, 4]])``
47+
48+
49+
Parameters
50+
----------
51+
x : numpy array.
52+
53+
Return
54+
------
55+
is_1d : boolean,
56+
Return True if x can be considered as a 1d vector.
57+
58+
Examples
59+
--------
60+
>>> import numpy as np
61+
>>> from sklearn.metrics.metrics import _is_1d
62+
>>> _is_1d([1, 2, 3])
63+
True
64+
>>> _is_1d(np.array([1, 2, 3]))
65+
True
66+
>>> _is_1d([[1, 2, 3]])
67+
True
68+
>>> _is_1d(np.array([[1, 2, 3]]))
69+
True
70+
>>> _is_1d([[1], [2], [3]])
71+
True
72+
>>> _is_1d(np.array([[1], [2], [3]]))
73+
True
74+
>>> _is_1d([[1, 2], [3, 4]])
75+
False
76+
>>> _is_1d(np.array([[1, 2], [3, 4]]))
77+
False
78+
79+
See also
80+
--------
81+
_check_1d_array
82+
83+
"""
84+
return np.size(x) == np.max(np.shape(x))
85+
86+
87+
def _check_1d_array(y1, y2, ravel=False):
88+
"""Check that y1 and y2 are vectors of the same shape.
89+
90+
It convert 1d arrays (y1 and y2) of various shape to a common shape
91+
representation. Note that ``y1`` and ``y2`` should have the same number of
92+
element.
93+
94+
Parameters
95+
----------
96+
y1 : array-like,
97+
y1 must be a "vector".
98+
99+
y2 : array-like
100+
y2 must be a "vector".
101+
102+
ravel : boolean, optional (default=False),
103+
If ``ravel``` is set to ``True``, then ``y1`` and ``y2`` are raveled.
104+
105+
Returns
106+
-------
107+
y1 : numpy array,
108+
If ``ravel`` is set to ``True``, return np.ravel(y1), else
109+
return y1.
110+
111+
y2 : numpy array,
112+
Return y2 reshaped to have the shape of y1.
113+
114+
Examples
115+
--------
116+
>>> from numpy import array
117+
>>> from sklearn.metrics.metrics import _check_1d_array
118+
>>> _check_1d_array([1, 2], [[3, 4]])
119+
(array([1, 2]), array([3, 4]))
120+
>>> _check_1d_array([[1, 2]], [[3], [4]])
121+
(array([[1, 2]]), array([[3, 4]]))
122+
>>> _check_1d_array([[1], [2]], [[3, 4]])
123+
(array([[1],
124+
[2]]), array([[3],
125+
[4]]))
126+
>>> _check_1d_array([[1], [2]], [[3, 4]], ravel=True)
127+
(array([1, 2]), array([3, 4]))
128+
129+
See also
130+
--------
131+
_is_1d
132+
133+
"""
134+
y1 = np.asarray(y1)
135+
y2 = np.asarray(y2)
136+
137+
if not _is_1d(y1):
138+
raise ValueError("y1 can't be considered as a vector")
139+
140+
if not _is_1d(y2):
141+
raise ValueError("y2 can't be considered as a vector")
142+
143+
if ravel:
144+
return np.ravel(y1), np.ravel(y2)
145+
else:
146+
if np.shape(y1) != np.shape(y2):
147+
y2 = np.reshape(y2, np.shape(y1))
148+
149+
return y1, y2
150+
151+
36152
def auc(x, y, reorder=False):
37153
"""Compute Area Under the Curve (AUC) using the trapezoidal rule
38154
@@ -47,7 +163,7 @@ def auc(x, y, reorder=False):
47163
y : array, shape = [n]
48164
y coordinates.
49165
50-
reorder : boolean, optional
166+
reorder : boolean, optional (default=False)
51167
If True, assume that the curve is ascending in the case of ties, as for
52168
an ROC curve. If the curve is non-ascending, the result will be wrong.
53169
@@ -299,6 +415,9 @@ def matthews_corrcoef(y_true, y_pred):
299415
-0.33...
300416
301417
"""
418+
y_true, y_pred = check_arrays(y_true, y_pred)
419+
y_true, y_pred = _check_1d_array(y_true, y_pred, ravel=True)
420+
302421
mcc = np.corrcoef(y_true, y_pred)[0, 1]
303422
if np.isnan(mcc):
304423
return 0.
@@ -655,8 +774,8 @@ def zero_one_loss(y_true, y_pred, normalize=True):
655774
y_pred : array-like or list of labels or label indicator matrix
656775
Predicted labels, as returned by a classifier.
657776
658-
normalize : bool, optional
659-
If ``False`` (default), return the number of misclassifications.
777+
normalize : bool, optional (default=True)
778+
If ``False``, return the number of misclassifications.
660779
Otherwise, return the fraction of misclassifications.
661780
662781
Returns
@@ -696,34 +815,19 @@ def zero_one_loss(y_true, y_pred, normalize=True):
696815
697816
"""
698817
y_true, y_pred = check_arrays(y_true, y_pred, allow_lists=True)
818+
score = accuracy_score(y_true, y_pred, normalize=normalize)
699819

700-
if is_multilabel(y_true):
701-
# Handle mix representation
702-
if type(y_true) != type(y_pred):
703-
labels = unique_labels(y_true, y_pred)
704-
lb = LabelBinarizer()
705-
lb.fit([labels.tolist()])
706-
y_true = lb.transform(y_true)
707-
y_pred = lb.transform(y_pred)
820+
if normalize:
821+
return 1 - score
822+
else:
823+
if hasattr(y_true, "shape"):
824+
n_samples = (np.max(y_true.shape) if _is_1d(y_true)
825+
else y_true.shape[0])
708826

709-
if is_label_indicator_matrix(y_true):
710-
loss = (y_pred != y_true).sum(axis=1) > 0
711827
else:
712-
# numpy 1.3 : it is required to perform a unique before setxor1d
713-
# to get unique label in numpy 1.3.
714-
# This is needed in order to handle redundant labels.
715-
# FIXME : check if this can be simplified when 1.3 is removed
716-
loss = np.array([np.size(np.setxor1d(np.unique(pred),
717-
np.unique(true))) > 0
718-
for pred, true in zip(y_pred, y_true)])
719-
else:
720-
y_true, y_pred = check_arrays(y_true, y_pred)
721-
loss = y_true != y_pred
828+
n_samples = len(y_true)
722829

723-
if normalize:
724-
return np.mean(loss)
725-
else:
726-
return np.sum(loss)
830+
return n_samples - score
727831

728832

729833
@deprecated("Function 'zero_one' has been renamed to "
@@ -743,7 +847,7 @@ def zero_one(y_true, y_pred, normalize=False):
743847
744848
y_pred : array-like
745849
746-
normalize : bool, optional
850+
normalize : bool, optional (default=False)
747851
If ``False`` (default), return the number of misclassifications.
748852
Otherwise, return the fraction of misclassifications.
749853
@@ -771,7 +875,7 @@ def zero_one(y_true, y_pred, normalize=False):
771875
###############################################################################
772876
# Multiclass score functions
773877
###############################################################################
774-
def accuracy_score(y_true, y_pred):
878+
def accuracy_score(y_true, y_pred, normalize=True):
775879
"""Accuracy classification score.
776880
777881
Parameters
@@ -782,6 +886,10 @@ def accuracy_score(y_true, y_pred):
782886
y_pred : array-like or list of labels or label indicator matrix
783887
Predicted labels, as returned by a classifier.
784888
889+
normalize : bool, optional (default=True)
890+
If ``False``, return the number of correctly classified samples.
891+
Otherwise, return the fraction of correctly classified samples.
892+
785893
Returns
786894
-------
787895
score : float
@@ -806,6 +914,8 @@ def accuracy_score(y_true, y_pred):
806914
>>> y_true = [0, 1, 2, 3]
807915
>>> accuracy_score(y_true, y_pred)
808916
0.5
917+
>>> accuracy_score(y_true, y_pred, normalize=False)
918+
2
809919
810920
In the multilabel case with binary indicator format:
811921
@@ -841,9 +951,15 @@ def accuracy_score(y_true, y_pred):
841951
for pred, true in zip(y_pred, y_true)])
842952
else:
843953
y_true, y_pred = check_arrays(y_true, y_pred)
954+
955+
# Handle mix shape
956+
y_true, y_pred = _check_1d_array(y_true, y_pred, ravel=True)
844957
score = y_true == y_pred
845958

846-
return np.mean(score)
959+
if normalize:
960+
return np.mean(score)
961+
else:
962+
return np.sum(score)
847963

848964

849965
def f1_score(y_true, y_pred, labels=None, pos_label=1, average='weighted'):
@@ -1146,6 +1262,8 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None,
11461262
raise ValueError("beta should be >0 in the F-beta score")
11471263

11481264
y_true, y_pred = check_arrays(y_true, y_pred)
1265+
y_true, y_pred = _check_1d_array(y_true, y_pred)
1266+
11491267
if labels is None:
11501268
labels = unique_labels(y_true, y_pred)
11511269
else:
@@ -1589,6 +1707,9 @@ def hamming_loss(y_true, y_pred, classes=None):
15891707
return np.mean(loss) / np.size(classes)
15901708

15911709
else:
1710+
y_true, y_pred = check_arrays(y_true, y_pred)
1711+
y_true, y_pred = _check_1d_array(y_true, y_pred)
1712+
15921713
return sp_hamming(y_true, y_pred)
15931714

15941715

@@ -1625,6 +1746,11 @@ def mean_absolute_error(y_true, y_pred):
16251746
16261747
"""
16271748
y_true, y_pred = check_arrays(y_true, y_pred)
1749+
1750+
# Handle mix 1d representation
1751+
if _is_1d(y_true):
1752+
y_true, y_pred = _check_1d_array(y_true, y_pred)
1753+
16281754
return np.mean(np.abs(y_pred - y_true))
16291755

16301756

@@ -1658,6 +1784,11 @@ def mean_squared_error(y_true, y_pred):
16581784
16591785
"""
16601786
y_true, y_pred = check_arrays(y_true, y_pred)
1787+
1788+
# Handle mix 1d representation
1789+
if _is_1d(y_true):
1790+
y_true, y_pred = _check_1d_array(y_true, y_pred)
1791+
16611792
return np.mean((y_pred - y_true) ** 2)
16621793

16631794

@@ -1696,6 +1827,11 @@ def explained_variance_score(y_true, y_pred):
16961827
16971828
"""
16981829
y_true, y_pred = check_arrays(y_true, y_pred)
1830+
1831+
# Handle mix 1d representation
1832+
if _is_1d(y_true):
1833+
y_true, y_pred = _check_1d_array(y_true, y_pred)
1834+
16991835
numerator = np.var(y_true - y_pred)
17001836
denominator = np.var(y_true)
17011837
if denominator == 0.0:
@@ -1752,6 +1888,11 @@ def r2_score(y_true, y_pred):
17521888
17531889
"""
17541890
y_true, y_pred = check_arrays(y_true, y_pred)
1891+
1892+
# Handle mix 1d representation
1893+
if _is_1d(y_true):
1894+
y_true, y_pred = _check_1d_array(y_true, y_pred, ravel=True)
1895+
17551896
if len(y_true) == 1:
17561897
raise ValueError("r2_score can only be computed given more than one"
17571898
" sample.")

0 commit comments

Comments
 (0)