@@ -32,17 +32,25 @@ def cross_entropy(pred,
3232 return loss
3333
3434
35- def _expand_onehot_labels (labels , label_weights , label_channels ):
35+ def _expand_onehot_labels (labels , label_weights , target_shape , ignore_index ):
3636 """Expand onehot labels to match the size of prediction."""
37- bin_labels = labels .new_full ((labels .size (0 ), label_channels ), 0 )
38- inds = torch .nonzero (labels >= 1 , as_tuple = False ).squeeze ()
39- if inds .numel () > 0 :
40- bin_labels [inds , labels [inds ] - 1 ] = 1
37+ bin_labels = labels .new_zeros (target_shape )
38+ valid_mask = (labels >= 0 ) & (labels != ignore_index )
39+ inds = torch .nonzero (valid_mask , as_tuple = True )
40+
41+ if inds [0 ].numel () > 0 :
42+ if labels .dim () == 3 :
43+ bin_labels [inds [0 ], labels [valid_mask ], inds [1 ], inds [2 ]] = 1
44+ else :
45+ bin_labels [inds [0 ], labels [valid_mask ]] = 1
46+
47+ valid_mask = valid_mask .unsqueeze (1 ).expand (target_shape ).float ()
4148 if label_weights is None :
42- bin_label_weights = None
49+ bin_label_weights = valid_mask
4350 else :
44- bin_label_weights = label_weights .view (- 1 , 1 ).expand (
45- label_weights .size (0 ), label_channels )
51+ bin_label_weights = label_weights .unsqueeze (1 ).expand (target_shape )
52+ bin_label_weights *= valid_mask
53+
4654 return bin_labels , bin_label_weights
4755
4856
@@ -51,7 +59,8 @@ def binary_cross_entropy(pred,
5159 weight = None ,
5260 reduction = 'mean' ,
5361 avg_factor = None ,
54- class_weight = None ):
62+ class_weight = None ,
63+ ignore_index = 255 ):
5564 """Calculate the binary CrossEntropy loss.
5665
5766 Args:
@@ -63,18 +72,24 @@ def binary_cross_entropy(pred,
6372 avg_factor (int, optional): Average factor that is used to average
6473 the loss. Defaults to None.
6574 class_weight (list[float], optional): The weight for each class.
75+ ignore_index (int | None): The label index to be ignored. Default: 255
6676
6777 Returns:
6878 torch.Tensor: The calculated loss
6979 """
7080 if pred .dim () != label .dim ():
71- label , weight = _expand_onehot_labels (label , weight , pred .size (- 1 ))
81+ assert (pred .dim () == 2 and label .dim () == 1 ) or (
82+ pred .dim () == 4 and label .dim () == 3 ), \
83+ 'Only pred shape [N, C], label shape [N] or pred shape [N, C, ' \
84+ 'H, W], label shape [N, H, W] are supported'
85+ label , weight = _expand_onehot_labels (label , weight , pred .shape ,
86+ ignore_index )
7287
7388 # weighted element-wise losses
7489 if weight is not None :
7590 weight = weight .float ()
7691 loss = F .binary_cross_entropy_with_logits (
77- pred , label .float (), weight = class_weight , reduction = 'none' )
92+ pred , label .float (), pos_weight = class_weight , reduction = 'none' )
7893 # do the reduction for the weighted loss
7994 loss = weight_reduce_loss (
8095 loss , weight , reduction = reduction , avg_factor = avg_factor )
@@ -87,7 +102,8 @@ def mask_cross_entropy(pred,
87102 label ,
88103 reduction = 'mean' ,
89104 avg_factor = None ,
90- class_weight = None ):
105+ class_weight = None ,
106+ ignore_index = None ):
91107 """Calculate the CrossEntropy loss for masks.
92108
93109 Args:
@@ -103,10 +119,13 @@ def mask_cross_entropy(pred,
103119 avg_factor (int, optional): Average factor that is used to average
104120 the loss. Defaults to None.
105121 class_weight (list[float], optional): The weight for each class.
122+ ignore_index (None): Placeholder, to be consistent with other loss.
123+ Default: None.
106124
107125 Returns:
108126 torch.Tensor: The calculated loss
109127 """
128+ assert ignore_index is None , 'BCE loss does not support ignore_index'
110129 # TODO: handle these two reserved arguments
111130 assert reduction == 'mean' and avg_factor is None
112131 num_rois = pred .size ()[0 ]
0 commit comments