Skip to content

Commit 019fc9b

Browse files
committed
ENH labels parameter in P/R/F may extend or reduce label set
1 parent d1d95e8 commit 019fc9b

File tree

5 files changed

+183
-52
lines changed

5 files changed

+183
-52
lines changed

doc/modules/model_evaluation.rst

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ take several parameters:
155155
certainties (``needs_threshold=True``). The default value is
156156
False.
157157

158-
* any additional parameters, such as ``beta`` in an :func:`f1_score`.
158+
* any additional parameters, such as ``beta`` or ``labels`` in :func:`f1_score`.
159159

160160
Here is an example of building custom scorers, and of using the
161161
``greater_is_better`` parameter::
@@ -657,8 +657,9 @@ specified by the ``average`` argument to the
657657
:func:`fbeta_score`, :func:`precision_recall_fscore_support`,
658658
:func:`precision_score` and :func:`recall_score` functions, as described
659659
:ref:`above <average>`. Note that for "micro"-averaging in a multiclass setting
660-
will produce equal precision, recall and :math:`F`, while "weighted" averaging
661-
may produce an F-score that is not between precision and recall.
660+
with all labels included will produce equal precision, recall and :math:`F`,
661+
while "weighted" averaging may produce an F-score that is not between
662+
precision and recall.
662663

663664
To make this more explicit, consider the following notation:
664665

@@ -709,6 +710,18 @@ Then the metrics are defined as:
709710
... # doctest: +ELLIPSIS
710711
(array([ 0.66..., 0. , 0. ]), array([ 1., 0., 0.]), array([ 0.71..., 0. , 0. ]), array([2, 2, 2]...))
711712

713+
For multiclass classification with a "negative class", it is possible to exclude some labels:
714+
715+
>>> metrics.recall_score(y_true, y_pred, labels=[1, 2], average='micro')
716+
... # excluding 0, no labels were correctly recalled
717+
0.0
718+
719+
Similarly, labels not present in the data sample may be accounted for in macro-averaging.
720+
721+
>>> metrics.precision_score(y_true, y_pred, labels=[0, 1, 2, 3], average='macro')
722+
... # doctest: +ELLIPSIS
723+
0.166...
724+
712725
.. _hinge_loss:
713726

714727
Hinge loss

doc/whats_new.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,14 @@ Enhancements
6464
- Added backlinks from the API reference pages to the user guide. By
6565
`Andreas Müller`_.
6666

67+
- The ``labels`` parameter to :func:`sklearn.metrics.f1_score`,
68+
:func:`sklearn.metrics.fbeta_score`,
69+
:func:`sklearn.metrics.recall_score` and
70+
:func:`sklearn.metrics.precision_score` has been extended.
71+
It is now possible to ignore one or more labels, such as where
72+
a multiclass problem has a majority class to ignore. By `Joel Nothman`_.
73+
74+
6775
Bug fixes
6876
.........
6977

sklearn/metrics/classification.py

