Skip to content

Commit 7a1c9a5

Browse files
[Fix] Fix the bug that vit cannot load pretrain properly when using i… (open-mmlab#999)
* [Fix] Fix the bug that vit cannot load pretrain properly when using init_cfg to specify the pretrain scheme * [Fix] fix the coverage problem * Update mmseg/models/backbones/vit.py Co-authored-by: Junjun2016 <[email protected]> * [Fix] make the predicate more concise and clearer * [Fix] Modified the judgement logic * Update tests/test_models/test_backbones/test_vit.py Co-authored-by: Junjun2016 <[email protected]> * add comments Co-authored-by: Junjun2016 <[email protected]>
1 parent 14dc00a commit 7a1c9a5

File tree

2 files changed

+69
-9
lines changed

2 files changed

+69
-9
lines changed

mmseg/models/backbones/vit.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def __init__(self,
170170
with_cp=False,
171171
pretrained=None,
172172
init_cfg=None):
173-
super(VisionTransformer, self).__init__()
173+
super(VisionTransformer, self).__init__(init_cfg=init_cfg)
174174

175175
if isinstance(img_size, int):
176176
img_size = to_2tuple(img_size)
@@ -185,10 +185,13 @@ def __init__(self,
185185
assert with_cls_token is True, f'with_cls_token must be True if' \
186186
f'set output_cls_token to True, but got {with_cls_token}'
187187

188-
if isinstance(pretrained, str) or pretrained is None:
189-
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
188+
assert not (init_cfg and pretrained), \
189+
'init_cfg and pretrained cannot be set at the same time'
190+
if isinstance(pretrained, str):
191+
warnings.warn('DeprecationWarning: pretrained is deprecated, '
190192
'please use "init_cfg" instead')
191-
else:
193+
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
194+
elif pretrained is not None:
192195
raise TypeError('pretrained must be a str or None')
193196

194197
self.img_size = img_size
@@ -197,7 +200,6 @@ def __init__(self,
197200
self.norm_eval = norm_eval
198201
self.with_cp = with_cp
199202
self.pretrained = pretrained
200-
self.init_cfg = init_cfg
201203

202204
self.patch_embed = PatchEmbed(
203205
in_channels=in_channels,
@@ -260,10 +262,12 @@ def norm1(self):
260262
return getattr(self, self.norm1_name)
261263

262264
def init_weights(self):
263-
if isinstance(self.pretrained, str):
265+
if (isinstance(self.init_cfg, dict)
266+
and self.init_cfg.get('type') == 'Pretrained'):
264267
logger = get_root_logger()
265268
checkpoint = _load_checkpoint(
266-
self.pretrained, logger=logger, map_location='cpu')
269+
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
270+
267271
if 'state_dict' in checkpoint:
268272
state_dict = checkpoint['state_dict']
269273
else:
@@ -283,9 +287,9 @@ def init_weights(self):
283287
(pos_size, pos_size), self.interpolate_mode)
284288

285289
self.load_state_dict(state_dict, False)
286-
287-
elif self.pretrained is None:
290+
elif self.init_cfg is not None:
288291
super(VisionTransformer, self).init_weights()
292+
else:
289293
# We only implement the 'jax_impl' initialization implemented at
290294
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
291295
trunc_normal_init(self.pos_embed, std=.02)

tests/test_models/test_backbones/test_vit.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,3 +118,59 @@ def test_vit_backbone():
118118
feat = model(imgs)
119119
assert feat[0][0].shape == (1, 768, 14, 14)
120120
assert feat[0][1].shape == (1, 768)
121+
122+
123+
def test_vit_init():
124+
path = 'PATH_THAT_DO_NOT_EXIST'
125+
# Test all combinations of pretrained and init_cfg
126+
# pretrained=None, init_cfg=None
127+
model = VisionTransformer(pretrained=None, init_cfg=None)
128+
assert model.init_cfg is None
129+
model.init_weights()
130+
131+
# pretrained=None
132+
# init_cfg loads pretrain from an non-existent file
133+
model = VisionTransformer(
134+
pretrained=None, init_cfg=dict(type='Pretrained', checkpoint=path))
135+
assert model.init_cfg == dict(type='Pretrained', checkpoint=path)
136+
# Test loading a checkpoint from an non-existent file
137+
with pytest.raises(OSError):
138+
model.init_weights()
139+
140+
# pretrained=None
141+
# init_cfg=123, whose type is unsupported
142+
model = VisionTransformer(pretrained=None, init_cfg=123)
143+
with pytest.raises(TypeError):
144+
model.init_weights()
145+
146+
# pretrained loads pretrain from an non-existent file
147+
# init_cfg=None
148+
model = VisionTransformer(pretrained=path, init_cfg=None)
149+
assert model.init_cfg == dict(type='Pretrained', checkpoint=path)
150+
# Test loading a checkpoint from an non-existent file
151+
with pytest.raises(OSError):
152+
model.init_weights()
153+
154+
# pretrained loads pretrain from an non-existent file
155+
# init_cfg loads pretrain from an non-existent file
156+
with pytest.raises(AssertionError):
157+
model = VisionTransformer(
158+
pretrained=path, init_cfg=dict(type='Pretrained', checkpoint=path))
159+
with pytest.raises(AssertionError):
160+
model = VisionTransformer(pretrained=path, init_cfg=123)
161+
162+
# pretrain=123, whose type is unsupported
163+
# init_cfg=None
164+
with pytest.raises(TypeError):
165+
model = VisionTransformer(pretrained=123, init_cfg=None)
166+
167+
# pretrain=123, whose type is unsupported
168+
# init_cfg loads pretrain from an non-existent file
169+
with pytest.raises(AssertionError):
170+
model = VisionTransformer(
171+
pretrained=123, init_cfg=dict(type='Pretrained', checkpoint=path))
172+
173+
# pretrain=123, whose type is unsupported
174+
# init_cfg=123, whose type is unsupported
175+
with pytest.raises(AssertionError):
176+
model = VisionTransformer(pretrained=123, init_cfg=123)

0 commit comments

Comments
 (0)