99from mmcv .cnn .bricks .transformer import MultiheadAttention
1010from mmcv .cnn .utils .weight_init import (constant_init , normal_init ,
1111 trunc_normal_init )
12- from mmcv .runner import BaseModule , ModuleList , Sequential , _load_checkpoint
12+ from mmcv .runner import BaseModule , ModuleList , Sequential
1313
14- from ...utils import get_root_logger
1514from ..builder import BACKBONES
1615from ..utils import PatchEmbed , nchw_to_nlc , nlc_to_nchw
1716
@@ -341,16 +340,18 @@ def __init__(self,
341340 norm_cfg = dict (type = 'LN' , eps = 1e-6 ),
342341 pretrained = None ,
343342 init_cfg = None ):
344- super ().__init__ (init_cfg = init_cfg )
343+ super (MixVisionTransformer , self ).__init__ (init_cfg = init_cfg )
345344
346- if isinstance (pretrained , str ) or pretrained is None :
347- warnings .warn ('DeprecationWarning: pretrained is a deprecated, '
345+ assert not (init_cfg and pretrained ), \
346+ 'init_cfg and pretrained cannot be set at the same time'
347+ if isinstance (pretrained , str ):
348+ warnings .warn ('DeprecationWarning: pretrained is deprecated, '
348349 'please use "init_cfg" instead' )
349- else :
350+ self .init_cfg = dict (type = 'Pretrained' , checkpoint = pretrained )
351+ elif pretrained is not None :
350352 raise TypeError ('pretrained must be a str or None' )
351353
352354 self .embed_dims = embed_dims
353-
354355 self .num_stages = num_stages
355356 self .num_layers = num_layers
356357 self .num_heads = num_heads
@@ -362,7 +363,6 @@ def __init__(self,
362363
363364 self .out_indices = out_indices
364365 assert max (out_indices ) < self .num_stages
365- self .pretrained = pretrained
366366
367367 # transformer encoder
368368 dpr = [
@@ -401,7 +401,7 @@ def __init__(self,
401401 cur += num_layer
402402
403403 def init_weights (self ):
404- if self .pretrained is None :
404+ if self .init_cfg is None :
405405 for m in self .modules ():
406406 if isinstance (m , nn .Linear ):
407407 trunc_normal_init (m , std = .02 , bias = 0. )
@@ -413,16 +413,8 @@ def init_weights(self):
413413 fan_out //= m .groups
414414 normal_init (
415415 m , mean = 0 , std = math .sqrt (2.0 / fan_out ), bias = 0 )
416- elif isinstance (self .pretrained , str ):
417- logger = get_root_logger ()
418- checkpoint = _load_checkpoint (
419- self .pretrained , logger = logger , map_location = 'cpu' )
420- if 'state_dict' in checkpoint :
421- state_dict = checkpoint ['state_dict' ]
422- else :
423- state_dict = checkpoint
424-
425- self .load_state_dict (state_dict , False )
416+ else :
417+ super (MixVisionTransformer , self ).init_weights ()
426418
427419 def forward (self , x ):
428420 outs = []
0 commit comments