@@ -218,26 +218,41 @@ def forward(self, x):
218218class HRNet (BaseModule ):
219219 """HRNet backbone.
220220
221- High-Resolution Representations for Labeling Pixels and Regions
222- arXiv: https://arxiv.org/abs/1904.04514
221+ ` High-Resolution Representations for Labeling Pixels and Regions
222+ arXiv: < https://arxiv.org/abs/1904.04514>`_.
223223
224224 Args:
225- extra (dict): detailed configuration for each stage of HRNet.
225+ extra (dict): Detailed configuration for each stage of HRNet.
226+ There must be 4 stages, the configuration for each stage must have
227+ 5 keys:
228+
229+ - num_modules (int): The number of HRModule in this stage.
230+ - num_branches (int): The number of branches in the HRModule.
231+ - block (str): The type of convolution block.
232+ - num_blocks (tuple): The number of blocks in each branch.
233+ The length must be equal to num_branches.
234+ - num_channels (tuple): The number of channels in each branch.
235+ The length must be equal to num_branches.
226236 in_channels (int): Number of input image channels. Normally 3.
227- conv_cfg (dict): dictionary to construct and config conv layer.
228- norm_cfg (dict): dictionary to construct and config norm layer.
237+ conv_cfg (dict): Dictionary to construct and config conv layer.
238+ Default: None.
239+ norm_cfg (dict): Dictionary to construct and config norm layer.
240+ Use `BN` by default.
229241 norm_eval (bool): Whether to set norm layers to eval mode, namely,
230242 freeze running stats (mean and var). Note: Effect on Batch Norm
231- and its variants only.
243+ and its variants only. Default: False.
232244 with_cp (bool): Use checkpoint or not. Using checkpoint will save some
233- memory while slowing down the training speed.
245+ memory while slowing down the training speed. Default: False.
234246 frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
235247 -1 means not freezing any parameters. Default: -1.
236- zero_init_residual (bool): whether to use zero init for last norm layer
237- in resblocks to let them behave as identity.
238- pretrained (str, optional): model pretrained path. Default: None
248+ zero_init_residual (bool): Whether to use zero init for last norm layer
249+ in resblocks to let them behave as identity. Default: False.
250+ multiscale_output (bool): Whether to output multi-level features
251+ produced by multiple branches. If False, only the first level
252+ feature will be output. Default: True.
253+ pretrained (str, optional): Model pretrained path. Default: None.
239254 init_cfg (dict or list[dict], optional): Initialization config dict.
240- Default: None
255+ Default: None.
241256
242257 Example:
243258 >>> from mmseg.models import HRNet
@@ -290,6 +305,7 @@ def __init__(self,
290305 with_cp = False ,
291306 frozen_stages = - 1 ,
292307 zero_init_residual = False ,
308+ multiscale_output = True ,
293309 pretrained = None ,
294310 init_cfg = None ):
295311 super (HRNet , self ).__init__ (init_cfg )
@@ -299,7 +315,7 @@ def __init__(self,
299315 assert not (init_cfg and pretrained ), \
300316 'init_cfg and pretrained cannot be setting at the same time'
301317 if isinstance (pretrained , str ):
302- warnings .warn ('DeprecationWarning: pretrained is a deprecated, '
318+ warnings .warn ('DeprecationWarning: pretrained is deprecated, '
303319 'please use "init_cfg" instead' )
304320 self .init_cfg = dict (type = 'Pretrained' , checkpoint = pretrained )
305321 elif pretrained is None :
@@ -314,6 +330,16 @@ def __init__(self,
314330 else :
315331 raise TypeError ('pretrained must be a str or None' )
316332
333+ # Assert configurations of 4 stages are in extra
334+ assert 'stage1' in extra and 'stage2' in extra \
335+ and 'stage3' in extra and 'stage4' in extra
336+ # Assert whether the length of `num_blocks` and `num_channels` are
337+ # equal to `num_branches`
338+ for i in range (4 ):
339+ cfg = extra [f'stage{ i + 1 } ' ]
340+ assert len (cfg ['num_blocks' ]) == cfg ['num_branches' ] and \
341+ len (cfg ['num_channels' ]) == cfg ['num_branches' ]
342+
317343 self .extra = extra
318344 self .conv_cfg = conv_cfg
319345 self .norm_cfg = norm_cfg
@@ -391,7 +417,7 @@ def __init__(self,
391417 self .transition3 = self ._make_transition_layer (pre_stage_channels ,
392418 num_channels )
393419 self .stage4 , pre_stage_channels = self ._make_stage (
394- self .stage4_cfg , num_channels )
420+ self .stage4_cfg , num_channels , multiscale_output = multiscale_output )
395421
396422 self ._freeze_stages ()
397423
0 commit comments