Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion mmseg/core/seg/sampler/ohem_pixel_sampler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F

from ..builder import PIXEL_SAMPLERS
Expand Down Expand Up @@ -62,14 +63,19 @@ def sample(self, seg_logit, seg_label):
threshold = max(min_threshold, self.thresh)
valid_seg_weight[seg_prob[valid_mask] < threshold] = 1.
else:
if not isinstance(self.context.loss_decode, nn.ModuleList):
losses_decode = [self.context.loss_decode]
else:
losses_decode = self.context.loss_decode
losses = 0.0
for loss_module in self.context.loss_decode:
for loss_module in losses_decode:
losses += loss_module(
seg_logit,
seg_label,
weight=None,
ignore_index=self.context.ignore_index,
reduction_override='none')

# faster than topk according to https://github.com/pytorch/pytorch/issues/22812 # noqa
_, sort_indices = losses[valid_mask].sort(descending=True)
valid_seg_weight[sort_indices[:batch_kept]] = 1.
Expand Down
11 changes: 8 additions & 3 deletions mmseg/models/decode_heads/decode_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,11 @@ def __init__(self,

self.ignore_index = ignore_index
self.align_corners = align_corners
self.loss_decode = nn.ModuleList()

if isinstance(loss_decode, dict):
self.loss_decode.append(build_loss(loss_decode))
self.loss_decode = build_loss(loss_decode)
elif isinstance(loss_decode, (list, tuple)):
self.loss_decode = nn.ModuleList()
for loss in loss_decode:
self.loss_decode.append(build_loss(loss))
else:
Expand Down Expand Up @@ -242,7 +242,12 @@ def losses(self, seg_logit, seg_label):
else:
seg_weight = None
seg_label = seg_label.squeeze(1)
for loss_decode in self.loss_decode:

if not isinstance(self.loss_decode, nn.ModuleList):
losses_decode = [self.loss_decode]
else:
losses_decode = self.loss_decode
for loss_decode in losses_decode:
if loss_decode.loss_name not in loss:
loss[loss_decode.loss_name] = loss_decode(
seg_logit,
Expand Down
7 changes: 6 additions & 1 deletion mmseg/models/decode_heads/point_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,14 @@ def forward_test(self, inputs, prev_output, img_metas, test_cfg):
def losses(self, point_logits, point_label):
"""Compute segmentation loss."""
loss = dict()
for loss_module in self.loss_decode:
if not isinstance(self.loss_decode, nn.ModuleList):
losses_decode = [self.loss_decode]
else:
losses_decode = self.loss_decode
for loss_module in losses_decode:
loss['point' + loss_module.loss_name] = loss_module(
point_logits, point_label, ignore_index=self.ignore_index)

loss['acc_point'] = accuracy(point_logits, point_label)
return loss

Expand Down
38 changes: 38 additions & 0 deletions tests/test_models/test_heads/test_point_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,41 @@ def test_point_head():
subdivision_steps=2, subdivision_num_points=8196, scale_factor=2)
output = point_head.forward_test(inputs, prev_output, None, test_cfg)
assert output.shape == (1, point_head.num_classes, 180, 180)

# test multiple losses case
inputs = [torch.randn(1, 32, 45, 45)]
point_head_multiple_losses = PointHead(
in_channels=[32],
in_index=[0],
channels=16,
num_classes=19,
loss_decode=[
dict(type='CrossEntropyLoss', loss_name='loss_1'),
dict(type='CrossEntropyLoss', loss_name='loss_2')
])
assert len(point_head_multiple_losses.fcs) == 3
fcn_head_multiple_losses = FCNHead(
in_channels=32,
channels=16,
num_classes=19,
loss_decode=[
dict(type='CrossEntropyLoss', loss_name='loss_1'),
dict(type='CrossEntropyLoss', loss_name='loss_2')
])
if torch.cuda.is_available():
head, inputs = to_cuda(point_head_multiple_losses, inputs)
head, inputs = to_cuda(fcn_head_multiple_losses, inputs)
prev_output = fcn_head_multiple_losses(inputs)
test_cfg = ConfigDict(
subdivision_steps=2, subdivision_num_points=8196, scale_factor=2)
output = point_head_multiple_losses.forward_test(inputs, prev_output, None,
test_cfg)
assert output.shape == (1, point_head.num_classes, 180, 180)

fake_label = torch.ones([1, 180, 180], dtype=torch.long)

if torch.cuda.is_available():
fake_label = fake_label.cuda()
loss = point_head_multiple_losses.losses(output, fake_label)
assert 'pointloss_1' in loss
assert 'pointloss_2' in loss
39 changes: 39 additions & 0 deletions tests/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,17 @@ def _context_for_ohem():
return FCNHead(in_channels=32, channels=16, num_classes=19)


def _context_for_ohem_multiple_loss():
return FCNHead(
in_channels=32,
channels=16,
num_classes=19,
loss_decode=[
dict(type='CrossEntropyLoss', loss_name='loss_1'),
dict(type='CrossEntropyLoss', loss_name='loss_2')
])


def test_ohem_sampler():

with pytest.raises(AssertionError):
Expand Down Expand Up @@ -37,3 +48,31 @@ def test_ohem_sampler():
assert seg_weight.shape[0] == seg_logit.shape[0]
assert seg_weight.shape[1:] == seg_logit.shape[2:]
assert seg_weight.sum() == 200

# test multiple losses case
with pytest.raises(AssertionError):
# seg_logit and seg_label must be of the same size
sampler = OHEMPixelSampler(context=_context_for_ohem_multiple_loss())
seg_logit = torch.randn(1, 19, 45, 45)
seg_label = torch.randint(0, 19, size=(1, 1, 89, 89))
sampler.sample(seg_logit, seg_label)

# test with thresh in multiple losses case
sampler = OHEMPixelSampler(
context=_context_for_ohem_multiple_loss(), thresh=0.7, min_kept=200)
seg_logit = torch.randn(1, 19, 45, 45)
seg_label = torch.randint(0, 19, size=(1, 1, 45, 45))
seg_weight = sampler.sample(seg_logit, seg_label)
assert seg_weight.shape[0] == seg_logit.shape[0]
assert seg_weight.shape[1:] == seg_logit.shape[2:]
assert seg_weight.sum() > 200

# test w.o thresh in multiple losses case
sampler = OHEMPixelSampler(
context=_context_for_ohem_multiple_loss(), min_kept=200)
seg_logit = torch.randn(1, 19, 45, 45)
seg_label = torch.randint(0, 19, size=(1, 1, 45, 45))
seg_weight = sampler.sample(seg_logit, seg_label)
assert seg_weight.shape[0] == seg_logit.shape[0]
assert seg_weight.shape[1:] == seg_logit.shape[2:]
assert seg_weight.sum() == 200