Skip to content

Commit 0c5b026

Browse files
authored
[Refactor]: Unified parameter initialization (open-mmlab#567)
* [Refactor]: Unified parameter initialization * fixed pretrained
1 parent 5d46314 commit 0c5b026

File tree

19 files changed

+329
-298
lines changed

19 files changed

+329
-298
lines changed

mmseg/models/backbones/cgnet.py

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1+
import warnings
2+
13
import torch
24
import torch.nn as nn
35
import torch.utils.checkpoint as cp
4-
from mmcv.cnn import (ConvModule, build_conv_layer, build_norm_layer,
5-
constant_init, kaiming_init)
6-
from mmcv.runner import load_checkpoint
6+
from mmcv.cnn import ConvModule, build_conv_layer, build_norm_layer
7+
from mmcv.runner import BaseModule
78
from mmcv.utils.parrots_wrapper import _BatchNorm
89

9-
from mmseg.utils import get_root_logger
1010
from ..builder import BACKBONES
1111

1212

@@ -183,7 +183,7 @@ def forward(self, x):
183183

184184

185185
@BACKBONES.register_module()
186-
class CGNet(nn.Module):
186+
class CGNet(BaseModule):
187187
"""CGNet backbone.
188188
189189
A Light-weight Context Guided Network for Semantic Segmentation
@@ -210,6 +210,9 @@ class CGNet(nn.Module):
210210
and its variants only. Default: False.
211211
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
212212
memory while slowing down the training speed. Default: False.
213+
pretrained (str, optional): model pretrained path. Default: None
214+
init_cfg (dict or list[dict], optional): Initialization config dict.
215+
Default: None
213216
"""
214217

215218
def __init__(self,
@@ -222,9 +225,31 @@ def __init__(self,
222225
norm_cfg=dict(type='BN', requires_grad=True),
223226
act_cfg=dict(type='PReLU'),
224227
norm_eval=False,
225-
with_cp=False):
228+
with_cp=False,
229+
pretrained=None,
230+
init_cfg=None):
231+
232+
super(CGNet, self).__init__(init_cfg)
233+
234+
assert not (init_cfg and pretrained), \
235+
'init_cfg and pretrained cannot be setting at the same time'
236+
if isinstance(pretrained, str):
237+
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
238+
'please use "init_cfg" instead')
239+
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
240+
elif pretrained is None:
241+
if init_cfg is None:
242+
self.init_cfg = [
243+
dict(type='Kaiming', layer=['Conv2d', 'Linear']),
244+
dict(
245+
type='Constant',
246+
val=1,
247+
layer=['_BatchNorm', 'GroupNorm']),
248+
dict(type='Constant', val=0, layer='PReLU')
249+
]
250+
else:
251+
raise TypeError('pretrained must be a str or None')
226252

227-
super(CGNet, self).__init__()
228253
self.in_channels = in_channels
229254
self.num_channels = num_channels
230255
assert isinstance(self.num_channels, tuple) and len(
@@ -335,27 +360,6 @@ def forward(self, x):
335360

336361
return output
337362

338-
def init_weights(self, pretrained=None):
339-
"""Initialize the weights in backbone.
340-
341-
Args:
342-
pretrained (str, optional): Path to pre-trained weights.
343-
Defaults to None.
344-
"""
345-
if isinstance(pretrained, str):
346-
logger = get_root_logger()
347-
load_checkpoint(self, pretrained, strict=False, logger=logger)
348-
elif pretrained is None:
349-
for m in self.modules():
350-
if isinstance(m, (nn.Conv2d, nn.Linear)):
351-
kaiming_init(m)
352-
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
353-
constant_init(m, 1)
354-
elif isinstance(m, nn.PReLU):
355-
constant_init(m, 0)
356-
else:
357-
raise TypeError('pretrained must be a str or None')
358-
359363
def train(self, mode=True):
360364
"""Convert the model into training mode will keeping the normalization
361365
layer freezed."""

mmseg/models/backbones/fast_scnn.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import torch
22
import torch.nn as nn
3-
from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule, constant_init,
4-
kaiming_init)
5-
from torch.nn.modules.batchnorm import _BatchNorm
3+
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
4+
from mmcv.runner import BaseModule
65