Lines changed: 87 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -517,8 +517,14 @@ def f1_score(y_true, y_pred, labels=None, pos_label=1, average='binary',
517517
y_pred : 1d array-like, or label indicator array / sparse matrix
518518
Estimated targets as returned by a classifier.
519519
520-
labels : array
521-
Integer array of labels.
520+
labels : list, optional
521+
The set of labels to include when ``average != 'binary'``, and their
522+
order if ``average is None``. Labels present in the data can be
523+
excluded, for example to calculate a multiclass average ignoring a
524+
majority negative class, while labels not present in the data will
525+
result in 0 components in a macro average. For multilabel targets,
526+
labels are column indices. By default, all labels in ``y_true`` and
527+
``y_pred`` are used in sorted order.
522528
523529
pos_label : str or int, 1 by default
524530
The class to report if ``average='binary'``. Until version 0.18 it is
@@ -614,8 +620,14 @@ def fbeta_score(y_true, y_pred, beta, labels=None, pos_label=1,
614620
beta: float
615621
Weight of precision in harmonic mean.
616622
617-
labels : array
618-
Integer array of labels.
623+
labels : list, optional
624+
The set of labels to include when ``average != 'binary'``, and their
625+
order if ``average is None``. Labels present in the data can be
626+
excluded, for example to calculate a multiclass average ignoring a
627+
majority negative class, while labels not present in the data will
628+
result in 0 components in a macro average. For multilabel targets,
629+
labels are column indices. By default, all labels in ``y_true`` and
630+
``y_pred`` are used in sorted order.
619631
620632
pos_label : str or int, 1 by default
621633
The class to report if ``average='binary'``. Until version 0.18 it is
@@ -784,8 +796,14 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None,
784796
beta : float, 1.0 by default
785797
The strength of recall versus precision in the F-score.
786798
787-
labels : array
788-
Integer array of labels.
799+
labels : list, optional
800+
The set of labels to include when ``average != 'binary'``, and their
801+
order if ``average is None``. Labels present in the data can be
802+
excluded, for example to calculate a multiclass average ignoring a
803+
majority negative class, while labels not present in the data will
804+
result in 0 components in a macro average. For multilabel targets,
805+
labels are column indices. By default, all labels in ``y_true`` and
806+
``y_pred`` are used in sorted order.
789807
790808
pos_label : str or int, 1 by default
791809
The class to report if ``average='binary'``. Until version 0.18 it is
@@ -879,6 +897,7 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None,
879897
raise ValueError("beta should be >0 in the F-beta score")
880898

881899
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
900+
present_labels = unique_labels(y_true, y_pred)
882901

883902
if average == 'binary' and (y_type != 'binary' or pos_label is None):
884903
warnings.warn('The default `weighted` averaging is deprecated, '
@@ -891,17 +910,49 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None,
891910
% str(average_options), DeprecationWarning, stacklevel=2)
892911
average = 'weighted'
893912

894-
label_order = labels # save this for later
913+
if y_type == 'binary' and pos_label is not None and average is not None:
914+
if average != 'binary':
915+
warnings.warn('From version 0.18, binary input will not be '
916+
'handled specially when using averaged '
917+
'precision/recall/F-score. '
918+
'Please use average=\'binary\' to report only the '
919+
'positive class performance.', DeprecationWarning)
920+
if labels is None or len(labels) <= 2:
921+
if pos_label not in present_labels:
922+
if len(present_labels) < 2:
923+
# Only negative labels
924+
return (0., 0., 0., 0)
925+
else:
926+
raise ValueError("pos_label=%r is not a valid label: %r" %
927+
(pos_label, present_labels))
928+
labels = [pos_label]
895929
if labels is None:
896-
labels = unique_labels(y_true, y_pred)
930+
labels = present_labels
931+
n_labels = None
897932
else:
898-
labels = np.sort(labels)
933+
n_labels = len(labels)
934+
labels = np.hstack([labels, np.setdiff1d(present_labels, labels,
935+
assume_unique=True)])
899936

900937
### Calculate tp_sum, pred_sum, true_sum ###
901938

902939
if y_type.startswith('multilabel'):
903940
sum_axis = 1 if average == 'samples' else 0
904941

942+
# All labels are index integers for multilabel.
943+
# Select labels:
944+
if not np.all(labels == present_labels):
945+
if np.max(labels) > np.max(present_labels):
946+
raise ValueError('All labels must be in [0, n labels). '
947+
'Got %d > %d' %
948+
(np.max(labels), np.max(present_labels)))
949+
if np.min(labels) < 0:
950+
raise ValueError('All labels must be in [0, n labels). '
951+
'Got %d < 0' % np.min(labels))
952+
953+
y_true = y_true[:, labels[:n_labels]]
954+
y_pred = y_pred[:, labels[:n_labels]]
955+
905956
# calculate weighted counts
906957
true_and_pred = y_true.multiply(y_pred)
907958
tp_sum = count_nonzero(true_and_pred, axis=sum_axis,
@@ -916,11 +967,11 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None,
916967
"not meaningful outside multilabel "
917968
"classification. See the accuracy_score instead.")
918969
else:
919-
lb = LabelEncoder()
920-
lb.fit(labels)
921-
y_true = lb.transform(y_true)
922-
y_pred = lb.transform(y_pred)
923-
labels = lb.classes_
970+
le = LabelEncoder()
971+
le.fit(labels)
972+
y_true = le.transform(y_true)
973+
y_pred = le.transform(y_pred)
974+
sorted_labels = le.classes_
924975

925976
# labels are now from 0 to len(labels) - 1 -> use bincount
926977
tp = y_true == y_pred
@@ -943,28 +994,13 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None,
943994
true_sum = bincount(y_true, weights=sample_weight,
944995
minlength=len(labels))
945996

946-
### Select labels to keep ###
997+
# Retain only selected labels
998+
indices = np.searchsorted(sorted_labels, labels[:n_labels])
999+
tp_sum = tp_sum[indices]
1000+
true_sum = true_sum[indices]
1001+
pred_sum = pred_sum[indices]
9471002

948-
if y_type == 'binary' and average is not None and pos_label is not None:
949-
if average != 'binary':
950-
warnings.warn('From version 0.18, binary input will not be '
951-
'handled specially when using averaged '
952-
'precision/recall/F-score. '
953-
'Please use average=\'binary\' to report only the '
954-
'positive class performance.', DeprecationWarning)
955-
if pos_label not in labels:
956-
if len(labels) == 1:
957-
# Only negative labels
958-
return (0., 0., 0., 0)
959-
else:
960-
raise ValueError("pos_label=%r is not a valid label: %r" %
961-
(pos_label, labels))
962-
pos_label_idx = labels == pos_label
963-
tp_sum = tp_sum[pos_label_idx]
964-
pred_sum = pred_sum[pos_label_idx]
965-
true_sum = true_sum[pos_label_idx]
966-
967-
elif average == 'micro':
1003+
if average == 'micro':
9681004
tp_sum = np.array([tp_sum.sum()])
9691005
pred_sum = np.array([pred_sum.sum()])
9701006
true_sum = np.array([true_sum.sum()])
@@ -1004,12 +1040,6 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None,
10041040
recall = np.average(recall, weights=weights)
10051041
f_score = np.average(f_score, weights=weights)
10061042
true_sum = None # return no support
1007-
elif label_order is not None:
1008-
indices = np.searchsorted(labels, label_order)
1009-
precision = precision[indices]
1010-
recall = recall[indices]
1011-
f_score = f_score[indices]
1012-
true_sum = true_sum[indices]
10131043

10141044
return precision, recall, f_score, true_sum
10151045

@@ -1035,8 +1065,14 @@ def precision_score(y_true, y_pred, labels=None, pos_label=1,
10351065
y_pred : 1d array-like, or label indicator array / sparse matrix
10361066
Estimated targets as returned by a classifier.
10371067
1038-
labels : array
1039-
Integer array of labels.
1068+
labels : list, optional
1069+
The set of labels to include when ``average != 'binary'``, and their
1070+
order if ``average is None``. Labels present in the data can be
1071+
excluded, for example to calculate a multiclass average ignoring a
1072+
majority negative class, while labels not present in the data will
1073+
result in 0 components in a macro average. For multilabel targets,
1074+
labels are column indices. By default, all labels in ``y_true`` and
1075+
``y_pred`` are used in sorted order.
10401076
10411077
pos_label : str or int, 1 by default
10421078
The class to report if ``average='binary'``. Until version 0.18 it is
@@ -1128,8 +1164,14 @@ def recall_score(y_true, y_pred, labels=None, pos_label=1, average='binary',
11281164
y_pred : 1d array-like, or label indicator array / sparse matrix
11291165
Estimated targets as returned by a classifier.
11301166
1131-
labels : array
1132-
Integer array of labels.
1167+
labels : list, optional
1168+
The set of labels to include when ``average != 'binary'``, and their
1169+
order if ``average is None``. Labels present in the data can be
1170+
excluded, for example to calculate a multiclass average ignoring a
1171+
majority negative class, while labels not present in the data will
1172+
result in 0 components in a macro average. For multilabel targets,
1173+
labels are column indices. By default, all labels in ``y_true`` and
1174+
``y_pred`` are used in sorted order.
11331175
11341176
pos_label : str or int, 1 by default
11351177
The class to report if ``average='binary'``. Until version 0.18 it is

sklearn/metrics/tests/test_classification.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from sklearn.datasets import make_multilabel_classification
1313
from sklearn.preprocessing import LabelBinarizer, MultiLabelBinarizer
14+
from sklearn.preprocessing import label_binarize
1415
from sklearn.utils.fixes import np_version
1516
from sklearn.utils.validation import check_random_state
1617

@@ -173,6 +174,73 @@ def test_precision_recall_f_binary_single_class():
173174
assert_equal(0., f1_score([-1, -1], [-1, -1]))
174175

175176

177+
@ignore_warnings
178+
def test_precision_recall_f_extra_labels():
179+
"""Test handling of explicit additional (not in input) labels to PRF
180+
"""
181+
y_true = [1, 3, 3, 2]
182+
y_pred = [1, 1, 3, 2]
183+
y_true_bin = label_binarize(y_true, classes=np.arange(5))
184+
y_pred_bin = label_binarize(y_pred, classes=np.arange(5))
185+
data = [(y_true, y_pred),
186+
(y_true_bin, y_pred_bin)]
187+
188+
for i, (y_true, y_pred) in enumerate(data):
189+
# No average: zeros in array
190+
actual = recall_score(y_true, y_pred, labels=[0, 1, 2, 3, 4],
191+
average=None)
192+
assert_array_almost_equal([0., 1., 1., .5, 0.], actual)
193+
194+
# Macro average is changed
195+
actual = recall_score(y_true, y_pred, labels=[0, 1, 2, 3, 4],
196+
average='macro')
197+
assert_array_almost_equal(np.mean([0., 1., 1., .5, 0.]), actual)
198+
199+
# No effect otheriwse
200+
for average in ['micro', 'weighted', 'samples']:
201+
if average == 'samples' and i == 0:
202+
continue
203+
assert_almost_equal(recall_score(y_true, y_pred,
204+
labels=[0, 1, 2, 3, 4],
205+
average=average),
206+
recall_score(y_true, y_pred, labels=None,
207+
average=average))
208+
209+
# Error when introducing invalid label in multilabel case
210+
# (although it would only affect performance if average='macro'/None)
211+
for average in [None, 'macro', 'micro', 'samples']:
212+
assert_raises(ValueError, recall_score, y_true_bin, y_pred_bin,
213+
labels=np.arange(6), average=average)
214+
assert_raises(ValueError, recall_score, y_true_bin, y_pred_bin,
215+
labels=np.arange(-1, 4), average=average)
216+
217+
218+
@ignore_warnings
219+
def test_precision_recall_f_ignored_labels():
220+
"""Test a subset of labels may be requested for PRF"""
221+
y_true = [1, 1, 2, 3]
222+
y_pred = [1, 3, 3, 3]
223+
y_true_bin = label_binarize(y_true, classes=np.arange(5))
224+
y_pred_bin = label_binarize(y_pred, classes=np.arange(5))
225+
data = [(y_true, y_pred),
226+
(y_true_bin, y_pred_bin)]
227+
228+
for i, (y_true, y_pred) in enumerate(data):
229+
recall_13 = partial(recall_score, y_true, y_pred, labels=[1, 3])
230+
recall_all = partial(recall_score, y_true, y_pred, labels=None)
231+
232+
assert_array_almost_equal([.5, 1.], recall_13(average=None))
233+
assert_almost_equal((.5 + 1.) / 2, recall_13(average='macro'))
234+
assert_almost_equal((.5 * 2 + 1. * 1) / 3,
235+
recall_13(average='weighted'))
236+
assert_almost_equal(2. / 3, recall_13(average='micro'))
237+
238+
# ensure the above were meaningful tests:
239+
for average in ['macro', 'weighted', 'micro']:
240+
assert_not_equal(recall_13(average=average),
241+
recall_all(average=average))
242+
243+
176244
def test_average_precision_score_score_non_binary_class():
177245
# Test that average_precision_score function returns an error when trying
178246
# to compute average_precision_score for multiclass task.
@@ -315,7 +383,7 @@ def test_precision_refcall_f1_score_multilabel_unordered_labels():
315383
y_pred = np.array([[0, 0, 1, 1]])
316384
for average in ['samples', 'micro', 'macro', 'weighted', None]:
317385
p, r, f, s = precision_recall_fscore_support(
318-
y_true, y_pred, labels=[4, 1, 2, 3], warn_for=[], average=average)
386+
y_true, y_pred, labels=[3, 0, 1, 2], warn_for=[], average=average)
319387
assert_array_equal(p, 0)
320388
assert_array_equal(r, 0)
321389
assert_array_equal(f, 0)

sklearn/metrics/tests/test_common.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1085,9 +1085,9 @@ def test_no_averaging_labels():
10851085
# in multi-class and multi-label cases
10861086
y_true_multilabel = np.array([[1, 1, 0, 0], [1, 1, 0, 0]])
10871087
y_pred_multilabel = np.array([[0, 0, 1, 1], [0, 1, 1, 0]])
1088-
y_true_multiclass = np.array([1, 2, 3])
1089-
y_pred_multiclass = np.array([1, 3, 4])
1090-
labels = np.array([4, 1, 2, 3])
1088+
y_true_multiclass = np.array([0, 1, 2])
1089+
y_pred_multiclass = np.array([0, 2, 3])
1090+
labels = np.array([3, 0, 1, 2])
10911091
_, inverse_labels = np.unique(labels, return_inverse=True)
10921092

10931093
for name in METRICS_WITH_AVERAGING:

0 commit comments

Comments
 (0)