@@ -230,6 +230,8 @@ class HRNet(BaseModule):
230
230
and its variants only.
231
231
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
232
232
memory while slowing down the training speed.
233
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
234
+ -1 means not freezing any parameters. Default: -1.
233
235
zero_init_residual (bool): whether to use zero init for last norm layer
234
236
in resblocks to let them behave as identity.
235
237
pretrained (str, optional): model pretrained path. Default: None
@@ -285,6 +287,7 @@ def __init__(self,
285
287
norm_cfg = dict (type = 'BN' , requires_grad = True ),
286
288
norm_eval = False ,
287
289
with_cp = False ,
290
+ frozen_stages = - 1 ,
288
291
zero_init_residual = False ,
289
292
pretrained = None ,
290
293
init_cfg = None ):
@@ -315,6 +318,7 @@ def __init__(self,
315
318
self .norm_cfg = norm_cfg
316
319
self .norm_eval = norm_eval
317
320
self .with_cp = with_cp
321
+ self .frozen_stages = frozen_stages
318
322
319
323
# stem net
320
324
self .norm1_name , norm1 = build_norm_layer (self .norm_cfg , 64 , postfix = 1 )
@@ -388,6 +392,8 @@ def __init__(self,
388
392
self .stage4 , pre_stage_channels = self ._make_stage (
389
393
self .stage4_cfg , num_channels )
390
394
395
+ self ._freeze_stages ()
396
+
391
397
@property
392
398
def norm1 (self ):
393
399
"""nn.Module: the normalization layer named "norm1" """
@@ -534,6 +540,32 @@ def _make_stage(self, layer_config, in_channels, multiscale_output=True):
534
540
535
541
return Sequential (* hr_modules ), in_channels
536
542
543
+ def _freeze_stages (self ):
544
+ """Freeze stages param and norm stats."""
545
+ if self .frozen_stages >= 0 :
546
+
547
+ self .norm1 .eval ()
548
+ self .norm2 .eval ()
549
+ for m in [self .conv1 , self .norm1 , self .conv2 , self .norm2 ]:
550
+ for param in m .parameters ():
551
+ param .requires_grad = False
552
+
553
+ for i in range (1 , self .frozen_stages + 1 ):
554
+ if i == 1 :
555
+ m = getattr (self , f'layer{ i } ' )
556
+ t = getattr (self , f'transition{ i } ' )
557
+ elif i == 4 :
558
+ m = getattr (self , f'stage{ i } ' )
559
+ else :
560
+ m = getattr (self , f'stage{ i } ' )
561
+ t = getattr (self , f'transition{ i } ' )
562
+ m .eval ()
563
+ for param in m .parameters ():
564
+ param .requires_grad = False
565
+ t .eval ()
566
+ for param in t .parameters ():
567
+ param .requires_grad = False
568
+
537
569
def forward (self , x ):
538
570
"""Forward function."""
539
571
@@ -575,6 +607,7 @@ def train(self, mode=True):
575
607
"""Convert the model into training mode will keeping the normalization
576
608
layer freezed."""
577
609
super (HRNet , self ).train (mode )
610
+ self ._freeze_stages ()
578
611
if mode and self .norm_eval :
579
612
for m in self .modules ():
580
613
# trick: eval have effect on BatchNorm only
0 commit comments