@@ -517,8 +517,14 @@ def f1_score(y_true, y_pred, labels=None, pos_label=1, average='binary',
517517 y_pred : 1d array-like, or label indicator array / sparse matrix
518518 Estimated targets as returned by a classifier.
519519
520- labels : array
521- Integer array of labels.
520+ labels : list, optional
521+ The set of labels to include when ``average != 'binary'``, and their
522+ order if ``average is None``. Labels present in the data can be
523+ excluded, for example to calculate a multiclass average ignoring a
524+ majority negative class, while labels not present in the data will
525+ result in 0 components in a macro average. For multilabel targets,
526+ labels are column indices. By default, all labels in ``y_true`` and
527+ ``y_pred`` are used in sorted order.
522528
523529 pos_label : str or int, 1 by default
524530 The class to report if ``average='binary'``. Until version 0.18 it is
@@ -614,8 +620,14 @@ def fbeta_score(y_true, y_pred, beta, labels=None, pos_label=1,
614620 beta: float
615621 Weight of precision in harmonic mean.
616622
617- labels : array
618- Integer array of labels.
623+ labels : list, optional
624+ The set of labels to include when ``average != 'binary'``, and their
625+ order if ``average is None``. Labels present in the data can be
626+ excluded, for example to calculate a multiclass average ignoring a
627+ majority negative class, while labels not present in the data will
628+ result in 0 components in a macro average. For multilabel targets,
629+ labels are column indices. By default, all labels in ``y_true`` and
630+ ``y_pred`` are used in sorted order.
619631
620632 pos_label : str or int, 1 by default
621633 The class to report if ``average='binary'``. Until version 0.18 it is
@@ -784,8 +796,14 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None,
784796 beta : float, 1.0 by default
785797 The strength of recall versus precision in the F-score.
786798
787- labels : array
788- Integer array of labels.
799+ labels : list, optional
800+ The set of labels to include when ``average != 'binary'``, and their
801+ order if ``average is None``. Labels present in the data can be
802+ excluded, for example to calculate a multiclass average ignoring a
803+ majority negative class, while labels not present in the data will
804+ result in 0 components in a macro average. For multilabel targets,
805+ labels are column indices. By default, all labels in ``y_true`` and
806+ ``y_pred`` are used in sorted order.
789807
790808 pos_label : str or int, 1 by default
791809 The class to report if ``average='binary'``. Until version 0.18 it is
@@ -879,6 +897,7 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None,
879897 raise ValueError ("beta should be >0 in the F-beta score" )
880898
881899 y_type , y_true , y_pred = _check_targets (y_true , y_pred )
900+ present_labels = unique_labels (y_true , y_pred )
882901
883902 if average == 'binary' and (y_type != 'binary' or pos_label is None ):
884903 warnings .warn ('The default `weighted` averaging is deprecated, '
@@ -891,17 +910,49 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None,
891910 % str (average_options ), DeprecationWarning , stacklevel = 2 )
892911 average = 'weighted'
893912
894- label_order = labels # save this for later
913+ if y_type == 'binary' and pos_label is not None and average is not None :
914+ if average != 'binary' :
915+ warnings .warn ('From version 0.18, binary input will not be '
916+ 'handled specially when using averaged '
917+ 'precision/recall/F-score. '
918+ 'Please use average=\' binary\' to report only the '
919+ 'positive class performance.' , DeprecationWarning )
920+ if labels is None or len (labels ) <= 2 :
921+ if pos_label not in present_labels :
922+ if len (present_labels ) < 2 :
923+ # Only negative labels
924+ return (0. , 0. , 0. , 0 )
925+ else :
926+ raise ValueError ("pos_label=%r is not a valid label: %r" %
927+ (pos_label , present_labels ))
928+ labels = [pos_label ]
895929 if labels is None :
896- labels = unique_labels (y_true , y_pred )
930+ labels = present_labels
931+ n_labels = None
897932 else :
898- labels = np .sort (labels )
933+ n_labels = len (labels )
934+ labels = np .hstack ([labels , np .setdiff1d (present_labels , labels ,
935+ assume_unique = True )])
899936
900937 ### Calculate tp_sum, pred_sum, true_sum ###
901938
902939 if y_type .startswith ('multilabel' ):
903940 sum_axis = 1 if average == 'samples' else 0
904941
942+ # All labels are index integers for multilabel.
943+ # Select labels:
944+ if not np .all (labels == present_labels ):
945+ if np .max (labels ) > np .max (present_labels ):
946+ raise ValueError ('All labels must be in [0, n labels). '
947+ 'Got %d > %d' %
948+ (np .max (labels ), np .max (present_labels )))
949+ if np .min (labels ) < 0 :
950+ raise ValueError ('All labels must be in [0, n labels). '
951+ 'Got %d < 0' % np .min (labels ))
952+
953+ y_true = y_true [:, labels [:n_labels ]]
954+ y_pred = y_pred [:, labels [:n_labels ]]
955+
905956 # calculate weighted counts
906957 true_and_pred = y_true .multiply (y_pred )
907958 tp_sum = count_nonzero (true_and_pred , axis = sum_axis ,
@@ -916,11 +967,11 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None,
916967 "not meaningful outside multilabel "
917968 "classification. See the accuracy_score instead." )
918969 else :
919- lb = LabelEncoder ()
920- lb .fit (labels )
921- y_true = lb .transform (y_true )
922- y_pred = lb .transform (y_pred )
923- labels = lb .classes_
970+ le = LabelEncoder ()
971+ le .fit (labels )
972+ y_true = le .transform (y_true )
973+ y_pred = le .transform (y_pred )
974+ sorted_labels = le .classes_
924975
925976 # labels are now from 0 to len(labels) - 1 -> use bincount
926977 tp = y_true == y_pred
@@ -943,28 +994,13 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None,
943994 true_sum = bincount (y_true , weights = sample_weight ,
944995 minlength = len (labels ))
945996
946- ### Select labels to keep ###
997+ # Retain only selected labels
998+ indices = np .searchsorted (sorted_labels , labels [:n_labels ])
999+ tp_sum = tp_sum [indices ]
1000+ true_sum = true_sum [indices ]
1001+ pred_sum = pred_sum [indices ]
9471002
948- if y_type == 'binary' and average is not None and pos_label is not None :
949- if average != 'binary' :
950- warnings .warn ('From version 0.18, binary input will not be '
951- 'handled specially when using averaged '
952- 'precision/recall/F-score. '
953- 'Please use average=\' binary\' to report only the '
954- 'positive class performance.' , DeprecationWarning )
955- if pos_label not in labels :
956- if len (labels ) == 1 :
957- # Only negative labels
958- return (0. , 0. , 0. , 0 )
959- else :
960- raise ValueError ("pos_label=%r is not a valid label: %r" %
961- (pos_label , labels ))
962- pos_label_idx = labels == pos_label
963- tp_sum = tp_sum [pos_label_idx ]
964- pred_sum = pred_sum [pos_label_idx ]
965- true_sum = true_sum [pos_label_idx ]
966-
967- elif average == 'micro' :
1003+ if average == 'micro' :
9681004 tp_sum = np .array ([tp_sum .sum ()])
9691005 pred_sum = np .array ([pred_sum .sum ()])
9701006 true_sum = np .array ([true_sum .sum ()])
@@ -1004,12 +1040,6 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None,
10041040 recall = np .average (recall , weights = weights )
10051041 f_score = np .average (f_score , weights = weights )
10061042 true_sum = None # return no support
1007- elif label_order is not None :
1008- indices = np .searchsorted (labels , label_order )
1009- precision = precision [indices ]
1010- recall = recall [indices ]
1011- f_score = f_score [indices ]
1012- true_sum = true_sum [indices ]
10131043
10141044 return precision , recall , f_score , true_sum
10151045
@@ -1035,8 +1065,14 @@ def precision_score(y_true, y_pred, labels=None, pos_label=1,
10351065 y_pred : 1d array-like, or label indicator array / sparse matrix
10361066 Estimated targets as returned by a classifier.
10371067
1038- labels : array
1039- Integer array of labels.
1068+ labels : list, optional
1069+ The set of labels to include when ``average != 'binary'``, and their
1070+ order if ``average is None``. Labels present in the data can be
1071+ excluded, for example to calculate a multiclass average ignoring a
1072+ majority negative class, while labels not present in the data will
1073+ result in 0 components in a macro average. For multilabel targets,
1074+ labels are column indices. By default, all labels in ``y_true`` and
1075+ ``y_pred`` are used in sorted order.
10401076
10411077 pos_label : str or int, 1 by default
10421078 The class to report if ``average='binary'``. Until version 0.18 it is
@@ -1128,8 +1164,14 @@ def recall_score(y_true, y_pred, labels=None, pos_label=1, average='binary',
11281164 y_pred : 1d array-like, or label indicator array / sparse matrix
11291165 Estimated targets as returned by a classifier.
11301166
1131- labels : array
1132- Integer array of labels.
1167+ labels : list, optional
1168+ The set of labels to include when ``average != 'binary'``, and their
1169+ order if ``average is None``. Labels present in the data can be
1170+ excluded, for example to calculate a multiclass average ignoring a
1171+ majority negative class, while labels not present in the data will
1172+ result in 0 components in a macro average. For multilabel targets,
1173+ labels are column indices. By default, all labels in ``y_true`` and
1174+ ``y_pred`` are used in sorted order.
11331175
11341176 pos_label : str or int, 1 by default
11351177 The class to report if ``average='binary'``. Until version 0.18 it is
0 commit comments