Skip to content

Commit aa438f5

Browse files
authored
[Fix] The interface multiscale_output is defined but not used (open-mmlab#830)
* Add interface multiscale_output * Add space between args and their types * Fix default value
1 parent 4ca42a3 commit aa438f5

File tree

2 files changed

+129
-23
lines changed

2 files changed

+129
-23
lines changed

mmseg/models/backbones/hrnet.py

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -218,26 +218,41 @@ def forward(self, x):
218218
class HRNet(BaseModule):
219219
"""HRNet backbone.
220220
221-
High-Resolution Representations for Labeling Pixels and Regions
222-
arXiv: https://arxiv.org/abs/1904.04514
221+
`High-Resolution Representations for Labeling Pixels and Regions
222+
arXiv: <https://arxiv.org/abs/1904.04514>`_.
223223
224224
Args:
225-
extra (dict): detailed configuration for each stage of HRNet.
225+
extra (dict): Detailed configuration for each stage of HRNet.
226+
There must be 4 stages, the configuration for each stage must have
227+
5 keys:
228+
229+
- num_modules (int): The number of HRModule in this stage.
230+
- num_branches (int): The number of branches in the HRModule.
231+
- block (str): The type of convolution block.
232+
- num_blocks (tuple): The number of blocks in each branch.
233+
The length must be equal to num_branches.
234+
- num_channels (tuple): The number of channels in each branch.
235+
The length must be equal to num_branches.
226236
in_channels (int): Number of input image channels. Normally 3.
227-
conv_cfg (dict): dictionary to construct and config conv layer.
228-
norm_cfg (dict): dictionary to construct and config norm layer.
237+
conv_cfg (dict): Dictionary to construct and config conv layer.
238+
Default: None.
239+
norm_cfg (dict): Dictionary to construct and config norm layer.
240+
Use `BN` by default.
229241
norm_eval (bool): Whether to set norm layers to eval mode, namely,
230242
freeze running stats (mean and var). Note: Effect on Batch Norm
231-
and its variants only.
243+
and its variants only. Default: False.
232244
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
233-
memory while slowing down the training speed.
245+
memory while slowing down the training speed. Default: False.
234246
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
235247
-1 means not freezing any parameters. Default: -1.
236-
zero_init_residual (bool): whether to use zero init for last norm layer
237-
in resblocks to let them behave as identity.
238-
pretrained (str, optional): model pretrained path. Default: None
248+
zero_init_residual (bool): Whether to use zero init for last norm layer
249+
in resblocks to let them behave as identity. Default: False.
250+
multiscale_output (bool): Whether to output multi-level features
251+
produced by multiple branches. If False, only the first level
252+
feature will be output. Default: True.
253+
pretrained (str, optional): Model pretrained path. Default: None.
239254
init_cfg (dict or list[dict], optional): Initialization config dict.
240-
Default: None
255+
Default: None.
241256
242257
Example:
243258
>>> from mmseg.models import HRNet
@@ -290,6 +305,7 @@ def __init__(self,
290305
with_cp=False,
291306
frozen_stages=-1,
292307
zero_init_residual=False,
308+
multiscale_output=True,
293309
pretrained=None,
294310
init_cfg=None):
295311
super(HRNet, self).__init__(init_cfg)
@@ -299,7 +315,7 @@ def __init__(self,
299315
assert not (init_cfg and pretrained), \
300316
'init_cfg and pretrained cannot be setting at the same time'
301317
if isinstance(pretrained, str):
302-
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
318+
warnings.warn('DeprecationWarning: pretrained is deprecated, '
303319
'please use "init_cfg" instead')
304320
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
305321
elif pretrained is None:
@@ -314,6 +330,16 @@ def __init__(self,
314330
else:
315331
raise TypeError('pretrained must be a str or None')
316332

333+
# Assert configurations of 4 stages are in extra
334+
assert 'stage1' in extra and 'stage2' in extra \
335+
and 'stage3' in extra and 'stage4' in extra
336+
# Assert whether the length of `num_blocks` and `num_channels` are
337+
# equal to `num_branches`
338+
for i in range(4):
339+
cfg = extra[f'stage{i + 1}']
340+
assert len(cfg['num_blocks']) == cfg['num_branches'] and \
341+
len(cfg['num_channels']) == cfg['num_branches']
342+
317343
self.extra = extra
318344
self.conv_cfg = conv_cfg
319345
self.norm_cfg = norm_cfg
@@ -391,7 +417,7 @@ def __init__(self,
391417
self.transition3 = self._make_transition_layer(pre_stage_channels,
392418
num_channels)
393419
self.stage4, pre_stage_channels = self._make_stage(
394-
self.stage4_cfg, num_channels)
420+
self.stage4_cfg, num_channels, multiscale_output=multiscale_output)
395421

396422
self._freeze_stages()
397423

tests/test_models/test_backbones/test_hrnet.py

Lines changed: 90 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,59 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
import pytest
3+
import torch
24
from mmcv.utils.parrots_wrapper import _BatchNorm
35

4-
from mmseg.models.backbones import HRNet
6+
from mmseg.models.backbones.hrnet import HRModule, HRNet
7+
from mmseg.models.backbones.resnet import BasicBlock, Bottleneck
58

69

7-
def test_hrnet_backbone():
8-
# Test HRNET with two stage frozen
10+
@pytest.mark.parametrize('block', [BasicBlock, Bottleneck])
11+
def test_hrmodule(block):
12+
# Test multiscale forward
13+
num_channles = (32, 64)
14+
in_channels = [c * block.expansion for c in num_channles]
15+
hrmodule = HRModule(
16+
num_branches=2,
17+
blocks=block,
18+
in_channels=in_channels,
19+
num_blocks=(4, 4),
20+
num_channels=num_channles,
21+
)
22+
23+
feats = [
24+
torch.randn(1, in_channels[0], 64, 64),
25+
torch.randn(1, in_channels[1], 32, 32)
26+
]
27+
feats = hrmodule(feats)
28+
29+
assert len(feats) == 2
30+
assert feats[0].shape == torch.Size([1, in_channels[0], 64, 64])
31+
assert feats[1].shape == torch.Size([1, in_channels[1], 32, 32])
32+
33+
# Test single scale forward
34+
num_channles = (32, 64)
35+
in_channels = [c * block.expansion for c in num_channles]
36+
hrmodule = HRModule(
37+
num_branches=2,
38+
blocks=block,
39+
in_channels=in_channels,
40+
num_blocks=(4, 4),
41+
num_channels=num_channles,
42+
multiscale_output=False,
43+
)
44+
45+
feats = [
46+
torch.randn(1, in_channels[0], 64, 64),
47+
torch.randn(1, in_channels[1], 32, 32)
48+
]
49+
feats = hrmodule(feats)
950

51+
assert len(feats) == 1
52+
assert feats[0].shape == torch.Size([1, in_channels[0], 64, 64])
53+
54+
55+
def test_hrnet_backbone():
56+
# only have 3 stages
1057
extra = dict(
1158
stage1=dict(
1259
num_modules=1,
@@ -25,13 +72,46 @@ def test_hrnet_backbone():
2572
num_branches=3,
2673
block='BASIC',
2774
num_blocks=(4, 4, 4),
28-
num_channels=(32, 64, 128)),
29-
stage4=dict(
30-
num_modules=3,
31-
num_branches=4,
32-
block='BASIC',
33-
num_blocks=(4, 4, 4, 4),
34-
num_channels=(32, 64, 128, 256)))
75+
num_channels=(32, 64, 128)))
76+
77+
with pytest.raises(AssertionError):
78+
# HRNet now only support 4 stages
79+
HRNet(extra=extra)
80+
extra['stage4'] = dict(
81+
num_modules=3,
82+
num_branches=3, # should be 4
83+
block='BASIC',
84+
num_blocks=(4, 4, 4, 4),
85+
num_channels=(32, 64, 128, 256))
86+
87+
with pytest.raises(AssertionError):
88+
# len(num_blocks) should equal num_branches
89+
HRNet(extra=extra)
90+
91+
extra['stage4']['num_branches'] = 4
92+
93+
# Test hrnetv2p_w32
94+
model = HRNet(extra=extra)
95+
model.init_weights()
96+
model.train()
97+
98+
imgs = torch.randn(1, 3, 256, 256)
99+
feats = model(imgs)
100+
assert len(feats) == 4
101+
assert feats[0].shape == torch.Size([1, 32, 64, 64])
102+
assert feats[3].shape == torch.Size([1, 256, 8, 8])
103+
104+
# Test single scale output
105+
model = HRNet(extra=extra, multiscale_output=False)
106+
model.init_weights()
107+
model.train()
108+
109+
imgs = torch.randn(1, 3, 256, 256)
110+
feats = model(imgs)
111+
assert len(feats) == 1
112+
assert feats[0].shape == torch.Size([1, 32, 64, 64])
113+
114+
# Test HRNET with two stage frozen
35115
frozen_stages = 2
36116
model = HRNet(extra, frozen_stages=frozen_stages)
37117
model.init_weights()

0 commit comments

Comments
 (0)