76
from mmseg.models.decode_heads.psp_head import PPM
87
from mmseg.ops import resize
@@ -247,7 +246,7 @@ def forward(self, higher_res_feature, lower_res_feature):
247246

248247

249248
@BACKBONES.register_module()
250-
class FastSCNN(nn.Module):
249+
class FastSCNN(BaseModule):
251250
"""Fast-SCNN Backbone.
252251
253252
Args:
@@ -291,6 +290,8 @@ class FastSCNN(nn.Module):
291290
dict(type='ReLU')
292291
align_corners (bool): align_corners argument of F.interpolate.
293292
Default: False
293+
init_cfg (dict or list[dict], optional): Initialization config dict.
294+
Default: None
294295
"""
295296

296297
def __init__(self,
@@ -307,9 +308,18 @@ def __init__(self,
307308
conv_cfg=None,
308309
norm_cfg=dict(type='BN'),
309310
act_cfg=dict(type='ReLU'),
310-
align_corners=False):
311+
align_corners=False,
312+
init_cfg=None):
313+
314+
super(FastSCNN, self).__init__(init_cfg)
315+
316+
if init_cfg is None:
317+
self.init_cfg = [
318+
dict(type='Kaiming', layer='Conv2d'),
319+
dict(
320+
type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm'])
321+
]
311322

312-
super(FastSCNN, self).__init__()
313323
if global_in_channels != higher_in_channels:
314324
raise AssertionError('Global Input Channels must be the same \
315325
with Higher Input Channels!')
@@ -357,13 +367,6 @@ def __init__(self,
357367
act_cfg=self.act_cfg,
358368
align_corners=self.align_corners)
359369

360-
def init_weights(self, pretrained=None):
361-
for m in self.modules():
362-
if isinstance(m, nn.Conv2d):
363-
kaiming_init(m)
364-
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
365-
constant_init(m, 1)
366-
367370
def forward(self, x):
368371
higher_res_features = self.learning_to_downsample(x)
369372
lower_res_features = self.global_feature_extractor(higher_res_features)

mmseg/models/backbones/hrnet.py

Lines changed: 73 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
1+
import warnings
2+
13
import torch.nn as nn
2-
from mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init,
3-
kaiming_init)
4-
from mmcv.runner import load_checkpoint
4+
from mmcv.cnn import build_conv_layer, build_norm_layer
5+
from mmcv.runner import BaseModule, ModuleList, Sequential
56
from mmcv.utils.parrots_wrapper import _BatchNorm
67

78
from mmseg.ops import Upsample, resize
8-
from mmseg.utils import get_root_logger
99
from ..builder import BACKBONES
1010
from .resnet import BasicBlock, Bottleneck
1111

1212

13-
class HRModule(nn.Module):
13+
class HRModule(BaseModule):
1414
"""High-Resolution Module for HRNet.
1515
1616
In this module, every branch has 4 BasicBlocks/Bottlenecks. Fusion/Exchange
@@ -26,8 +26,11 @@ def __init__(self,
2626
multiscale_output=True,
2727
with_cp=False,
2828
conv_cfg=None,
29-
norm_cfg=dict(type='BN', requires_grad=True)):
30-
super(HRModule, self).__init__()
29+
norm_cfg=dict(type='BN', requires_grad=True),
30+
block_init_cfg=None,
31+
init_cfg=None):
32+
super(HRModule, self).__init__(init_cfg)
33+
self.block_init_cfg = block_init_cfg
3134
self._check_branches(num_branches, num_blocks, in_channels,
3235
num_channels)
3336

@@ -92,7 +95,8 @@ def _make_one_branch(self,
9295
downsample=downsample,
9396
with_cp=self.with_cp,
9497
norm_cfg=self.norm_cfg,
95-
conv_cfg=self.conv_cfg))
98+
conv_cfg=self.conv_cfg,
99+
init_cfg=self.block_init_cfg))
96100
self.in_channels[branch_index] = \
97101
num_channels[branch_index] * block.expansion
98102
for i in range(1, num_blocks[branch_index]):
@@ -102,9 +106,10 @@ def _make_one_branch(self,
102106
num_channels[branch_index],
103107
with_cp=self.with_cp,
104108
norm_cfg=self.norm_cfg,
105-
conv_cfg=self.conv_cfg))
109+
conv_cfg=self.conv_cfg,
110+
init_cfg=self.block_init_cfg))
106111

107-
return nn.Sequential(*layers)
112+
return Sequential(*layers)
108113

109114
def _make_branches(self, num_branches, block, num_blocks, num_channels):
110115
"""Build multiple branch."""
@@ -114,7 +119,7 @@ def _make_branches(self, num_branches, block, num_blocks, num_channels):
114119
branches.append(
115120
self._make_one_branch(i, block, num_blocks, num_channels))
116121

117-
return nn.ModuleList(branches)
122+
return ModuleList(branches)
118123

119124
def _make_fuse_layers(self):
120125
"""Build fuse layer."""
@@ -209,7 +214,7 @@ def forward(self, x):
209214

210215

211216
@BACKBONES.register_module()
212-
class HRNet(nn.Module):
217+
class HRNet(BaseModule):
213218
"""HRNet backbone.
214219
215220
High-Resolution Representations for Labeling Pixels and Regions
@@ -227,6 +232,9 @@ class HRNet(nn.Module):
227232
memory while slowing down the training speed.
228233
zero_init_residual (bool): whether to use zero init for last norm layer
229234
in resblocks to let them behave as identity.
235+
pretrained (str, optional): model pretrained path. Default: None
236+
init_cfg (dict or list[dict], optional): Initialization config dict.
237+
Default: None
230238
231239
Example:
232240
>>> from mmseg.models import HRNet
@@ -277,14 +285,36 @@ def __init__(self,
277285
norm_cfg=dict(type='BN', requires_grad=True),
278286
norm_eval=False,
279287
with_cp=False,
280-
zero_init_residual=False):
281-
super(HRNet, self).__init__()
288+
zero_init_residual=False,
289+
pretrained=None,
290+
init_cfg=None):
291+
super(HRNet, self).__init__(init_cfg)
292+
293+
self.pretrained = pretrained
294+
self.zero_init_residual = zero_init_residual
295+
assert not (init_cfg and pretrained), \
296+
'init_cfg and pretrained cannot be setting at the same time'
297+
if isinstance(pretrained, str):
298+
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
299+
'please use "init_cfg" instead')
300+
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
301+
elif pretrained is None:
302+
if init_cfg is None:
303+
self.init_cfg = [
304+
dict(type='Kaiming', layer='Conv2d'),
305+
dict(
306+
type='Constant',
307+
val=1,
308+
layer=['_BatchNorm', 'GroupNorm'])
309+
]
310+
else:
311+
raise TypeError('pretrained must be a str or None')
312+
282313
self.extra = extra
283314
self.conv_cfg = conv_cfg
284315
self.norm_cfg = norm_cfg
285316
self.norm_eval = norm_eval
286317
self.with_cp = with_cp
287-
self.zero_init_residual = zero_init_residual
288318

