Skip to content

Commit a82ebad

Browse files
authored
[Fix] Add avg_non_ignore in cross entropy loss (open-mmlab#1409)
* [Fix] Add avg_non_ignore in cross entropy loss * [Fix] Add avg_non_ignore in cross entropy loss * add docstring * fix ut * fix docstring and comments * fix * fix bce * fix avg_factor in BCE and add more ut * add avg_non_ignore * add more ut * fix part of ut * fix part of ut * test avg_non_ignore would not affect ce/bce when reduction none/sum * test avg_non_ignore would not affect ce/bce when reduction none/sum/mean * re-organize ut * re-organize ut * re-organize ut * re-organize hardcode case * fix parts of comments * fix another parts of comments * fix
1 parent 24f1563 commit a82ebad

File tree

5 files changed

+289
-36
lines changed

5 files changed

+289
-36
lines changed

docs/en/tutorials/training_tricks.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,23 @@ model = dict(
6868
In this way, `loss_weight` and `loss_name` will be weight and name in training log of corresponding loss, respectively.
6969

7070
Note: If you want this loss item to be included into the backward graph, `loss_` must be the prefix of the name.
71+
72+
## Ignore specified label index in loss calculation
73+
74+
In default setting, `avg_non_ignore=False` which means each pixel counts for loss calculation although some of them belong to ignore-index labels.
75+
76+
For loss calculation, we support ignore index of certain label by `avg_non_ignore` and `ignore_index`. In this way, the average loss would only be calculated in non-ignored labels which may achieve better performance, and here is the [reference](https://github.com/open-mmlab/mmsegmentation/pull/1409). Here is an example config of training `unet` on `Cityscapes` dataset: in loss calculation it would ignore label 0 which is background and loss average is only calculated on non-ignore labels:
77+
78+
```python
79+
_base_ = './fcn_unet_s5-d16_4x4_512x1024_160k_cityscapes.py'
80+
model = dict(
81+
decode_head=dict(
82+
ignore_index=0,
83+
loss_decode=dict(
84+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0, avg_non_ignore=True),
85+
auxiliary_head=dict(
86+
ignore_index=0,
87+
loss_decode=dict(
88+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0, avg_non_ignore=True)),
89+
))
90+
```

docs/zh_cn/tutorials/training_tricks.md

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,28 @@ model = dict(
6868
通过这种方式,确定训练过程中损失函数的权重 `loss_weight` 和在训练日志里的名字 `loss_name`
6969

7070
注意: `loss_name` 的名字必须带有 `loss_` 前缀,这样它才能被包括在反传的图里。
71+
72+
## 在损失函数中忽略特定的 label 类别
73+
74+
默认设置 `avg_non_ignore=False`, 即每个像素都用来计算损失函数。尽管其中的一些像素属于需要被忽略的类别。
75+
76+
对于训练时损失函数的计算,我们目前支持使用 `avg_non_ignore``ignore_index` 来忽略 label 特定的类别。 这样损失函数将只在非忽略类别像素中求平均值,会获得更好的表现。这里是[相关 PR](https://github.com/open-mmlab/mmsegmentation/pull/1409)。以 `unet` 使用 `Cityscapes` 数据集训练为例,
77+
在计算损失函数时,忽略 label 为0的背景,并且仅在不被忽略的像素上计算均值。配置文件写为:
78+
79+
```python
80+
_base_ = './fcn_unet_s5-d16_4x4_512x1024_160k_cityscapes.py'
81+
model = dict(
82+
decode_head=dict(
83+
ignore_index=0,
84+
loss_decode=dict(
85+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0, avg_non_ignore=True),
86+
auxiliary_head=dict(
87+
ignore_index=0,
88+
loss_decode=dict(
89+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0, avg_non_ignore=True)),
90+
))
91+
```
92+
93+
通过这种方式,确定训练过程中损失函数的权重 `loss_weight` 和在训练日志里的名字 `loss_name`
94+
95+
注意: `loss_name` 的名字必须带有 `loss_` 前缀,这样它才能被包括在反传的图里。

mmseg/models/losses/cross_entropy_loss.py

Lines changed: 80 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
import warnings
3+
24
import torch
35
import torch.nn as nn
46
import torch.nn.functional as F
@@ -13,8 +15,31 @@ def cross_entropy(pred,
1315
class_weight=None,
1416
reduction='mean',
1517
avg_factor=None,
16-
ignore_index=-100):
17-
"""The wrapper function for :func:`F.cross_entropy`"""
18+
ignore_index=-100,
19+
avg_non_ignore=False):
20+
"""cross_entropy. The wrapper function for :func:`F.cross_entropy`
21+
22+
Args:
23+
pred (torch.Tensor): The prediction with shape (N, 1).
24+
label (torch.Tensor): The learning label of the prediction.
25+
weight (torch.Tensor, optional): Sample-wise loss weight.
26+
Default: None.
27+
class_weight (list[float], optional): The weight for each class.
28+
Default: None.
29+
reduction (str, optional): The method used to reduce the loss.
30+
Options are 'none', 'mean' and 'sum'. Default: 'mean'.
31+
avg_factor (int, optional): Average factor that is used to average
32+
the loss. Default: None.
33+
ignore_index (int): Specifies a target value that is ignored and
34+
does not contribute to the input gradients. When
35+
``avg_non_ignore `` is ``True``, and the ``reduction`` is
36+
``''mean''``, the loss is averaged over non-ignored targets.
37+
Defaults: -100.
38+
avg_non_ignore (bool): The flag decides to whether the loss is
39+
only averaged over non-ignored targets. Default: False.
40+
`New in version 0.23.0.`
41+
"""
42+
1843
# class_weight is a manual rescaling weight given to each class.
1944
# If given, has to be a Tensor of size C element-wise losses
2045
loss = F.cross_entropy(
@@ -25,6 +50,11 @@ def cross_entropy(pred,
2550
ignore_index=ignore_index)
2651

2752
# apply weights and do the reduction
53+
# average loss over non-ignored elements
54+
# pytorch's official cross_entropy average loss over non-ignored elements
55+
# refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660 # noqa
56+
if (avg_factor is None) and avg_non_ignore and reduction == 'mean':
57+
avg_factor = label.numel() - (label == ignore_index).sum().item()
2858
if weight is not None:
2959
weight = weight.float()
3060
loss = weight_reduce_loss(
@@ -46,13 +76,14 @@ def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index):
4676
bin_labels[inds[0], labels[valid_mask]] = 1
4777

4878
valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float()
79+
4980
if label_weights is None:
5081
bin_label_weights = valid_mask
5182
else:
5283
bin_label_weights = label_weights.unsqueeze(1).expand(target_shape)
5384
bin_label_weights *= valid_mask
5485

55-
return bin_labels, bin_label_weights
86+
return bin_labels, bin_label_weights, valid_mask
5687

5788

5889
def binary_cross_entropy(pred,
@@ -61,19 +92,25 @@ def binary_cross_entropy(pred,
6192
reduction='mean',
6293
avg_factor=None,
6394
class_weight=None,
64-
ignore_index=255):
95+
ignore_index=-100,
96+
avg_non_ignore=False,
97+
**kwargs):
6598
"""Calculate the binary CrossEntropy loss.
6699
67100
Args:
68101
pred (torch.Tensor): The prediction with shape (N, 1).
69102
label (torch.Tensor): The learning label of the prediction.
103+
Note: In bce loss, label < 0 is invalid.
70104
weight (torch.Tensor, optional): Sample-wise loss weight.
71105
reduction (str, optional): The method used to reduce the loss.
72106
Options are "none", "mean" and "sum".
73107
avg_factor (int, optional): Average factor that is used to average
74108
the loss. Defaults to None.
75109
class_weight (list[float], optional): The weight for each class.
76-
ignore_index (int | None): The label index to be ignored. Default: 255
110+
ignore_index (int): The label index to be ignored. Default: -100.
111+
avg_non_ignore (bool): The flag decides to whether the loss is
112+
only averaged over non-ignored targets. Default: False.
113+
`New in version 0.23.0.`
77114
78115
Returns:
79116
torch.Tensor: The calculated loss
@@ -83,12 +120,21 @@ def binary_cross_entropy(pred,
83120
pred.dim() == 4 and label.dim() == 3), \
84121
'Only pred shape [N, C], label shape [N] or pred shape [N, C, ' \
85122
'H, W], label shape [N, H, W] are supported'
86-
label, weight = _expand_onehot_labels(label, weight, pred.shape,
87-
ignore_index)
123+
# `weight` returned from `_expand_onehot_labels`
124+
# has been treated for valid (non-ignore) pixels
125+
label, weight, valid_mask = _expand_onehot_labels(
126+
label, weight, pred.shape, ignore_index)
127+
else:
128+
# should mask out the ignored elements
129+
valid_mask = ((label >= 0) & (label != ignore_index)).float()
130+
if weight is not None:
131+
weight *= valid_mask
132+
else:
133+
weight = valid_mask
134+
# average loss over non-ignored and valid elements
135+
if reduction == 'mean' and avg_factor is None and avg_non_ignore:
136+
avg_factor = valid_mask.sum().item()
88137

89-
# weighted element-wise losses
90-
if weight is not None:
91-
weight = weight.float()
92138
loss = F.binary_cross_entropy_with_logits(
93139
pred, label.float(), pos_weight=class_weight, reduction='none')
94140
# do the reduction for the weighted loss
@@ -104,7 +150,8 @@ def mask_cross_entropy(pred,
104150
reduction='mean',
105151
avg_factor=None,
106152
class_weight=None,
107-
ignore_index=None):
153+
ignore_index=None,
154+
**kwargs):
108155
"""Calculate the CrossEntropy loss for masks.
109156
110157
Args:
@@ -153,6 +200,9 @@ class CrossEntropyLoss(nn.Module):
153200
loss_name (str, optional): Name of the loss item. If you want this loss
154201
item to be included into the backward graph, `loss_` must be the
155202
prefix of the name. Defaults to 'loss_ce'.
203+
avg_non_ignore (bool): The flag decides to whether the loss is
204+
only averaged over non-ignored targets. Default: False.
205+
`New in version 0.23.0.`
156206
"""
157207

158208
def __init__(self,
@@ -161,14 +211,22 @@ def __init__(self,
161211
reduction='mean',
162212
class_weight=None,
163213
loss_weight=1.0,
164-
loss_name='loss_ce'):
214+
loss_name='loss_ce',
215+
avg_non_ignore=False):
165216
super(CrossEntropyLoss, self).__init__()
166217
assert (use_sigmoid is False) or (use_mask is False)
167218
self.use_sigmoid = use_sigmoid
168219
self.use_mask = use_mask
169220
self.reduction = reduction
170221
self.loss_weight = loss_weight
171222
self.class_weight = get_class_weight(class_weight)
223+
self.avg_non_ignore = avg_non_ignore
224+
if not self.avg_non_ignore and self.reduction == 'mean':
225+
warnings.warn(
226+
'Default ``avg_non_ignore`` is False, if you would like to '
227+
'ignore the certain label and average loss over non-ignore '
228+
'labels, which is the same with PyTorch official '
229+
'cross_entropy, set ``avg_non_ignore=True``.')
172230

173231
if self.use_sigmoid:
174232
self.cls_criterion = binary_cross_entropy
@@ -178,12 +236,18 @@ def __init__(self,
178236
self.cls_criterion = cross_entropy
179237
self._loss_name = loss_name
180238

239+
def extra_repr(self):
240+
"""Extra repr."""
241+
s = f'avg_non_ignore={self.avg_non_ignore}'
242+
return s
243+
181244
def forward(self,
182245
cls_score,
183246
label,
184247
weight=None,
185248
avg_factor=None,
186249
reduction_override=None,
250+
ignore_index=-100,
187251
**kwargs):
188252
"""Forward function."""
189253
assert reduction_override in (None, 'none', 'mean', 'sum')
@@ -193,13 +257,16 @@ def forward(self,
193257
class_weight = cls_score.new_tensor(self.class_weight)
194258
else:
195259
class_weight = None
260+
# Note: for BCE loss, label < 0 is invalid.
196261
loss_cls = self.loss_weight * self.cls_criterion(
197262
cls_score,
198263
label,
199264
weight,
200265
class_weight=class_weight,
201266
reduction=reduction,
202267
avg_factor=avg_factor,
268+
avg_non_ignore=self.avg_non_ignore,
269+
ignore_index=ignore_index,
203270
**kwargs)
204271
return loss_cls
205272

@@ -212,6 +279,7 @@ def loss_name(self):
212279
by simple sum operation. In addition, if you want this loss item to be
213280
included into the backward graph, `loss_` must be the prefix of the
214281
name.
282+
215283
Returns:
216284
str: The name of this loss item.
217285
"""

mmseg/models/losses/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import mmcv
55
import numpy as np
6+
import torch
67
import torch.nn.functional as F
78

89

@@ -69,7 +70,10 @@ def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
6970
else:
7071
# if reduction is mean, then average the loss by avg_factor
7172
if reduction == 'mean':
72-
loss = loss.sum() / avg_factor
73+
# Avoid causing ZeroDivisionError when avg_factor is 0.0,
74+
# i.e., all labels of an image belong to ignore index.
75+
eps = torch.finfo(torch.float32).eps
76+
loss = loss.sum() / (avg_factor + eps)
7377
# if reduction is 'none', then do nothing, otherwise raise an error
7478
elif reduction != 'none':
7579
raise ValueError('avg_factor can not be used with reduction="sum"')

0 commit comments

Comments
 (0)