22import torch .nn as nn
33
44
5- def accuracy (pred , target , topk = 1 , thresh = None ):
5+ def accuracy (pred , target , topk = 1 , thresh = None , ignore_index = None ):
66 """Calculate accuracy according to the prediction and target.
77
88 Args:
99 pred (torch.Tensor): The model prediction, shape (N, num_class, ...)
1010 target (torch.Tensor): The target of each prediction, shape (N, , ...)
11+ ignore_index (int | None): The label index to be ignored. Default: None
1112 topk (int | tuple[int], optional): If the predictions in ``topk``
1213 matches the target, the predictions will be regarded as
1314 correct ones. Defaults to 1.
@@ -43,17 +44,19 @@ def accuracy(pred, target, topk=1, thresh=None):
4344 if thresh is not None :
4445 # Only prediction values larger than thresh are counted as correct
4546 correct = correct & (pred_value > thresh ).t ()
47+ correct = correct [:, target != ignore_index ]
4648 res = []
4749 for k in topk :
4850 correct_k = correct [:k ].reshape (- 1 ).float ().sum (0 , keepdim = True )
49- res .append (correct_k .mul_ (100.0 / target .numel ()))
51+ res .append (
52+ correct_k .mul_ (100.0 / target [target != ignore_index ].numel ()))
5053 return res [0 ] if return_single else res
5154
5255
5356class Accuracy (nn .Module ):
5457 """Accuracy calculation module."""
5558
56- def __init__ (self , topk = (1 , ), thresh = None ):
59+ def __init__ (self , topk = (1 , ), thresh = None , ignore_index = None ):
5760 """Module to calculate the accuracy.
5861
5962 Args:
@@ -65,6 +68,7 @@ def __init__(self, topk=(1, ), thresh=None):
6568 super ().__init__ ()
6669 self .topk = topk
6770 self .thresh = thresh
71+ self .ignore_index = ignore_index
6872
6973 def forward (self , pred , target ):
7074 """Forward function to calculate accuracy.
@@ -76,4 +80,5 @@ def forward(self, pred, target):
7680 Returns:
7781 tuple[float]: The accuracies under different topk criterions.
7882 """
79- return accuracy (pred , target , self .topk , self .thresh )
83+ return accuracy (pred , target , self .topk , self .thresh ,
84+ self .ignore_index )
0 commit comments