Skip to content

Commit c23e902

Browse files
authored
[Fix] Fix the bug that mit cannot process init_cfg (open-mmlab#1102)
* [Fix] Fix the bug that mit cannot process init_cfg * fix error
1 parent dec5bf0 commit c23e902

File tree

2 files changed

+67
-19
lines changed

2 files changed

+67
-19
lines changed

mmseg/models/backbones/mit.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,8 @@
99
from mmcv.cnn.bricks.transformer import MultiheadAttention
1010
from mmcv.cnn.utils.weight_init import (constant_init, normal_init,
1111
trunc_normal_init)
12-
from mmcv.runner import BaseModule, ModuleList, Sequential, _load_checkpoint
12+
from mmcv.runner import BaseModule, ModuleList, Sequential
1313

14-
from ...utils import get_root_logger
1514
from ..builder import BACKBONES
1615
from ..utils import PatchEmbed, nchw_to_nlc, nlc_to_nchw
1716

@@ -341,16 +340,18 @@ def __init__(self,
341340
norm_cfg=dict(type='LN', eps=1e-6),
342341
pretrained=None,
343342
init_cfg=None):
344-
super().__init__(init_cfg=init_cfg)
343+
super(MixVisionTransformer, self).__init__(init_cfg=init_cfg)
345344

346-
if isinstance(pretrained, str) or pretrained is None:
347-
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
345+
assert not (init_cfg and pretrained), \
346+
'init_cfg and pretrained cannot be set at the same time'
347+
if isinstance(pretrained, str):
348+
warnings.warn('DeprecationWarning: pretrained is deprecated, '
348349
'please use "init_cfg" instead')
349-
else:
350+
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
351+
elif pretrained is not None:
350352
raise TypeError('pretrained must be a str or None')
351353

352354
self.embed_dims = embed_dims
353-
354355
self.num_stages = num_stages
355356
self.num_layers = num_layers
356357
self.num_heads = num_heads
@@ -362,7 +363,6 @@ def __init__(self,
362363

363364
self.out_indices = out_indices
364365
assert max(out_indices) < self.num_stages
365-
self.pretrained = pretrained
366366

367367
# transformer encoder
368368
dpr = [
@@ -401,7 +401,7 @@ def __init__(self,
401401
cur += num_layer
402402

403403
def init_weights(self):
404-
if self.pretrained is None:
404+
if self.init_cfg is None:
405405
for m in self.modules():
406406
if isinstance(m, nn.Linear):
407407
trunc_normal_init(m, std=.02, bias=0.)
@@ -413,16 +413,8 @@ def init_weights(self):
413413
fan_out //= m.groups
414414
normal_init(
415415
m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0)
416-
elif isinstance(self.pretrained, str):
417-
logger = get_root_logger()
418-
checkpoint = _load_checkpoint(
419-
self.pretrained, logger=logger, map_location='cpu')
420-
if 'state_dict' in checkpoint:
421-
state_dict = checkpoint['state_dict']
422-
else:
423-
state_dict = checkpoint
424-
425-
self.load_state_dict(state_dict, False)
416+
else:
417+
super(MixVisionTransformer, self).init_weights()
426418

427419
def forward(self, x):
428420
outs = []

tests/test_models/test_backbones/test_mit.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,59 @@ def test_mit():
5555
# Out identity
5656
outs = MHA(temp, hw_shape, temp)
5757
assert out.shape == (1, token_len, 64)
58+
59+
60+
def test_mit_init():
61+
path = 'PATH_THAT_DO_NOT_EXIST'
62+
# Test all combinations of pretrained and init_cfg
63+
# pretrained=None, init_cfg=None
64+
model = MixVisionTransformer(pretrained=None, init_cfg=None)
65+
assert model.init_cfg is None
66+
model.init_weights()
67+
68+
# pretrained=None
69+
# init_cfg loads pretrain from an non-existent file
70+
model = MixVisionTransformer(
71+
pretrained=None, init_cfg=dict(type='Pretrained', checkpoint=path))
72+
assert model.init_cfg == dict(type='Pretrained', checkpoint=path)
73+
# Test loading a checkpoint from an non-existent file
74+
with pytest.raises(OSError):
75+
model.init_weights()
76+
77+
# pretrained=None
78+
# init_cfg=123, whose type is unsupported
79+
model = MixVisionTransformer(pretrained=None, init_cfg=123)
80+
with pytest.raises(TypeError):
81+
model.init_weights()
82+
83+
# pretrained loads pretrain from an non-existent file
84+
# init_cfg=None
85+
model = MixVisionTransformer(pretrained=path, init_cfg=None)
86+
assert model.init_cfg == dict(type='Pretrained', checkpoint=path)
87+
# Test loading a checkpoint from an non-existent file
88+
with pytest.raises(OSError):
89+
model.init_weights()
90+
91+
# pretrained loads pretrain from an non-existent file
92+
# init_cfg loads pretrain from an non-existent file
93+
with pytest.raises(AssertionError):
94+
MixVisionTransformer(
95+
pretrained=path, init_cfg=dict(type='Pretrained', checkpoint=path))
96+
with pytest.raises(AssertionError):
97+
MixVisionTransformer(pretrained=path, init_cfg=123)
98+
99+
# pretrain=123, whose type is unsupported
100+
# init_cfg=None
101+
with pytest.raises(TypeError):
102+
MixVisionTransformer(pretrained=123, init_cfg=None)
103+
104+
# pretrain=123, whose type is unsupported
105+
# init_cfg loads pretrain from an non-existent file
106+
with pytest.raises(AssertionError):
107+
MixVisionTransformer(
108+
pretrained=123, init_cfg=dict(type='Pretrained', checkpoint=path))
109+
110+
# pretrain=123, whose type is unsupported
111+
# init_cfg=123, whose type is unsupported
112+
with pytest.raises(AssertionError):
113+
MixVisionTransformer(pretrained=123, init_cfg=123)

0 commit comments

Comments
 (0)