Skip to content

Commit 3057ef6

Browse files
authored
[Fix] Fix wrong init usage in transformer models (open-mmlab#1069)
* fix wrong trunc_normal_init usage * fix mit init weights * fix vit init weights * fix mit init weights * fix typo * fix swin init weights
1 parent 2918220 commit 3057ef6

File tree

3 files changed

+25
-32
lines changed

3 files changed

+25
-32
lines changed

mmseg/models/backbones/mit.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44

55
import torch
66
import torch.nn as nn
7-
from mmcv.cnn import (Conv2d, build_activation_layer, build_norm_layer,
8-
constant_init, normal_init, trunc_normal_init)
7+
from mmcv.cnn import Conv2d, build_activation_layer, build_norm_layer
98
from mmcv.cnn.bricks.drop import build_dropout
109
from mmcv.cnn.bricks.transformer import MultiheadAttention
10+
from mmcv.cnn.utils.weight_init import (constant_init, normal_init,
11+
trunc_normal_init)
1112
from mmcv.runner import BaseModule, ModuleList, Sequential, _load_checkpoint
1213

1314
from ...utils import get_root_logger
@@ -343,7 +344,7 @@ def __init__(self,
343344
norm_cfg=dict(type='LN', eps=1e-6),
344345
pretrained=None,
345346
init_cfg=None):
346-
super().__init__()
347+
super().__init__(init_cfg=init_cfg)
347348

348349
if isinstance(pretrained, str) or pretrained is None:
349350
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
@@ -365,7 +366,6 @@ def __init__(self,
365366
self.out_indices = out_indices
366367
assert max(out_indices) < self.num_stages
367368
self.pretrained = pretrained
368-
self.init_cfg = init_cfg
369369

370370
# transformer encoder
371371
dpr = [
@@ -407,19 +407,15 @@ def init_weights(self):
407407
if self.pretrained is None:
408408
for m in self.modules():
409409
if isinstance(m, nn.Linear):
410-
trunc_normal_init(m.weight, std=.02)
411-
if m.bias is not None:
412-
constant_init(m.bias, 0)
410+
trunc_normal_init(m, std=.02, bias=0.)
413411
elif isinstance(m, nn.LayerNorm):
414-
constant_init(m.bias, 0)
415-
constant_init(m.weight, 1.0)
412+
constant_init(m, val=1.0, bias=0.)
416413
elif isinstance(m, nn.Conv2d):
417414
fan_out = m.kernel_size[0] * m.kernel_size[
418415
1] * m.out_channels
419416
fan_out //= m.groups
420-
normal_init(m.weight, 0, math.sqrt(2.0 / fan_out))
421-
if m.bias is not None:
422-
constant_init(m.bias, 0)
417+
normal_init(
418+
m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0)
423419
elif isinstance(self.pretrained, str):
424420
logger = get_root_logger()
425421
checkpoint = _load_checkpoint(

mmseg/models/backbones/swin.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77
import torch.nn as nn
88
import torch.nn.functional as F
99
import torch.utils.checkpoint as cp
10-
from mmcv.cnn import build_norm_layer, constant_init, trunc_normal_init
10+
from mmcv.cnn import build_norm_layer
1111
from mmcv.cnn.bricks.transformer import FFN, build_dropout
12+
from mmcv.cnn.utils.weight_init import (constant_init, trunc_normal_,
13+
trunc_normal_init)
1214
from mmcv.runner import BaseModule, ModuleList, _load_checkpoint
1315
from mmcv.utils import to_2tuple
1416

@@ -73,7 +75,7 @@ def __init__(self,
7375
self.softmax = nn.Softmax(dim=-1)
7476

7577
def init_weights(self):
76-
trunc_normal_init(self.relative_position_bias_table, std=0.02)
78+
trunc_normal_(self.relative_position_bias_table, std=0.02)
7779

7880
def forward(self, x, mask=None):
7981
"""
@@ -665,15 +667,12 @@ def init_weights(self):
665667
f'{self.__class__.__name__}, '
666668
f'training start from scratch')
667669
if self.use_abs_pos_embed:
668-
trunc_normal_init(self.absolute_pos_embed, std=0.02)
670+
trunc_normal_(self.absolute_pos_embed, std=0.02)
669671
for m in self.modules():
670672
if isinstance(m, nn.Linear):
671-
trunc_normal_init(m.weight, std=.02)
672-
if m.bias is not None:
673-
constant_init(m.bias, 0)
673+
trunc_normal_init(m, std=.02, bias=0.)
674674
elif isinstance(m, nn.LayerNorm):
675-
constant_init(m.bias, 0)
676-
constant_init(m.weight, 1.0)
675+
constant_init(m, val=1.0, bias=0.)
677676
else:
678677
assert 'checkpoint' in self.init_cfg, f'Only support ' \
679678
f'specify `Pretrained` in ' \

mmseg/models/backbones/vit.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44

55
import torch
66
import torch.nn as nn
7-
from mmcv.cnn import (build_norm_layer, constant_init, kaiming_init,
8-
normal_init, trunc_normal_init)
7+
from mmcv.cnn import build_norm_layer
98
from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
9+
from mmcv.cnn.utils.weight_init import (constant_init, kaiming_init,
10+
trunc_normal_)
1011
from mmcv.runner import BaseModule, ModuleList, _load_checkpoint
1112
from torch.nn.modules.batchnorm import _BatchNorm
1213
from torch.nn.modules.utils import _pair as to_2tuple
@@ -292,23 +293,20 @@ def init_weights(self):
292293
else:
293294
# We only implement the 'jax_impl' initialization implemented at
294295
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
295-
trunc_normal_init(self.pos_embed, std=.02)
296-
trunc_normal_init(self.cls_token, std=.02)
296+
trunc_normal_(self.pos_embed, std=.02)
297+
trunc_normal_(self.cls_token, std=.02)
297298
for n, m in self.named_modules():
298299
if isinstance(m, nn.Linear):
299-
trunc_normal_init(m.weight, std=.02)
300+
trunc_normal_(m.weight, std=.02)
300301
if m.bias is not None:
301302
if 'ffn' in n:
302-
normal_init(m.bias, std=1e-6)
303+
nn.init.normal_(m.bias, mean=0., std=1e-6)
303304
else:
304-
constant_init(m.bias, 0)
305+
nn.init.constant_(m.bias, 0)
305306
elif isinstance(m, nn.Conv2d):
306-
kaiming_init(m.weight, mode='fan_in')
307-
if m.bias is not None:
308-
constant_init(m.bias, 0)
307+
kaiming_init(m, mode='fan_in', bias=0.)
309308
elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
310-
constant_init(m.bias, 0)
311-
constant_init(m.weight, 1.0)
309+
constant_init(m, val=1.0, bias=0.)
312310

313311
def _pos_embeding(self, patched_img, hw_shape, pos_embed):
314312
"""Positiong embeding method.

0 commit comments

Comments
 (0)