44
55import torch
66import torch .nn as nn
7- from mmcv .cnn import (Conv2d , build_activation_layer , build_norm_layer ,
8- constant_init , normal_init , trunc_normal_init )
7+ from mmcv .cnn import Conv2d , build_activation_layer , build_norm_layer
98from mmcv .cnn .bricks .drop import build_dropout
109from mmcv .cnn .bricks .transformer import MultiheadAttention
10+ from mmcv .cnn .utils .weight_init import (constant_init , normal_init ,
11+ trunc_normal_init )
1112from mmcv .runner import BaseModule , ModuleList , Sequential , _load_checkpoint
1213
1314from ...utils import get_root_logger
@@ -343,7 +344,7 @@ def __init__(self,
343344 norm_cfg = dict (type = 'LN' , eps = 1e-6 ),
344345 pretrained = None ,
345346 init_cfg = None ):
346- super ().__init__ ()
347+ super ().__init__ (init_cfg = init_cfg )
347348
348349 if isinstance (pretrained , str ) or pretrained is None :
349350 warnings .warn ('DeprecationWarning: pretrained is a deprecated, '
@@ -365,7 +366,6 @@ def __init__(self,
365366 self .out_indices = out_indices
366367 assert max (out_indices ) < self .num_stages
367368 self .pretrained = pretrained
368- self .init_cfg = init_cfg
369369
370370 # transformer encoder
371371 dpr = [
@@ -407,19 +407,15 @@ def init_weights(self):
407407 if self .pretrained is None :
408408 for m in self .modules ():
409409 if isinstance (m , nn .Linear ):
410- trunc_normal_init (m .weight , std = .02 )
411- if m .bias is not None :
412- constant_init (m .bias , 0 )
410+ trunc_normal_init (m , std = .02 , bias = 0. )
413411 elif isinstance (m , nn .LayerNorm ):
414- constant_init (m .bias , 0 )
415- constant_init (m .weight , 1.0 )
412+ constant_init (m , val = 1.0 , bias = 0. )
416413 elif isinstance (m , nn .Conv2d ):
417414 fan_out = m .kernel_size [0 ] * m .kernel_size [
418415 1 ] * m .out_channels
419416 fan_out //= m .groups
420- normal_init (m .weight , 0 , math .sqrt (2.0 / fan_out ))
421- if m .bias is not None :
422- constant_init (m .bias , 0 )
417+ normal_init (
418+ m , mean = 0 , std = math .sqrt (2.0 / fan_out ), bias = 0 )
423419 elif isinstance (self .pretrained , str ):
424420 logger = get_root_logger ()
425421 checkpoint = _load_checkpoint (
0 commit comments