1
+ import warnings
2
+
1
3
import torch .nn as nn
2
- from mmcv .cnn import (build_conv_layer , build_norm_layer , constant_init ,
3
- kaiming_init )
4
- from mmcv .runner import load_checkpoint
4
+ from mmcv .cnn import build_conv_layer , build_norm_layer
5
+ from mmcv .runner import BaseModule , ModuleList , Sequential
5
6
from mmcv .utils .parrots_wrapper import _BatchNorm
6
7
7
8
from mmseg .ops import Upsample , resize
8
- from mmseg .utils import get_root_logger
9
9
from ..builder import BACKBONES
10
10
from .resnet import BasicBlock , Bottleneck
11
11
12
12
13
- class HRModule (nn . Module ):
13
+ class HRModule (BaseModule ):
14
14
"""High-Resolution Module for HRNet.
15
15
16
16
In this module, every branch has 4 BasicBlocks/Bottlenecks. Fusion/Exchange
@@ -26,8 +26,11 @@ def __init__(self,
26
26
multiscale_output = True ,
27
27
with_cp = False ,
28
28
conv_cfg = None ,
29
- norm_cfg = dict (type = 'BN' , requires_grad = True )):
30
- super (HRModule , self ).__init__ ()
29
+ norm_cfg = dict (type = 'BN' , requires_grad = True ),
30
+ block_init_cfg = None ,
31
+ init_cfg = None ):
32
+ super (HRModule , self ).__init__ (init_cfg )
33
+ self .block_init_cfg = block_init_cfg
31
34
self ._check_branches (num_branches , num_blocks , in_channels ,
32
35
num_channels )
33
36
@@ -92,7 +95,8 @@ def _make_one_branch(self,
92
95
downsample = downsample ,
93
96
with_cp = self .with_cp ,
94
97
norm_cfg = self .norm_cfg ,
95
- conv_cfg = self .conv_cfg ))
98
+ conv_cfg = self .conv_cfg ,
99
+ init_cfg = self .block_init_cfg ))
96
100
self .in_channels [branch_index ] = \
97
101
num_channels [branch_index ] * block .expansion
98
102
for i in range (1 , num_blocks [branch_index ]):
@@ -102,9 +106,10 @@ def _make_one_branch(self,
102
106
num_channels [branch_index ],
103
107
with_cp = self .with_cp ,
104
108
norm_cfg = self .norm_cfg ,
105
- conv_cfg = self .conv_cfg ))
109
+ conv_cfg = self .conv_cfg ,
110
+ init_cfg = self .block_init_cfg ))
106
111
107
- return nn . Sequential (* layers )
112
+ return Sequential (* layers )
108
113
109
114
def _make_branches (self , num_branches , block , num_blocks , num_channels ):
110
115
"""Build multiple branch."""
@@ -114,7 +119,7 @@ def _make_branches(self, num_branches, block, num_blocks, num_channels):
114
119
branches .append (
115
120
self ._make_one_branch (i , block , num_blocks , num_channels ))
116
121
117
- return nn . ModuleList (branches )
122
+ return ModuleList (branches )
118
123
119
124
def _make_fuse_layers (self ):
120
125
"""Build fuse layer."""
@@ -209,7 +214,7 @@ def forward(self, x):
209
214
210
215
211
216
@BACKBONES .register_module ()
212
- class HRNet (nn . Module ):
217
+ class HRNet (BaseModule ):
213
218
"""HRNet backbone.
214
219
215
220
High-Resolution Representations for Labeling Pixels and Regions
@@ -227,6 +232,9 @@ class HRNet(nn.Module):
227
232
memory while slowing down the training speed.
228
233
zero_init_residual (bool): whether to use zero init for last norm layer
229
234
in resblocks to let them behave as identity.
235
+ pretrained (str, optional): model pretrained path. Default: None
236
+ init_cfg (dict or list[dict], optional): Initialization config dict.
237
+ Default: None
230
238
231
239
Example:
232
240
>>> from mmseg.models import HRNet
@@ -277,14 +285,36 @@ def __init__(self,
277
285
norm_cfg = dict (type = 'BN' , requires_grad = True ),
278
286
norm_eval = False ,
279
287
with_cp = False ,
280
- zero_init_residual = False ):
281
- super (HRNet , self ).__init__ ()
288
+ zero_init_residual = False ,
289
+ pretrained = None ,
290
+ init_cfg = None ):
291
+ super (HRNet , self ).__init__ (init_cfg )
292
+
293
+ self .pretrained = pretrained
294
+ self .zero_init_residual = zero_init_residual
295
+ assert not (init_cfg and pretrained ), \
296
+ 'init_cfg and pretrained cannot be setting at the same time'
297
+ if isinstance (pretrained , str ):
298
+ warnings .warn ('DeprecationWarning: pretrained is a deprecated, '
299
+ 'please use "init_cfg" instead' )
300
+ self .init_cfg = dict (type = 'Pretrained' , checkpoint = pretrained )
301
+ elif pretrained is None :
302
+ if init_cfg is None :
303
+ self .init_cfg = [
304
+ dict (type = 'Kaiming' , layer = 'Conv2d' ),
305
+ dict (
306
+ type = 'Constant' ,
307
+ val = 1 ,
308
+ layer = ['_BatchNorm' , 'GroupNorm' ])
309
+ ]
310
+ else :
311
+ raise TypeError ('pretrained must be a str or None' )
312
+
282
313
self .extra = extra
283
314
self .conv_cfg = conv_cfg
284
315
self .norm_cfg = norm_cfg
285
316
self .norm_eval = norm_eval
286
317
self .with_cp = with_cp
287
- self .zero_init_residual = zero_init_residual
288
318
289
319
# stem net
290
320
self .norm1_name , norm1 = build_norm_layer (self .norm_cfg , 64 , postfix = 1 )
@@ -430,6 +460,16 @@ def _make_layer(self, block, inplanes, planes, blocks, stride=1):
430
460
build_norm_layer (self .norm_cfg , planes * block .expansion )[1 ])
431
461
432
462
layers = []
463
+ block_init_cfg = None
464
+ if self .pretrained is None and not hasattr (
465
+ self , 'init_cfg' ) and self .zero_init_residual :
466
+ if block is BasicBlock :
467
+ block_init_cfg = dict (
468
+ type = 'Constant' , val = 0 , override = dict (name = 'norm2' ))
469
+ elif block is Bottleneck :
470
+ block_init_cfg = dict (
471
+ type = 'Constant' , val = 0 , override = dict (name = 'norm3' ))
472
+
433
473
layers .append (
434
474
block (
435
475
inplanes ,
@@ -438,7 +478,8 @@ def _make_layer(self, block, inplanes, planes, blocks, stride=1):
438
478
downsample = downsample ,
439
479
with_cp = self .with_cp ,
440
480
norm_cfg = self .norm_cfg ,
441
- conv_cfg = self .conv_cfg ))
481
+ conv_cfg = self .conv_cfg ,
482
+ init_cfg = block_init_cfg ))
442
483
inplanes = planes * block .expansion
443
484
for i in range (1 , blocks ):
444
485
layers .append (
@@ -447,9 +488,10 @@ def _make_layer(self, block, inplanes, planes, blocks, stride=1):
447
488
planes ,
448
489
with_cp = self .with_cp ,
449
490
norm_cfg = self .norm_cfg ,
450
- conv_cfg = self .conv_cfg ))
491
+ conv_cfg = self .conv_cfg ,
492
+ init_cfg = block_init_cfg ))
451
493
452
- return nn . Sequential (* layers )
494
+ return Sequential (* layers )
453
495
454
496
def _make_stage (self , layer_config , in_channels , multiscale_output = True ):
455
497
"""Make each stage."""
@@ -460,6 +502,16 @@ def _make_stage(self, layer_config, in_channels, multiscale_output=True):
460
502
block = self .blocks_dict [layer_config ['block' ]]
461
503
462
504
hr_modules = []
505
+ block_init_cfg = None
506
+ if self .pretrained is None and not hasattr (
507
+ self , 'init_cfg' ) and self .zero_init_residual :
508
+ if block is BasicBlock :
509
+ block_init_cfg = dict (
510
+ type = 'Constant' , val = 0 , override = dict (name = 'norm2' ))
511
+ elif block is Bottleneck :
512
+ block_init_cfg = dict (
513
+ type = 'Constant' , val = 0 , override = dict (name = 'norm3' ))
514
+
463
515
for i in range (num_modules ):
464
516
# multi_scale_output is only used for the last module
465
517
if not multiscale_output and i == num_modules - 1 :
@@ -477,35 +529,10 @@ def _make_stage(self, layer_config, in_channels, multiscale_output=True):
477
529
reset_multiscale_output ,
478
530
with_cp = self .with_cp ,
479
531
norm_cfg = self .norm_cfg ,
480
- conv_cfg = self .conv_cfg ))
481
-
482
- return nn .Sequential (* hr_modules ), in_channels
532
+ conv_cfg = self .conv_cfg ,
533
+ block_init_cfg = block_init_cfg ))
483
534
484
- def init_weights (self , pretrained = None ):
485
- """Initialize the weights in backbone.
486
-
487
- Args:
488
- pretrained (str, optional): Path to pre-trained weights.
489
- Defaults to None.
490
- """
491
- if isinstance (pretrained , str ):
492
- logger = get_root_logger ()
493
- load_checkpoint (self , pretrained , strict = False , logger = logger )
494
- elif pretrained is None :
495
- for m in self .modules ():
496
- if isinstance (m , nn .Conv2d ):
497
- kaiming_init (m )
498
- elif isinstance (m , (_BatchNorm , nn .GroupNorm )):
499
- constant_init (m , 1 )
500
-
501
- if self .zero_init_residual :
502
- for m in self .modules ():
503
- if isinstance (m , Bottleneck ):
504
- constant_init (m .norm3 , 0 )
505
- elif isinstance (m , BasicBlock ):
506
- constant_init (m .norm2 , 0 )
507
- else :
508
- raise TypeError ('pretrained must be a str or None' )
535
+ return Sequential (* hr_modules ), in_channels
509
536
510
537
def forward (self , x ):
511
538
"""Forward function."""
0 commit comments