Skip to content

[Fix] Fix the bug that vit cannot load pretrain properly when using i… #999

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Nov 3, 2021
Prev Previous commit
Next Next commit
[Fix] Modified the judgement logic
  • Loading branch information
RockeyCoss committed Oct 28, 2021
commit b7aaa6a5972b4e0fb8930047fca0b7be07e2d79c
19 changes: 6 additions & 13 deletions mmseg/models/backbones/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,18 +262,11 @@ def norm1(self):
return getattr(self, self.norm1_name)

def init_weights(self):
if (isinstance(self.pretrained, str)
or (isinstance(self.init_cfg, dict)
and self.init_cfg.get('type') == 'Pretrained')):
if (isinstance(self.init_cfg, dict)
and self.init_cfg.get('type') == 'Pretrained'):
logger = get_root_logger()
if self.pretrained:
checkpoint = _load_checkpoint(
self.pretrained, logger=logger, map_location='cpu')
else:
checkpoint = _load_checkpoint(
self.init_cfg['checkpoint'],
logger=logger,
map_location='cpu')
checkpoint = _load_checkpoint(
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')

if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
Expand All @@ -294,9 +287,9 @@ def init_weights(self):
(pos_size, pos_size), self.interpolate_mode)

self.load_state_dict(state_dict, False)

elif self.pretrained is None:
elif self.init_cfg is not None:
super(VisionTransformer, self).init_weights()
else:
# We only implement the 'jax_impl' initialization implemented at
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
trunc_normal_init(self.pos_embed, std=.02)
Expand Down