Skip to content

Commit c1c942e

Browse files
authored
[Fix] Fix binary segmentation when num_classes==1 (open-mmlab#2016)
* fix binary * add ut * fix ut * restore metric computation * remove metric ut update * set out_channels by num_classes * replace num_classes in encoder_decoder * update props setting and fix ut * update ut * minor change * update warning
1 parent d8ea8f7 commit c1c942e

File tree

5 files changed

+67
-7
lines changed

5 files changed

+67
-7
lines changed

mmseg/models/decode_heads/decode_head.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
import warnings
23
from abc import ABCMeta, abstractmethod
34

45
import torch
@@ -18,6 +19,9 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
1819
in_channels (int|Sequence[int]): Input channels.
1920
channels (int): Channels after modules, before conv_seg.
2021
num_classes (int): Number of classes.
22+
out_channels (int): Output channels of conv_seg.
23+
threshold (float): Threshold for binary segmentation in the case of
24+
`num_classes==1`. Default: None.
2125
dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
2226
conv_cfg (dict|None): Config of conv layers. Default: None.
2327
norm_cfg (dict|None): Config of norm layers. Default: None.
@@ -56,6 +60,8 @@ def __init__(self,
5660
channels,
5761
*,
5862
num_classes,
63+
out_channels=None,
64+
threshold=None,
5965
dropout_ratio=0.1,
6066
conv_cfg=None,
6167
norm_cfg=None,
@@ -74,7 +80,6 @@ def __init__(self,
7480
super(BaseDecodeHead, self).__init__(init_cfg)
7581
self._init_inputs(in_channels, in_index, input_transform)
7682
self.channels = channels
77-
self.num_classes = num_classes
7883
self.dropout_ratio = dropout_ratio
7984
self.conv_cfg = conv_cfg
8085
self.norm_cfg = norm_cfg
@@ -84,6 +89,30 @@ def __init__(self,
8489
self.ignore_index = ignore_index
8590
self.align_corners = align_corners
8691

92+
if out_channels is None:
93+
if num_classes == 2:
94+
warnings.warn('For binary segmentation, we suggest using'
95+
'`out_channels = 1` to define the output'
96+
'channels of segmentor, and use `threshold`'
97+
'to convert seg_logist into a prediction'
98+
'applying a threshold')
99+
out_channels = num_classes
100+
101+
if out_channels != num_classes and out_channels != 1:
102+
raise ValueError(
103+
'out_channels should be equal to num_classes,'
104+
'except binary segmentation set out_channels == 1 and'
105+
f'num_classes == 2, but got out_channels={out_channels}'
106+
f'and num_classes={num_classes}')
107+
108+
if out_channels == 1 and threshold is None:
109+
threshold = 0.3
110+
warnings.warn('threshold is not defined for binary, and defaults'
111+
'to 0.3')
112+
self.num_classes = num_classes
113+
self.out_channels = out_channels
114+
self.threshold = threshold
115+
87116
if isinstance(loss_decode, dict):
88117
self.loss_decode = build_loss(loss_decode)
89118
elif isinstance(loss_decode, (list, tuple)):
@@ -99,7 +128,7 @@ def __init__(self,
99128
else:
100129
self.sampler = None
101130

102-
self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
131+
self.conv_seg = nn.Conv2d(channels, self.out_channels, kernel_size=1)
103132
if dropout_ratio > 0:
104133
self.dropout = nn.Dropout2d(dropout_ratio)
105134
else:

mmseg/models/segmentors/cascade_encoder_decoder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def _init_decode_head(self, decode_head):
4747
self.decode_head.append(builder.build_head(decode_head[i]))
4848
self.align_corners = self.decode_head[-1].align_corners
4949
self.num_classes = self.decode_head[-1].num_classes
50+
self.out_channels = self.decode_head[-1].out_channels
5051

5152
def encode_decode(self, img, img_metas):
5253
"""Encode images with backbone and decode into a semantic segmentation

mmseg/models/segmentors/encoder_decoder.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def _init_decode_head(self, decode_head):
4949
self.decode_head = builder.build_head(decode_head)
5050
self.align_corners = self.decode_head.align_corners
5151
self.num_classes = self.decode_head.num_classes
52+
self.out_channels = self.decode_head.out_channels
5253

5354
def _init_auxiliary_head(self, auxiliary_head):
5455
"""Initialize ``auxiliary_head``"""
@@ -162,10 +163,10 @@ def slide_inference(self, img, img_meta, rescale):
162163
h_stride, w_stride = self.test_cfg.stride
163164
h_crop, w_crop = self.test_cfg.crop_size
164165
batch_size, _, h_img, w_img = img.size()
165-
num_classes = self.num_classes
166+
out_channels = self.out_channels
166167
h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
167168
w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
168-
preds = img.new_zeros((batch_size, num_classes, h_img, w_img))
169+
preds = img.new_zeros((batch_size, out_channels, h_img, w_img))
169170
count_mat = img.new_zeros((batch_size, 1, h_img, w_img))
170171
for h_idx in range(h_grids):
171172
for w_idx in range(w_grids):
@@ -245,7 +246,10 @@ def inference(self, img, img_meta, rescale):
245246
seg_logit = self.slide_inference(img, img_meta, rescale)
246247
else:
247248
seg_logit = self.whole_inference(img, img_meta, rescale)
248-
output = F.softmax(seg_logit, dim=1)
249+
if self.out_channels == 1:
250+
output = F.sigmoid(seg_logit)
251+
else:
252+
output = F.softmax(seg_logit, dim=1)
249253
flip = img_meta[0]['flip']
250254
if flip:
251255
flip_direction = img_meta[0]['flip_direction']
@@ -260,7 +264,11 @@ def inference(self, img, img_meta, rescale):
260264
def simple_test(self, img, img_meta, rescale=True):
261265
"""Simple test with single image."""
262266
seg_logit = self.inference(img, img_meta, rescale)
263-
seg_pred = seg_logit.argmax(dim=1)
267+
if self.out_channels == 1:
268+
seg_pred = (seg_logit >
269+
self.decode_head.threshold).to(seg_logit).squeeze(1)
270+
else:
271+
seg_pred = seg_logit.argmax(dim=1)
264272
if torch.onnx.is_in_onnx_export():
265273
# our inference backend only support 4D output
266274
seg_pred = seg_pred.unsqueeze(0)
@@ -283,7 +291,11 @@ def aug_test(self, imgs, img_metas, rescale=True):
283291
cur_seg_logit = self.inference(imgs[i], img_metas[i], rescale)
284292
seg_logit += cur_seg_logit
285293
seg_logit /= len(imgs)
286-
seg_pred = seg_logit.argmax(dim=1)
294+
if self.out_channels == 1:
295+
seg_pred = (seg_logit >
296+
self.decode_head.threshold).to(seg_logit).squeeze(1)
297+
else:
298+
seg_pred = seg_logit.argmax(dim=1)
287299
seg_pred = seg_pred.cpu().numpy()
288300
# unravel batch dim
289301
seg_pred = list(seg_pred)

tests/test_models/test_heads/test_decode_head.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,18 @@ def test_decode_head():
4343
in_index=[-1],
4444
input_transform='resize_concat')
4545

46+
with pytest.raises(ValueError):
47+
# out_channels should be equal to num_classes
48+
BaseDecodeHead(32, 16, num_classes=19, out_channels=18)
49+
50+
# test out_channels
51+
head = BaseDecodeHead(32, 16, num_classes=2)
52+
assert head.out_channels == 2
53+
54+
# test out_channels == 1 and num_classes == 2
55+
head = BaseDecodeHead(32, 16, num_classes=2, out_channels=1)
56+
assert head.out_channels == 1 and head.num_classes == 2
57+
4658
# test default dropout
4759
head = BaseDecodeHead(32, 16, num_classes=19)
4860
assert hasattr(head, 'dropout') and head.dropout.p == 0.1

tests/test_models/test_segmentors/test_encoder_decoder.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@ def test_encoder_decoder():
1818
segmentor = build_segmentor(cfg)
1919
_segmentor_forward_train_test(segmentor)
2020

21+
# test out_channels == 1
22+
segmentor.out_channels = 1
23+
segmentor.decode_head.out_channels = 1
24+
segmentor.decode_head.threshold = 0.3
25+
_segmentor_forward_train_test(segmentor)
26+
2127
# test slide mode
2228
cfg.test_cfg = ConfigDict(mode='slide', crop_size=(3, 3), stride=(2, 2))
2329
segmentor = build_segmentor(cfg)

0 commit comments

Comments
 (0)