Skip to content

Commit 61e1d5c

Browse files
authored
[Enhancement] Support ignore_index for sigmoid BCE (open-mmlab#210)
* [Enhancement] Add args check for ignore_index * Support ignore_index
1 parent c2608b2 commit 61e1d5c

File tree

5 files changed

+48
-18
lines changed

5 files changed

+48
-18
lines changed

configs/_base_/models/fast_scnn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
norm_cfg=norm_cfg,
2626
align_corners=False,
2727
loss_decode=dict(
28-
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.)),
28+
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.4)),
2929
auxiliary_head=[
3030
dict(
3131
type='FCNHead',
@@ -38,7 +38,7 @@
3838
concat_input=False,
3939
align_corners=False,
4040
loss_decode=dict(
41-
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
41+
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.4)),
4242
dict(
4343
type='FCNHead',
4444
in_channels=64,
@@ -50,7 +50,7 @@
5050
concat_input=False,
5151
align_corners=False,
5252
loss_decode=dict(
53-
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
53+
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.4)),
5454
])
5555

5656
# model training and testing settings

configs/fastscnn/fast_scnn_4x8_80k_lr0.12_cityscapes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
]
55

66
# Re-config the data sampler.
7-
data = dict(samples_per_gpu=8, workers_per_gpu=4)
7+
data = dict(samples_per_gpu=2, workers_per_gpu=4)
88

99
# Re-config the optimizer.
1010
optimizer = dict(type='SGD', lr=0.12, momentum=0.9, weight_decay=4e-5)

mmseg/models/decode_heads/decode_head.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ class BaseDecodeHead(nn.Module, metaclass=ABCMeta):
3535
Default: None.
3636
loss_decode (dict): Config of decode loss.
3737
Default: dict(type='CrossEntropyLoss').
38-
ignore_index (int): The label index to be ignored. Default: 255
38+
ignore_index (int | None): The label index to be ignored. When using
39+
masked BCE loss, ignore_index should be set to None. Default: 255
3940
sampler (dict|None): The config of segmentation map sampler.
4041
Default: None.
4142
align_corners (bool): align_corners argument of F.interpolate.

mmseg/models/losses/cross_entropy_loss.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,25 @@ def cross_entropy(pred,
3232
return loss
3333

3434

35-
def _expand_onehot_labels(labels, label_weights, label_channels):
35+
def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index):
3636
"""Expand onehot labels to match the size of prediction."""
37-
bin_labels = labels.new_full((labels.size(0), label_channels), 0)
38-
inds = torch.nonzero(labels >= 1, as_tuple=False).squeeze()
39-
if inds.numel() > 0:
40-
bin_labels[inds, labels[inds] - 1] = 1
37+
bin_labels = labels.new_zeros(target_shape)
38+
valid_mask = (labels >= 0) & (labels != ignore_index)
39+
inds = torch.nonzero(valid_mask, as_tuple=True)
40+
41+
if inds[0].numel() > 0:
42+
if labels.dim() == 3:
43+
bin_labels[inds[0], labels[valid_mask], inds[1], inds[2]] = 1
44+
else:
45+
bin_labels[inds[0], labels[valid_mask]] = 1
46+
47+
valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float()
4148
if label_weights is None:
42-
bin_label_weights = None
49+
bin_label_weights = valid_mask
4350
else:
44-
bin_label_weights = label_weights.view(-1, 1).expand(
45-
label_weights.size(0), label_channels)
51+
bin_label_weights = label_weights.unsqueeze(1).expand(target_shape)
52+
bin_label_weights *= valid_mask
53+
4654
return bin_labels, bin_label_weights
4755

4856

@@ -51,7 +59,8 @@ def binary_cross_entropy(pred,
5159
weight=None,
5260
reduction='mean',
5361
avg_factor=None,
54-
class_weight=None):
62+
class_weight=None,
63+
ignore_index=255):
5564
"""Calculate the binary CrossEntropy loss.
5665
5766
Args:
@@ -63,18 +72,24 @@ def binary_cross_entropy(pred,
6372
avg_factor (int, optional): Average factor that is used to average
6473
the loss. Defaults to None.
6574
class_weight (list[float], optional): The weight for each class.
75+
ignore_index (int | None): The label index to be ignored. Default: 255
6676
6777
Returns:
6878
torch.Tensor: The calculated loss
6979
"""
7080
if pred.dim() != label.dim():
71-
label, weight = _expand_onehot_labels(label, weight, pred.size(-1))
81+
assert (pred.dim() == 2 and label.dim() == 1) or (
82+
pred.dim() == 4 and label.dim() == 3), \
83+
'Only pred shape [N, C], label shape [N] or pred shape [N, C, ' \
84+
'H, W], label shape [N, H, W] are supported'
85+
label, weight = _expand_onehot_labels(label, weight, pred.shape,
86+
ignore_index)
7287

7388
# weighted element-wise losses
7489
if weight is not None:
7590
weight = weight.float()
7691
loss = F.binary_cross_entropy_with_logits(
77-
pred, label.float(), weight=class_weight, reduction='none')
92+
pred, label.float(), pos_weight=class_weight, reduction='none')
7893
# do the reduction for the weighted loss
7994
loss = weight_reduce_loss(
8095
loss, weight, reduction=reduction, avg_factor=avg_factor)
@@ -87,7 +102,8 @@ def mask_cross_entropy(pred,
87102
label,
88103
reduction='mean',
89104
avg_factor=None,
90-
class_weight=None):
105+
class_weight=None,
106+
ignore_index=None):
91107
"""Calculate the CrossEntropy loss for masks.
92108
93109
Args:
@@ -103,10 +119,13 @@ def mask_cross_entropy(pred,
103119
avg_factor (int, optional): Average factor that is used to average
104120
the loss. Defaults to None.
105121
class_weight (list[float], optional): The weight for each class.
122+
ignore_index (None): Placeholder, to be consistent with other loss.
123+
Default: None.
106124
107125
Returns:
108126
torch.Tensor: The calculated loss
109127
"""
128+
assert ignore_index is None, 'BCE loss does not support ignore_index'
110129
# TODO: handle these two reserved arguments
111130
assert reduction == 'mean' and avg_factor is None
112131
num_rois = pred.size()[0]

tests/test_models/test_losses.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,17 @@ def test_ce_loss():
7171
loss_cls_cfg = dict(
7272
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)
7373
loss_cls = build_loss(loss_cls_cfg)
74-
assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(0.))
74+
assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(100.))
75+
76+
fake_pred = torch.full(size=(2, 21, 8, 8), fill_value=0.5)
77+
fake_label = torch.ones(2, 8, 8).long()
78+
assert torch.allclose(
79+
loss_cls(fake_pred, fake_label), torch.tensor(0.9503), atol=1e-4)
80+
fake_label[:, 0, 0] = 255
81+
assert torch.allclose(
82+
loss_cls(fake_pred, fake_label, ignore_index=255),
83+
torch.tensor(0.9354),
84+
atol=1e-4)
7585

7686
# TODO test use_mask
7787

0 commit comments

Comments
 (0)