8
8
from mmcv .cnn .bricks .transformer import FFN , MultiheadAttention
9
9
from mmcv .cnn .utils .weight_init import (constant_init , kaiming_init ,
10
10
trunc_normal_ )
11
- from mmcv .runner import BaseModule , ModuleList , _load_checkpoint
11
+ from mmcv .runner import (BaseModule , CheckpointLoader , ModuleList ,
12
+ load_state_dict )
12
13
from torch .nn .modules .batchnorm import _BatchNorm
13
14
from torch .nn .modules .utils import _pair as to_2tuple
14
15
@@ -266,7 +267,7 @@ def init_weights(self):
266
267
if (isinstance (self .init_cfg , dict )
267
268
and self .init_cfg .get ('type' ) == 'Pretrained' ):
268
269
logger = get_root_logger ()
269
- checkpoint = _load_checkpoint (
270
+ checkpoint = CheckpointLoader . load_checkpoint (
270
271
self .init_cfg ['checkpoint' ], logger = logger , map_location = 'cpu' )
271
272
272
273
if 'state_dict' in checkpoint :
@@ -287,7 +288,7 @@ def init_weights(self):
287
288
(h // self .patch_size , w // self .patch_size ),
288
289
(pos_size , pos_size ), self .interpolate_mode )
289
290
290
- self . load_state_dict (state_dict , False )
291
+ load_state_dict (self , state_dict , strict = False , logger = logger )
291
292
elif self .init_cfg is not None :
292
293
super (VisionTransformer , self ).init_weights ()
293
294
else :
0 commit comments