Skip to content

Commit f934084

Browse files
authored
[Enhancement] Support hrnet frozen stage (open-mmlab#743)
* support hrnet frozen stage * support hrnet frozen stage
1 parent 52b4fa5 commit f934084

File tree

2 files changed

+96
-0
lines changed

2 files changed

+96
-0
lines changed

mmseg/models/backbones/hrnet.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,8 @@ class HRNet(BaseModule):
230230
and its variants only.
231231
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
232232
memory while slowing down the training speed.
233+
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
234+
-1 means not freezing any parameters. Default: -1.
233235
zero_init_residual (bool): whether to use zero init for last norm layer
234236
in resblocks to let them behave as identity.
235237
pretrained (str, optional): model pretrained path. Default: None
@@ -285,6 +287,7 @@ def __init__(self,
285287
norm_cfg=dict(type='BN', requires_grad=True),
286288
norm_eval=False,
287289
with_cp=False,
290+
frozen_stages=-1,
288291
zero_init_residual=False,
289292
pretrained=None,
290293
init_cfg=None):
@@ -315,6 +318,7 @@ def __init__(self,
315318
self.norm_cfg = norm_cfg
316319
self.norm_eval = norm_eval
317320
self.with_cp = with_cp
321+
self.frozen_stages = frozen_stages
318322

319323
# stem net
320324
self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1)
@@ -388,6 +392,8 @@ def __init__(self,
388392
self.stage4, pre_stage_channels = self._make_stage(
389393
self.stage4_cfg, num_channels)
390394

395+
self._freeze_stages()
396+
391397
@property
392398
def norm1(self):
393399
"""nn.Module: the normalization layer named "norm1" """
@@ -534,6 +540,32 @@ def _make_stage(self, layer_config, in_channels, multiscale_output=True):
534540

535541
return Sequential(*hr_modules), in_channels
536542

543+
def _freeze_stages(self):
544+
"""Freeze stages param and norm stats."""
545+
if self.frozen_stages >= 0:
546+
547+
self.norm1.eval()
548+
self.norm2.eval()
549+
for m in [self.conv1, self.norm1, self.conv2, self.norm2]:
550+
for param in m.parameters():
551+
param.requires_grad = False
552+
553+
for i in range(1, self.frozen_stages + 1):
554+
if i == 1:
555+
m = getattr(self, f'layer{i}')
556+
t = getattr(self, f'transition{i}')
557+
elif i == 4:
558+
m = getattr(self, f'stage{i}')
559+
else:
560+
m = getattr(self, f'stage{i}')
561+
t = getattr(self, f'transition{i}')
562+
m.eval()
563+
for param in m.parameters():
564+
param.requires_grad = False
565+
t.eval()
566+
for param in t.parameters():
567+
param.requires_grad = False
568+
537569
def forward(self, x):
538570
"""Forward function."""
539571

@@ -575,6 +607,7 @@ def train(self, mode=True):
575607
"""Convert the model into training mode will keeping the normalization
576608
layer freezed."""
577609
super(HRNet, self).train(mode)
610+
self._freeze_stages()
578611
if mode and self.norm_eval:
579612
for m in self.modules():
580613
# trick: eval have effect on BatchNorm only
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
from mmcv.utils.parrots_wrapper import _BatchNorm
2+
3+
from mmseg.models.backbones import HRNet
4+
5+
6+
def test_hrnet_backbone():
7+
# Test HRNET with two stage frozen
8+
9+
extra = dict(
10+
stage1=dict(
11+
num_modules=1,
12+
num_branches=1,
13+
block='BOTTLENECK',
14+
num_blocks=(4, ),
15+
num_channels=(64, )),
16+
stage2=dict(
17+
num_modules=1,
18+
num_branches=2,
19+
block='BASIC',
20+
num_blocks=(4, 4),
21+
num_channels=(32, 64)),
22+
stage3=dict(
23+
num_modules=4,
24+
num_branches=3,
25+
block='BASIC',
26+
num_blocks=(4, 4, 4),
27+
num_channels=(32, 64, 128)),
28+
stage4=dict(
29+
num_modules=3,
30+
num_branches=4,
31+
block='BASIC',
32+
num_blocks=(4, 4, 4, 4),
33+
num_channels=(32, 64, 128, 256)))
34+
frozen_stages = 2
35+
model = HRNet(extra, frozen_stages=frozen_stages)
36+
model.init_weights()
37+
model.train()
38+
assert model.norm1.training is False
39+
40+
for layer in [model.conv1, model.norm1]:
41+
for param in layer.parameters():
42+
assert param.requires_grad is False
43+
for i in range(1, frozen_stages + 1):
44+
if i == 1:
45+
layer = getattr(model, f'layer{i}')
46+
transition = getattr(model, f'transition{i}')
47+
elif i == 4:
48+
layer = getattr(model, f'stage{i}')
49+
else:
50+
layer = getattr(model, f'stage{i}')
51+
transition = getattr(model, f'transition{i}')
52+
53+
for mod in layer.modules():
54+
if isinstance(mod, _BatchNorm):
55+
assert mod.training is False
56+
for param in layer.parameters():
57+
assert param.requires_grad is False
58+
59+
for mod in transition.modules():
60+
if isinstance(mod, _BatchNorm):
61+
assert mod.training is False
62+
for param in transition.parameters():
63+
assert param.requires_grad is False

0 commit comments

Comments
 (0)