88import torch .nn .functional as F
99import torch .utils .checkpoint as cp
1010from mmcv .cnn import (Conv2d , Linear , build_activation_layer , build_norm_layer ,
11- constant_init , kaiming_init , normal_init , xavier_init )
11+ constant_init , kaiming_init , normal_init )
1212from mmcv .runner import _load_checkpoint
1313from mmcv .utils .parrots_wrapper import _BatchNorm
1414
1515from mmseg .utils import get_root_logger
1616from ..builder import BACKBONES
17+ from ..utils import DropPath , trunc_normal_
1718
1819
1920class Mlp (nn .Module ):
@@ -114,10 +115,14 @@ class Block(nn.Module):
114115 Default: 0.
115116 proj_drop (float): Drop rate for attn layer output weights.
116117 Default: 0.
118+ drop_path (float): Drop rate for paths of model.
119+ Default: 0.
117120 act_cfg (dict): Config dict for activation layer.
118121 Default: dict(type='GELU').
119122 norm_cfg (dict): Config dict for normalization layer.
120123 Default: dict(type='LN', requires_grad=True).
124+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
125+ memory while slowing down the training speed. Default: False.
121126 """
122127
123128 def __init__ (self ,
@@ -129,14 +134,17 @@ def __init__(self,
129134 drop = 0. ,
130135 attn_drop = 0. ,
131136 proj_drop = 0. ,
137+ drop_path = 0. ,
132138 act_cfg = dict (type = 'GELU' ),
133- norm_cfg = dict (type = 'LN' ),
139+ norm_cfg = dict (type = 'LN' , eps = 1e-6 ),
134140 with_cp = False ):
135141 super (Block , self ).__init__ ()
136142 self .with_cp = with_cp
137143 _ , self .norm1 = build_norm_layer (norm_cfg , dim )
138144 self .attn = Attention (dim , num_heads , qkv_bias , qk_scale , attn_drop ,
139145 proj_drop )
146+ self .drop_path = DropPath (
147+ drop_path ) if drop_path > 0. else nn .Identity ()
140148 _ , self .norm2 = build_norm_layer (norm_cfg , dim )
141149 mlp_hidden_dim = int (dim * mlp_ratio )
142150 self .mlp = Mlp (
@@ -148,8 +156,8 @@ def __init__(self,
148156 def forward (self , x ):
149157
150158 def _inner_forward (x ):
151- out = x + self .attn (self .norm1 (x ))
152- out = out + self .mlp (self .norm2 (out ))
159+ out = x + self .drop_path ( self . attn (self .norm1 (x ) ))
160+ out = out + self .drop_path ( self . mlp (self .norm2 (out ) ))
153161 return out
154162
155163 if self .with_cp and x .requires_grad :
@@ -164,7 +172,7 @@ class PatchEmbed(nn.Module):
164172 """Image to Patch Embedding.
165173
166174 Args:
167- img_size (int, tuple): Input image size.
175+ img_size (int | tuple): Input image size.
168176 default: 224.
169177 patch_size (int): Width and height for a patch.
170178 default: 16.
@@ -202,24 +210,34 @@ class VisionTransformer(nn.Module):
202210 Image Recognition at Scale` - https://arxiv.org/abs/2010.11929
203211
204212 Args:
205- img_size (tuple): input image size. Default: (224, 224) .
213+ img_size (tuple): input image size. Default: (224, 224) .
206214 patch_size (int, tuple): patch size. Default: 16.
207215 in_channels (int): number of input channels. Default: 3.
208216 embed_dim (int): embedding dimension. Default: 768.
209217 depth (int): depth of transformer. Default: 12.
210218 num_heads (int): number of attention heads. Default: 12.
211- mlp_ratio (int): ratio of mlp hidden dim to embedding dim. Default: 4.
219+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim.
220+ Default: 4.
221+ out_indices (list | tuple | int): Output from which stages.
222+ Default: -1.
212223 qkv_bias (bool): enable bias for qkv if True. Default: True.
213224 qk_scale (float): override default qk scale of head_dim ** -0.5 if set.
214225 drop_rate (float): dropout rate. Default: 0.
215226 attn_drop_rate (float): attention dropout rate. Default: 0.
227+ drop_path_rate (float): Rate of DropPath. Default: 0.
216228 norm_cfg (dict): Config dict for normalization layer.
217- Default: dict(type='LN', requires_grad=True).
229+ Default: dict(type='LN', eps=1e-6, requires_grad=True).
218230 act_cfg (dict): Config dict for activation layer.
219231 Default: dict(type='GELU').
220232 norm_eval (bool): Whether to set norm layers to eval mode, namely,
221233 freeze running stats (mean and var). Note: Effect on Batch Norm
222234 and its variants only. Default: False.
235+ final_norm (bool): Whether to add a additional layer to normalize
236+ final feature map. Default: False.
237+ interpolate_mode (str): Select the interpolate mode for position
238+ embeding vector resize. Default: bicubic.
239+ with_cls_token (bool): If concatenating class token into image tokens
240+ as transformer input. Default: True.
223241 with_cp (bool): Use checkpoint or not. Using checkpoint
224242 will save some memory while slowing down the training speed.
225243 Default: False.
@@ -233,13 +251,18 @@ def __init__(self,
233251 depth = 12 ,
234252 num_heads = 12 ,
235253 mlp_ratio = 4 ,
254+ out_indices = 11 ,
236255 qkv_bias = True ,
237256 qk_scale = None ,
238257 drop_rate = 0. ,
239258 attn_drop_rate = 0. ,
240- norm_cfg = dict (type = 'LN' ),
259+ drop_path_rate = 0. ,
260+ norm_cfg = dict (type = 'LN' , eps = 1e-6 , requires_grad = True ),
241261 act_cfg = dict (type = 'GELU' ),
242262 norm_eval = False ,
263+ final_norm = False ,
264+ with_cls_token = True ,
265+ interpolate_mode = 'bicubic' ,
243266 with_cp = False ):
244267 super (VisionTransformer , self ).__init__ ()
245268 self .img_size = img_size
@@ -251,24 +274,39 @@ def __init__(self,
251274 in_channels = in_channels ,
252275 embed_dim = embed_dim )
253276
277+ self .with_cls_token = with_cls_token
278+ self .cls_token = nn .Parameter (torch .zeros (1 , 1 , self .embed_dim ))
254279 self .pos_embed = nn .Parameter (
255- torch .zeros (1 , self .patch_embed .num_patches , embed_dim ))
280+ torch .zeros (1 , self .patch_embed .num_patches + 1 , embed_dim ))
256281 self .pos_drop = nn .Dropout (p = drop_rate )
257282
258- self .blocks = nn .Sequential (* [
283+ if isinstance (out_indices , int ):
284+ self .out_indices = [out_indices ]
285+ elif isinstance (out_indices , list ) or isinstance (out_indices , tuple ):
286+ self .out_indices = out_indices
287+ else :
288+ raise TypeError ('out_indices must be type of int, list or tuple' )
289+
290+ dpr = [x .item () for x in torch .linspace (0 , drop_path_rate , depth )
291+ ] # stochastic depth decay rule
292+ self .blocks = nn .ModuleList ([
259293 Block (
260294 dim = embed_dim ,
261295 num_heads = num_heads ,
262296 mlp_ratio = mlp_ratio ,
263297 qkv_bias = qkv_bias ,
264298 qk_scale = qk_scale ,
265- drop = drop_rate ,
299+ drop = dpr [ i ] ,
266300 attn_drop = attn_drop_rate ,
267301 act_cfg = act_cfg ,
268302 norm_cfg = norm_cfg ,
269303 with_cp = with_cp ) for i in range (depth )
270304 ])
271- _ , self .norm = build_norm_layer (norm_cfg , embed_dim )
305+
306+ self .interpolate_mode = interpolate_mode
307+ self .final_norm = final_norm
308+ if final_norm :
309+ _ , self .norm = build_norm_layer (norm_cfg , embed_dim )
272310
273311 self .norm_eval = norm_eval
274312 self .with_cp = with_cp
@@ -283,28 +321,26 @@ def init_weights(self, pretrained=None):
283321 state_dict = checkpoint
284322
285323 if 'pos_embed' in state_dict .keys ():
286- state_dict ['pos_embed' ] = state_dict ['pos_embed' ][:, 1 :, :]
287- logger .info (
288- msg = 'Remove the "cls_token" dimension from the checkpoint' )
289-
290324 if self .pos_embed .shape != state_dict ['pos_embed' ].shape :
291325 logger .info (msg = f'Resize the pos_embed shape from \
292- { state_dict ["pos_embed" ].shape } to \
293- { self .pos_embed .shape } ' )
326+ { state_dict ["pos_embed" ].shape } to { self .pos_embed .shape } ' )
294327 h , w = self .img_size
295- pos_size = int (math .sqrt (state_dict ['pos_embed' ].shape [1 ]))
328+ pos_size = int (
329+ math .sqrt (state_dict ['pos_embed' ].shape [1 ] - 1 ))
296330 state_dict ['pos_embed' ] = self .resize_pos_embed (
297331 state_dict ['pos_embed' ], (h , w ), (pos_size , pos_size ),
298- self .patch_size )
332+ self .patch_size , self .interpolate_mode )
333+
299334 self .load_state_dict (state_dict , False )
300335
301336 elif pretrained is None :
302337 # We only implement the 'jax_impl' initialization implemented at
303338 # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
304- normal_init (self .pos_embed )
339+ trunc_normal_ (self .pos_embed , std = .02 )
340+ trunc_normal_ (self .cls_token , std = .02 )
305341 for n , m in self .named_modules ():
306342 if isinstance (m , Linear ):
307- xavier_init (m .weight , distribution = 'uniform' )
343+ trunc_normal_ (m .weight , std = .02 )
308344 if m .bias is not None :
309345 if 'mlp' in n :
310346 normal_init (m .bias , std = 1e-6 )
@@ -316,7 +352,7 @@ def init_weights(self, pretrained=None):
316352 constant_init (m .bias , 0 )
317353 elif isinstance (m , (_BatchNorm , nn .GroupNorm , nn .LayerNorm )):
318354 constant_init (m .bias , 0 )
319- constant_init (m .weight , 1 )
355+ constant_init (m .weight , 1.0 )
320356 else :
321357 raise TypeError ('pretrained must be a str or None' )
322358
@@ -340,19 +376,20 @@ def _pos_embeding(self, img, patched_img, pos_embed):
340376 x_len , pos_len = patched_img .shape [1 ], pos_embed .shape [1 ]
341377 if x_len != pos_len :
342378 if pos_len == (self .img_size [0 ] // self .patch_size ) * (
343- self .img_size [1 ] // self .patch_size ):
379+ self .img_size [1 ] // self .patch_size ) + 1 :
344380 pos_h = self .img_size [0 ] // self .patch_size
345381 pos_w = self .img_size [1 ] // self .patch_size
346382 else :
347383 raise ValueError (
348384 'Unexpected shape of pos_embed, got {}.' .format (
349385 pos_embed .shape ))
350386 pos_embed = self .resize_pos_embed (pos_embed , img .shape [2 :],
351- (pos_h , pos_w ), self .patch_size )
352- return patched_img + pos_embed
387+ (pos_h , pos_w ), self .patch_size ,
388+ self .interpolate_mode )
389+ return self .pos_drop (patched_img + pos_embed )
353390
354391 @staticmethod
355- def resize_pos_embed (pos_embed , input_shpae , pos_shape , patch_size ):
392+ def resize_pos_embed (pos_embed , input_shpae , pos_shape , patch_size , mode ):
356393 """Resize pos_embed weights.
357394
358395 Resize pos_embed using bicubic interpolate method.
@@ -367,26 +404,52 @@ def resize_pos_embed(pos_embed, input_shpae, pos_shape, patch_size):
367404 assert pos_embed .ndim == 3 , 'shape of pos_embed must be [B, L, C]'
368405 input_h , input_w = input_shpae
369406 pos_h , pos_w = pos_shape
370- pos_embed = pos_embed .reshape (1 , pos_h , pos_w ,
371- pos_embed .shape [2 ]).permute (0 , 3 , 1 , 2 )
372- pos_embed = F .interpolate (
373- pos_embed ,
407+ cls_token_weight = pos_embed [:, 0 ]
408+ pos_embed_weight = pos_embed [:, (- 1 * pos_h * pos_w ):]
409+ pos_embed_weight = pos_embed_weight .reshape (
410+ 1 , pos_h , pos_w , pos_embed .shape [2 ]).permute (0 , 3 , 1 , 2 )
411+ pos_embed_weight = F .interpolate (
412+ pos_embed_weight ,
374413 size = [input_h // patch_size , input_w // patch_size ],
375414 align_corners = False ,
376- mode = 'bicubic' )
377- pos_embed = torch .flatten (pos_embed , 2 ).transpose (1 , 2 )
415+ mode = mode )
416+ cls_token_weight = cls_token_weight .unsqueeze (1 )
417+ pos_embed_weight = torch .flatten (pos_embed_weight , 2 ).transpose (1 , 2 )
418+ pos_embed = torch .cat ((cls_token_weight , pos_embed_weight ), dim = 1 )
378419 return pos_embed
379420
380421 def forward (self , inputs ):
422+ B = inputs .shape [0 ]
423+
381424 x = self .patch_embed (inputs )
425+
426+ cls_tokens = self .cls_token .expand (B , - 1 , - 1 )
427+ x = torch .cat ((cls_tokens , x ), dim = 1 )
382428 x = self ._pos_embeding (inputs , x , self .pos_embed )
383- x = self .blocks (x )
384- x = self .norm (x )
385- B , _ , C = x .shape
386- x = x .reshape (B , inputs .shape [2 ] // self .patch_size ,
387- inputs .shape [3 ] // self .patch_size ,
388- C ).permute (0 , 3 , 1 , 2 )
389- return [x ]
429+
430+ if not self .with_cls_token :
431+ # Remove class token for transformer input
432+ x = x [:, 1 :]
433+
434+ outs = []
435+ for i , blk in enumerate (self .blocks ):
436+ x = blk (x )
437+ if i == len (self .blocks ) - 1 :
438+ if self .final_norm :
439+ x = self .norm (x )
440+ if i in self .out_indices :
441+ if self .with_cls_token :
442+ # Remove class token and reshape token for decoder head
443+ out = x [:, 1 :]
444+ else :
445+ out = x
446+ B , _ , C = out .shape
447+ out = out .reshape (B , inputs .shape [2 ] // self .patch_size ,
448+ inputs .shape [3 ] // self .patch_size ,
449+ C ).permute (0 , 3 , 1 , 2 )
450+ outs .append (out )
451+
452+ return tuple (outs )
390453
391454 def train (self , mode = True ):
392455 super (VisionTransformer , self ).train (mode )
0 commit comments