Skip to content

Commit 66b778c

Browse files
authored
[Improve] Use MMCV load_state_dict func in ViT/Swin (open-mmlab#1272)
* [Improve] Use MMCV load_state_dict func in ViT/Swin * use CheckpointLoader instead
1 parent b4314f9 commit 66b778c

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

mmseg/models/backbones/swin.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
from mmcv.cnn.bricks.transformer import FFN, build_dropout
1212
from mmcv.cnn.utils.weight_init import (constant_init, trunc_normal_,
1313
trunc_normal_init)
14-
from mmcv.runner import BaseModule, ModuleList, _load_checkpoint
14+
from mmcv.runner import (BaseModule, CheckpointLoader, ModuleList,
15+
load_state_dict)
1516
from mmcv.utils import to_2tuple
1617

1718
from ...utils import get_root_logger
@@ -678,7 +679,7 @@ def init_weights(self):
678679
f'specify `Pretrained` in ' \
679680
f'`init_cfg` in ' \
680681
f'{self.__class__.__name__} '
681-
ckpt = _load_checkpoint(
682+
ckpt = CheckpointLoader.load_checkpoint(
682683
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
683684
if 'state_dict' in ckpt:
684685
_state_dict = ckpt['state_dict']
@@ -732,7 +733,7 @@ def init_weights(self):
732733
nH2, L2).permute(1, 0).contiguous()
733734

734735
# load state_dict
735-
self.load_state_dict(state_dict, False)
736+
load_state_dict(self, state_dict, strict=False, logger=logger)
736737

737738
def forward(self, x):
738739
x, hw_shape = self.patch_embed(x)

mmseg/models/backbones/vit.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
99
from mmcv.cnn.utils.weight_init import (constant_init, kaiming_init,
1010
trunc_normal_)
11-
from mmcv.runner import BaseModule, ModuleList, _load_checkpoint
11+
from mmcv.runner import (BaseModule, CheckpointLoader, ModuleList,
12+
load_state_dict)
1213
from torch.nn.modules.batchnorm import _BatchNorm
1314
from torch.nn.modules.utils import _pair as to_2tuple
1415

@@ -266,7 +267,7 @@ def init_weights(self):
266267
if (isinstance(self.init_cfg, dict)
267268
and self.init_cfg.get('type') == 'Pretrained'):
268269
logger = get_root_logger()
269-
checkpoint = _load_checkpoint(
270+
checkpoint = CheckpointLoader.load_checkpoint(
270271
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
271272

272273
if 'state_dict' in checkpoint:
@@ -287,7 +288,7 @@ def init_weights(self):
287288
(h // self.patch_size, w // self.patch_size),
288289
(pos_size, pos_size), self.interpolate_mode)
289290

290-
self.load_state_dict(state_dict, False)
291+
load_state_dict(self, state_dict, strict=False, logger=logger)
291292
elif self.init_cfg is not None:
292293
super(VisionTransformer, self).init_weights()
293294
else:

0 commit comments

Comments
 (0)