Skip to content

Commit 349fc2d

Browse files
authored
[Fix] Change self.loss_decode back to dict in Single Loss situation. (open-mmlab#1002)
* fix single loss type * fix error in ohem & point_head * fix coverage miss * fix uncoverage error of PointHead loss * fix coverage miss * fix uncoverage error of PointHead loss * nn.modules.container.ModuleList to nn.ModuleList * more simple format * merge unittest def
1 parent 845098b commit 349fc2d

File tree

5 files changed

+98
-5
lines changed

5 files changed

+98
-5
lines changed

mmseg/core/seg/sampler/ohem_pixel_sampler.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
import torch
3+
import torch.nn as nn
34
import torch.nn.functional as F
45

56
from ..builder import PIXEL_SAMPLERS
@@ -62,14 +63,19 @@ def sample(self, seg_logit, seg_label):
6263
threshold = max(min_threshold, self.thresh)
6364
valid_seg_weight[seg_prob[valid_mask] < threshold] = 1.
6465
else:
66+
if not isinstance(self.context.loss_decode, nn.ModuleList):
67+
losses_decode = [self.context.loss_decode]
68+
else:
69+
losses_decode = self.context.loss_decode
6570
losses = 0.0
66-
for loss_module in self.context.loss_decode:
71+
for loss_module in losses_decode:
6772
losses += loss_module(
6873
seg_logit,
6974
seg_label,
7075
weight=None,
7176
ignore_index=self.context.ignore_index,
7277
reduction_override='none')
78+
7379
# faster than topk according to https://github.com/pytorch/pytorch/issues/22812 # noqa
7480
_, sort_indices = losses[valid_mask].sort(descending=True)
7581
valid_seg_weight[sort_indices[:batch_kept]] = 1.

mmseg/models/decode_heads/decode_head.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,11 @@ def __init__(self,
8383

8484
self.ignore_index = ignore_index
8585
self.align_corners = align_corners
86-
self.loss_decode = nn.ModuleList()
8786

8887
if isinstance(loss_decode, dict):
89-
self.loss_decode.append(build_loss(loss_decode))
88+
self.loss_decode = build_loss(loss_decode)
9089
elif isinstance(loss_decode, (list, tuple)):
90+
self.loss_decode = nn.ModuleList()
9191
for loss in loss_decode:
9292
self.loss_decode.append(build_loss(loss))
9393
else:
@@ -242,7 +242,12 @@ def losses(self, seg_logit, seg_label):
242242
else:
243243
seg_weight = None
244244
seg_label = seg_label.squeeze(1)
245-
for loss_decode in self.loss_decode:
245+
246+
if not isinstance(self.loss_decode, nn.ModuleList):
247+
losses_decode = [self.loss_decode]
248+
else:
249+
losses_decode = self.loss_decode
250+
for loss_decode in losses_decode:
246251
if loss_decode.loss_name not in loss:
247252
loss[loss_decode.loss_name] = loss_decode(
248253
seg_logit,

mmseg/models/decode_heads/point_head.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,9 +249,14 @@ def forward_test(self, inputs, prev_output, img_metas, test_cfg):
249249
def losses(self, point_logits, point_label):
250250
"""Compute segmentation loss."""
251251
loss = dict()
252-
for loss_module in self.loss_decode:
252+
if not isinstance(self.loss_decode, nn.ModuleList):
253+
losses_decode = [self.loss_decode]
254+
else:
255+
losses_decode = self.loss_decode
256+
for loss_module in losses_decode:
253257
loss['point' + loss_module.loss_name] = loss_module(
254258
point_logits, point_label, ignore_index=self.ignore_index)
259+
255260
loss['acc_point'] = accuracy(point_logits, point_label)
256261
return loss
257262

tests/test_models/test_heads/test_point_head.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,41 @@ def test_point_head():
2121
subdivision_steps=2, subdivision_num_points=8196, scale_factor=2)
2222
output = point_head.forward_test(inputs, prev_output, None, test_cfg)
2323
assert output.shape == (1, point_head.num_classes, 180, 180)
24+
25+
# test multiple losses case
26+
inputs = [torch.randn(1, 32, 45, 45)]
27+
point_head_multiple_losses = PointHead(
28+
in_channels=[32],
29+
in_index=[0],
30+
channels=16,
31+
num_classes=19,
32+
loss_decode=[
33+
dict(type='CrossEntropyLoss', loss_name='loss_1'),
34+
dict(type='CrossEntropyLoss', loss_name='loss_2')
35+
])
36+
assert len(point_head_multiple_losses.fcs) == 3
37+
fcn_head_multiple_losses = FCNHead(
38+
in_channels=32,
39+
channels=16,
40+
num_classes=19,
41+
loss_decode=[
42+
dict(type='CrossEntropyLoss', loss_name='loss_1'),
43+
dict(type='CrossEntropyLoss', loss_name='loss_2')
44+
])
45+
if torch.cuda.is_available():
46+
head, inputs = to_cuda(point_head_multiple_losses, inputs)
47+
head, inputs = to_cuda(fcn_head_multiple_losses, inputs)
48+
prev_output = fcn_head_multiple_losses(inputs)
49+
test_cfg = ConfigDict(
50+
subdivision_steps=2, subdivision_num_points=8196, scale_factor=2)
51+
output = point_head_multiple_losses.forward_test(inputs, prev_output, None,
52+
test_cfg)
53+
assert output.shape == (1, point_head.num_classes, 180, 180)
54+
55+
fake_label = torch.ones([1, 180, 180], dtype=torch.long)
56+
57+
if torch.cuda.is_available():
58+
fake_label = fake_label.cuda()
59+
loss = point_head_multiple_losses.losses(output, fake_label)
60+
assert 'pointloss_1' in loss
61+
assert 'pointloss_2' in loss

tests/test_sampler.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,17 @@ def _context_for_ohem():
1010
return FCNHead(in_channels=32, channels=16, num_classes=19)
1111

1212

13+
def _context_for_ohem_multiple_loss():
14+
return FCNHead(
15+
in_channels=32,
16+
channels=16,
17+
num_classes=19,
18+
loss_decode=[
19+
dict(type='CrossEntropyLoss', loss_name='loss_1'),
20+
dict(type='CrossEntropyLoss', loss_name='loss_2')
21+
])
22+
23+
1324
def test_ohem_sampler():
1425

1526
with pytest.raises(AssertionError):
@@ -37,3 +48,31 @@ def test_ohem_sampler():
3748
assert seg_weight.shape[0] == seg_logit.shape[0]
3849
assert seg_weight.shape[1:] == seg_logit.shape[2:]
3950
assert seg_weight.sum() == 200
51+
52+
# test multiple losses case
53+
with pytest.raises(AssertionError):
54+
# seg_logit and seg_label must be of the same size
55+
sampler = OHEMPixelSampler(context=_context_for_ohem_multiple_loss())
56+
seg_logit = torch.randn(1, 19, 45, 45)
57+
seg_label = torch.randint(0, 19, size=(1, 1, 89, 89))
58+
sampler.sample(seg_logit, seg_label)
59+
60+
# test with thresh in multiple losses case
61+
sampler = OHEMPixelSampler(
62+
context=_context_for_ohem_multiple_loss(), thresh=0.7, min_kept=200)
63+
seg_logit = torch.randn(1, 19, 45, 45)
64+
seg_label = torch.randint(0, 19, size=(1, 1, 45, 45))
65+
seg_weight = sampler.sample(seg_logit, seg_label)
66+
assert seg_weight.shape[0] == seg_logit.shape[0]
67+
assert seg_weight.shape[1:] == seg_logit.shape[2:]
68+
assert seg_weight.sum() > 200
69+
70+
# test w.o thresh in multiple losses case
71+
sampler = OHEMPixelSampler(
72+
context=_context_for_ohem_multiple_loss(), min_kept=200)
73+
seg_logit = torch.randn(1, 19, 45, 45)
74+
seg_label = torch.randint(0, 19, size=(1, 1, 45, 45))
75+
seg_weight = sampler.sample(seg_logit, seg_label)
76+
assert seg_weight.shape[0] == seg_logit.shape[0]
77+
assert seg_weight.shape[1:] == seg_logit.shape[2:]
78+
assert seg_weight.sum() == 200

0 commit comments

Comments
 (0)