Skip to content

Commit 10886b0

Browse files
authored
fix load ckpt bug in swin (open-mmlab#928)
1 parent c1dcf91 commit 10886b0

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

mmseg/models/backbones/swin.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -680,7 +680,7 @@ def init_weights(self):
680680
f'`init_cfg` in ' \
681681
f'{self.__class__.__name__} '
682682
ckpt = _load_checkpoint(
683-
self.init_cfg.checkpoint, logger=logger, map_location='cpu')
683+
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
684684
if 'state_dict' in ckpt:
685685
_state_dict = ckpt['state_dict']
686686
elif 'model' in ckpt:
@@ -692,6 +692,8 @@ def init_weights(self):
692692
for k, v in _state_dict.items():
693693
if k.startswith('backbone.'):
694694
state_dict[k[9:]] = v
695+
else:
696+
state_dict[k] = v
695697

696698
# strip prefix of state_dict
697699
if list(state_dict.keys())[0].startswith('module.'):

tools/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def main():
9696
else:
9797
distributed = True
9898
init_dist(args.launcher, **cfg.dist_params)
99-
# gpu_ids is used to calculate iter when resuming checkpoint,
99+
# gpu_ids is used to calculate iter when resuming checkpoint
100100
_, world_size = get_dist_info()
101101
cfg.gpu_ids = range(world_size)
102102

0 commit comments

Comments
 (0)