289319
# stem net
290320
self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1)
@@ -430,6 +460,16 @@ def _make_layer(self, block, inplanes, planes, blocks, stride=1):
430460
build_norm_layer(self.norm_cfg, planes * block.expansion)[1])
431461

432462
layers = []
463+
block_init_cfg = None
464+
if self.pretrained is None and not hasattr(
465+
self, 'init_cfg') and self.zero_init_residual:
466+
if block is BasicBlock:
467+
block_init_cfg = dict(
468+
type='Constant', val=0, override=dict(name='norm2'))
469+
elif block is Bottleneck:
470+
block_init_cfg = dict(
471+
type='Constant', val=0, override=dict(name='norm3'))
472+
433473
layers.append(
434474
block(
435475
inplanes,
@@ -438,7 +478,8 @@ def _make_layer(self, block, inplanes, planes, blocks, stride=1):
438478
downsample=downsample,
439479
with_cp=self.with_cp,
440480
norm_cfg=self.norm_cfg,
441-
conv_cfg=self.conv_cfg))
481+
conv_cfg=self.conv_cfg,
482+
init_cfg=block_init_cfg))
442483
inplanes = planes * block.expansion
443484
for i in range(1, blocks):
444485
layers.append(
@@ -447,9 +488,10 @@ def _make_layer(self, block, inplanes, planes, blocks, stride=1):
447488
planes,
448489
with_cp=self.with_cp,
449490
norm_cfg=self.norm_cfg,
450-
conv_cfg=self.conv_cfg))
491+
conv_cfg=self.conv_cfg,
492+
init_cfg=block_init_cfg))
451493

