1+ import warnings
2+
13import torch .nn as nn
2- from mmcv .cnn import (build_conv_layer , build_norm_layer , constant_init ,
3- kaiming_init )
4- from mmcv .runner import load_checkpoint
4+ from mmcv .cnn import build_conv_layer , build_norm_layer
5+ from mmcv .runner import BaseModule , ModuleList , Sequential
56from mmcv .utils .parrots_wrapper import _BatchNorm
67
78from mmseg .ops import Upsample , resize
8- from mmseg .utils import get_root_logger
99from ..builder import BACKBONES
1010from .resnet import BasicBlock , Bottleneck
1111
1212
13- class HRModule (nn . Module ):
13+ class HRModule (BaseModule ):
1414 """High-Resolution Module for HRNet.
1515
1616 In this module, every branch has 4 BasicBlocks/Bottlenecks. Fusion/Exchange
@@ -26,8 +26,11 @@ def __init__(self,
2626 multiscale_output = True ,
2727 with_cp = False ,
2828 conv_cfg = None ,
29- norm_cfg = dict (type = 'BN' , requires_grad = True )):
30- super (HRModule , self ).__init__ ()
29+ norm_cfg = dict (type = 'BN' , requires_grad = True ),
30+ block_init_cfg = None ,
31+ init_cfg = None ):
32+ super (HRModule , self ).__init__ (init_cfg )
33+ self .block_init_cfg = block_init_cfg
3134 self ._check_branches (num_branches , num_blocks , in_channels ,
3235 num_channels )
3336
@@ -92,7 +95,8 @@ def _make_one_branch(self,
9295 downsample = downsample ,
9396 with_cp = self .with_cp ,
9497 norm_cfg = self .norm_cfg ,
95- conv_cfg = self .conv_cfg ))
98+ conv_cfg = self .conv_cfg ,
99+ init_cfg = self .block_init_cfg ))
96100 self .in_channels [branch_index ] = \
97101 num_channels [branch_index ] * block .expansion
98102 for i in range (1 , num_blocks [branch_index ]):
@@ -102,9 +106,10 @@ def _make_one_branch(self,
102106 num_channels [branch_index ],
103107 with_cp = self .with_cp ,
104108 norm_cfg = self .norm_cfg ,
105- conv_cfg = self .conv_cfg ))
109+ conv_cfg = self .conv_cfg ,
110+ init_cfg = self .block_init_cfg ))
106111
107- return nn . Sequential (* layers )
112+ return Sequential (* layers )
108113
109114 def _make_branches (self , num_branches , block , num_blocks , num_channels ):
110115 """Build multiple branch."""
@@ -114,7 +119,7 @@ def _make_branches(self, num_branches, block, num_blocks, num_channels):
114119 branches .append (
115120 self ._make_one_branch (i , block , num_blocks , num_channels ))
116121
117- return nn . ModuleList (branches )
122+ return ModuleList (branches )
118123
119124 def _make_fuse_layers (self ):
120125 """Build fuse layer."""
@@ -209,7 +214,7 @@ def forward(self, x):
209214
210215
211216@BACKBONES .register_module ()
212- class HRNet (nn . Module ):
217+ class HRNet (BaseModule ):
213218 """HRNet backbone.
214219
215220 High-Resolution Representations for Labeling Pixels and Regions
@@ -227,6 +232,9 @@ class HRNet(nn.Module):
227232 memory while slowing down the training speed.
228233 zero_init_residual (bool): whether to use zero init for last norm layer
229234 in resblocks to let them behave as identity.
235+ pretrained (str, optional): model pretrained path. Default: None
236+ init_cfg (dict or list[dict], optional): Initialization config dict.
237+ Default: None
230238
231239 Example:
232240 >>> from mmseg.models import HRNet
@@ -277,14 +285,36 @@ def __init__(self,
277285 norm_cfg = dict (type = 'BN' , requires_grad = True ),
278286 norm_eval = False ,
279287 with_cp = False ,
280- zero_init_residual = False ):
281- super (HRNet , self ).__init__ ()
288+ zero_init_residual = False ,
289+ pretrained = None ,
290+ init_cfg = None ):
291+ super (HRNet , self ).__init__ (init_cfg )
292+
293+ self .pretrained = pretrained
294+ self .zero_init_residual = zero_init_residual
295+ assert not (init_cfg and pretrained ), \
296+ 'init_cfg and pretrained cannot be setting at the same time'
297+ if isinstance (pretrained , str ):
298+ warnings .warn ('DeprecationWarning: pretrained is a deprecated, '
299+ 'please use "init_cfg" instead' )
300+ self .init_cfg = dict (type = 'Pretrained' , checkpoint = pretrained )
301+ elif pretrained is None :
302+ if init_cfg is None :
303+ self .init_cfg = [
304+ dict (type = 'Kaiming' , layer = 'Conv2d' ),
305+ dict (
306+ type = 'Constant' ,
307+ val = 1 ,
308+ layer = ['_BatchNorm' , 'GroupNorm' ])
309+ ]
310+ else :
311+ raise TypeError ('pretrained must be a str or None' )
312+
282313 self .extra = extra
283314 self .conv_cfg = conv_cfg
284315 self .norm_cfg = norm_cfg
285316 self .norm_eval = norm_eval
286317 self .with_cp = with_cp
287- self .zero_init_residual = zero_init_residual
288318
289319 # stem net
290320 self .norm1_name , norm1 = build_norm_layer (self .norm_cfg , 64 , postfix = 1 )
@@ -430,6 +460,16 @@ def _make_layer(self, block, inplanes, planes, blocks, stride=1):
430460 build_norm_layer (self .norm_cfg , planes * block .expansion )[1 ])
431461
432462 layers = []
463+ block_init_cfg = None
464+ if self .pretrained is None and not hasattr (
465+ self , 'init_cfg' ) and self .zero_init_residual :
466+ if block is BasicBlock :
467+ block_init_cfg = dict (
468+ type = 'Constant' , val = 0 , override = dict (name = 'norm2' ))
469+ elif block is Bottleneck :
470+ block_init_cfg = dict (
471+ type = 'Constant' , val = 0 , override = dict (name = 'norm3' ))
472+
433473 layers .append (
434474 block (
435475 inplanes ,
@@ -438,7 +478,8 @@ def _make_layer(self, block, inplanes, planes, blocks, stride=1):
438478 downsample = downsample ,
439479 with_cp = self .with_cp ,
440480 norm_cfg = self .norm_cfg ,
441- conv_cfg = self .conv_cfg ))
481+ conv_cfg = self .conv_cfg ,
482+ init_cfg = block_init_cfg ))
442483 inplanes = planes * block .expansion
443484 for i in range (1 , blocks ):
444485 layers .append (
@@ -447,9 +488,10 @@ def _make_layer(self, block, inplanes, planes, blocks, stride=1):
447488 planes ,
448489 with_cp = self .with_cp ,
449490 norm_cfg = self .norm_cfg ,
450- conv_cfg = self .conv_cfg ))
491+ conv_cfg = self .conv_cfg ,
492+ init_cfg = block_init_cfg ))
451493
452- return nn . Sequential (* layers )
494+ return Sequential (* layers )
453495
454496 def _make_stage (self , layer_config , in_channels , multiscale_output = True ):
455497 """Make each stage."""
@@ -460,6 +502,16 @@ def _make_stage(self, layer_config, in_channels, multiscale_output=True):
460502 block = self .blocks_dict [layer_config ['block' ]]
461503
462504 hr_modules = []
505+ block_init_cfg = None
506+ if self .pretrained is None and not hasattr (
507+ self , 'init_cfg' ) and self .zero_init_residual :
508+ if block is BasicBlock :
509+ block_init_cfg = dict (
510+ type = 'Constant' , val = 0 , override = dict (name = 'norm2' ))
511+ elif block is Bottleneck :
512+ block_init_cfg = dict (
513+ type = 'Constant' , val = 0 , override = dict (name = 'norm3' ))
514+
463515 for i in range (num_modules ):
464516 # multi_scale_output is only used for the last module
465517 if not multiscale_output and i == num_modules - 1 :
@@ -477,35 +529,10 @@ def _make_stage(self, layer_config, in_channels, multiscale_output=True):
477529 reset_multiscale_output ,
478530 with_cp = self .with_cp ,
479531 norm_cfg = self .norm_cfg ,
480- conv_cfg = self .conv_cfg ))
481-
482- return nn .Sequential (* hr_modules ), in_channels
532+ conv_cfg = self .conv_cfg ,
533+ block_init_cfg = block_init_cfg ))
483534
484- def init_weights (self , pretrained = None ):
485- """Initialize the weights in backbone.
486-
487- Args:
488- pretrained (str, optional): Path to pre-trained weights.
489- Defaults to None.
490- """
491- if isinstance (pretrained , str ):
492- logger = get_root_logger ()
493- load_checkpoint (self , pretrained , strict = False , logger = logger )
494- elif pretrained is None :
495- for m in self .modules ():
496- if isinstance (m , nn .Conv2d ):
497- kaiming_init (m )
498- elif isinstance (m , (_BatchNorm , nn .GroupNorm )):
499- constant_init (m , 1 )
500-
501- if self .zero_init_residual :
502- for m in self .modules ():
503- if isinstance (m , Bottleneck ):
504- constant_init (m .norm3 , 0 )
505- elif isinstance (m , BasicBlock ):
506- constant_init (m .norm2 , 0 )
507- else :
508- raise TypeError ('pretrained must be a str or None' )
535+ return Sequential (* hr_modules ), in_channels
509536
510537 def forward (self , x ):
511538 """Forward function."""
0 commit comments