11import math
2+ import warnings
23
34import torch
45import torch .nn as nn
56import torch .nn .functional as F
67from mmcv .cnn import (build_conv_layer , build_norm_layer , constant_init ,
78 kaiming_init , normal_init , trunc_normal_init )
89from mmcv .cnn .bricks .transformer import FFN , MultiheadAttention
9- from mmcv .runner import _load_checkpoint
10- from mmcv .runner .base_module import BaseModule , ModuleList
10+ from mmcv .runner import BaseModule , ModuleList , _load_checkpoint
1111from torch .nn .modules .batchnorm import _BatchNorm
1212from torch .nn .modules .utils import _pair as to_2tuple
1313
@@ -140,12 +140,6 @@ def __init__(self,
140140 self .norm = None
141141
142142 def forward (self , x ):
143- B , C , H , W = x .shape
144- # FIXME look at relaxing size constraints
145- # assert H == self.img_size[0] and W == self.img_size[1], \
146- # f"Input image size ({H}*{W}) doesn't " \
147- # f'match model ({self.img_size[0]}*{self.img_size[1]}).'
148- # The output size is (B, N, D), where N=H*W/P/P, D is embid_dim
149143 x = self .projection (x ).flatten (2 ).transpose (1 , 2 )
150144
151145 if self .norm is not None :
@@ -185,8 +179,12 @@ class VisionTransformer(BaseModule):
185179 Default: dict(type='LN')
186180 act_cfg (dict): The activation config for FFNs.
187181 Defalut: dict(type='GELU').
188- final_norm (bool): Whether to add a additional layer to normalize
182+ patch_norm (bool): Whether to add a norm in PatchEmbed Block.
183+ Default: False.
184+ final_norm (bool): Whether to add a additional layer to normalize
189185 final feature map. Default: False.
186+ out_shape (str): Select the output format of feature information.
187+ Default: NCHW.
190188 interpolate_mode (str): Select the interpolate mode for position
191189 embeding vector resize. Default: bicubic.
192190 num_fcs (int): The number of fully-connected layers for FFNs.
@@ -198,6 +196,9 @@ class VisionTransformer(BaseModule):
198196 some memory while slowing down the training speed. Default: False.
199197 pretrain_style (str): Choose to use timm or mmcls pretrain weights.
200198 Default: timm.
199+ pretrained (str, optional): model pretrained path. Default: None.
200+ init_cfg (dict or list[dict], optional): Initialization config dict.
201+ Default: None.
201202 """
202203
203204 def __init__ (self ,
@@ -216,12 +217,16 @@ def __init__(self,
216217 with_cls_token = True ,
217218 norm_cfg = dict (type = 'LN' ),
218219 act_cfg = dict (type = 'GELU' ),
220+ patch_norm = False ,
219221 final_norm = False ,
222+ out_shape = 'NCHW' ,
220223 interpolate_mode = 'bicubic' ,
221224 num_fcs = 2 ,
222225 norm_eval = False ,
223226 with_cp = False ,
224- pretrain_style = 'timm' ):
227+ pretrain_style = 'timm' ,
228+ pretrained = None ,
229+ init_cfg = None ):
225230 super (VisionTransformer , self ).__init__ ()
226231
227232 if isinstance (img_size , int ):
@@ -235,16 +240,32 @@ def __init__(self,
235240
236241 assert pretrain_style in ['timm' , 'mmcls' ]
237242
238- self .pretrain_style = pretrain_style
243+ assert out_shape in ['NLC' ,
244+ 'NCHW' ], 'output shape must be "NLC" or "NCHW".'
245+
246+ if isinstance (pretrained , str ) or pretrained is None :
247+ warnings .warn ('DeprecationWarning: pretrained is a deprecated, '
248+ 'please use "init_cfg" instead' )
249+ else :
250+ raise TypeError ('pretrained must be a str or None' )
251+
239252 self .img_size = img_size
240253 self .patch_size = patch_size
254+ self .out_shape = out_shape
255+ self .interpolate_mode = interpolate_mode
256+ self .norm_eval = norm_eval
257+ self .with_cp = with_cp
258+ self .pretrain_style = pretrain_style
259+ self .pretrained = pretrained
260+ self .init_cfg = init_cfg
241261
242262 self .patch_embed = PatchEmbed (
243263 img_size = img_size ,
244264 patch_size = patch_size ,
245265 in_channels = in_channels ,
246266 embed_dim = embed_dims ,
247- norm_cfg = norm_cfg )
267+ norm_cfg = norm_cfg if patch_norm else None )
268+
248269 num_patches = self .patch_embed .num_patches
249270
250271 self .with_cls_token = with_cls_token
@@ -280,24 +301,20 @@ def __init__(self,
280301 norm_cfg = norm_cfg ,
281302 batch_first = True ))
282303
283- self .interpolate_mode = interpolate_mode
284304 self .final_norm = final_norm
285305 if final_norm :
286306 self .norm1_name , norm1 = build_norm_layer (
287307 norm_cfg , embed_dims , postfix = 1 )
288308 self .add_module (self .norm1_name , norm1 )
289309
290- self .norm_eval = norm_eval
291- self .with_cp = with_cp
292-
293310 @property
294311 def norm1 (self ):
295312 return getattr (self , self .norm1_name )
296313
297- def init_weights (self , pretrained = None ):
298- if isinstance (pretrained , str ):
314+ def init_weights (self ):
315+ if isinstance (self . pretrained , str ):
299316 logger = get_root_logger ()
300- checkpoint = _load_checkpoint (pretrained , logger = logger )
317+ checkpoint = _load_checkpoint (self . pretrained , logger = logger )
301318 if 'state_dict' in checkpoint :
302319 state_dict = checkpoint ['state_dict' ]
303320 elif 'model' in checkpoint :
@@ -325,7 +342,8 @@ def init_weights(self, pretrained=None):
325342
326343 self .load_state_dict (state_dict , False )
327344
328- elif pretrained is None :
345+ elif self .pretrained is None :
346+ super (VisionTransformer , self ).init_weights ()
329347 # We only implement the 'jax_impl' initialization implemented at
330348 # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
331349 trunc_normal_init (self .pos_embed , std = .02 )
@@ -345,8 +363,6 @@ def init_weights(self, pretrained=None):
345363 elif isinstance (m , (_BatchNorm , nn .GroupNorm , nn .LayerNorm )):
346364 constant_init (m .bias , 0 )
347365 constant_init (m .weight , 1.0 )
348- else :
349- raise TypeError ('pretrained must be a str or None' )
350366
351367 def _pos_embeding (self , img , patched_img , pos_embed ):
352368 """Positiong embeding method.
@@ -436,10 +452,11 @@ def forward(self, inputs):
436452 out = x [:, 1 :]
437453 else :
438454 out = x
439- B , _ , C = out .shape
440- out = out .reshape (B , inputs .shape [2 ] // self .patch_size ,
441- inputs .shape [3 ] // self .patch_size ,
442- C ).permute (0 , 3 , 1 , 2 )
455+ if self .out_shape == 'NCHW' :
456+ B , _ , C = out .shape
457+ out = out .reshape (B , inputs .shape [2 ] // self .patch_size ,
458+ inputs .shape [3 ] // self .patch_size ,
459+ C ).permute (0 , 3 , 1 , 2 )
443460 outs .append (out )
444461
445462 return tuple (outs )
0 commit comments