452-
return nn.Sequential(*layers)
494+
return Sequential(*layers)
453495

454496
def _make_stage(self, layer_config, in_channels, multiscale_output=True):
455497
"""Make each stage."""
@@ -460,6 +502,16 @@ def _make_stage(self, layer_config, in_channels, multiscale_output=True):
460502
block = self.blocks_dict[layer_config['block']]
461503

462504
hr_modules = []
505+
block_init_cfg = None
506+
if self.pretrained is None and not hasattr(
507+
self, 'init_cfg') and self.zero_init_residual:
508+
if block is BasicBlock:
509+
block_init_cfg = dict(
510+
type='Constant', val=0, override=dict(name='norm2'))
511+
elif block is Bottleneck:
512+
block_init_cfg = dict(
513+
type='Constant', val=0, override=dict(name='norm3'))
514+
463515
for i in range(num_modules):
464516
# multi_scale_output is only used for the last module
465517
if not multiscale_output and i == num_modules - 1:
@@ -477,35 +529,10 @@ def _make_stage(self, layer_config, in_channels, multiscale_output=True):
477529
reset_multiscale_output,
478530
with_cp=self.with_cp,
479531
norm_cfg=self.norm_cfg,
480-
conv_cfg=self.conv_cfg))
481-
482-
return nn.Sequential(*hr_modules), in_channels
532+
conv_cfg=self.conv_cfg,
533+
block_init_cfg=block_init_cfg))
483534

484-
def init_weights(self, pretrained=None):
485-
"""Initialize the weights in backbone.
486-
487-
Args:
488-
pretrained (str, optional): Path to pre-trained weights.
489-
Defaults to None.
490-
"""
491-
if isinstance(pretrained, str):
492-
logger = get_root_logger()
493-
load_checkpoint(self, pretrained, strict=False, logger=logger)
494-
elif pretrained is None:
495-
for m in self.modules():
496-
if isinstance(m, nn.Conv2d):
497-
kaiming_init(m)
498-
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
499-
constant_init(m, 1)
500-
501-
if self.zero_init_residual:
502-
for m in self.modules():
503-
if isinstance(m, Bottleneck):
504-
constant_init(m.norm3, 0)
505-
elif isinstance(m, BasicBlock):
506-
constant_init(m.norm2, 0)
507-
else:
508-
raise TypeError('pretrained must be a str or None')
535+
return Sequential(*hr_modules), in_channels
509536

510537
def forward(self, x):
511538
"""Forward function."""

0 commit comments

Comments
 (0)