8
8
import torch .nn .functional as F
9
9
import torch .utils .checkpoint as cp
10
10
from mmcv .cnn import (Conv2d , Linear , build_activation_layer , build_norm_layer ,
11
- constant_init , kaiming_init , normal_init , xavier_init )
11
+ constant_init , kaiming_init , normal_init )
12
12
from mmcv .runner import _load_checkpoint
13
13
from mmcv .utils .parrots_wrapper import _BatchNorm
14
14
15
15
from mmseg .utils import get_root_logger
16
16
from ..builder import BACKBONES
17
+ from ..utils import DropPath , trunc_normal_
17
18
18
19
19
20
class Mlp (nn .Module ):
@@ -114,10 +115,14 @@ class Block(nn.Module):
114
115
Default: 0.
115
116
proj_drop (float): Drop rate for attn layer output weights.
116
117
Default: 0.
118
+ drop_path (float): Drop rate for paths of model.
119
+ Default: 0.
117
120
act_cfg (dict): Config dict for activation layer.
118
121
Default: dict(type='GELU').
119
122
norm_cfg (dict): Config dict for normalization layer.
120
123
Default: dict(type='LN', requires_grad=True).
124
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
125
+ memory while slowing down the training speed. Default: False.
121
126
"""
122
127
123
128
def __init__ (self ,
@@ -129,14 +134,17 @@ def __init__(self,
129
134
drop = 0. ,
130
135
attn_drop = 0. ,
131
136
proj_drop = 0. ,
137
+ drop_path = 0. ,
132
138
act_cfg = dict (type = 'GELU' ),
133
- norm_cfg = dict (type = 'LN' ),
139
+ norm_cfg = dict (type = 'LN' , eps = 1e-6 ),
134
140
with_cp = False ):
135
141
super (Block , self ).__init__ ()
136
142
self .with_cp = with_cp
137
143
_ , self .norm1 = build_norm_layer (norm_cfg , dim )
138
144
self .attn = Attention (dim , num_heads , qkv_bias , qk_scale , attn_drop ,
139
145
proj_drop )
146
+ self .drop_path = DropPath (
147
+ drop_path ) if drop_path > 0. else nn .Identity ()
140
148
_ , self .norm2 = build_norm_layer (norm_cfg , dim )
141
149
mlp_hidden_dim = int (dim * mlp_ratio )
142
150
self .mlp = Mlp (
@@ -148,8 +156,8 @@ def __init__(self,
148
156
def forward (self , x ):
149
157
150
158
def _inner_forward (x ):
151
- out = x + self .attn (self .norm1 (x ))
152
- out = out + self .mlp (self .norm2 (out ))
159
+ out = x + self .drop_path ( self . attn (self .norm1 (x ) ))
160
+ out = out + self .drop_path ( self . mlp (self .norm2 (out ) ))
153
161
return out
154
162
155
163
if self .with_cp and x .requires_grad :
@@ -164,7 +172,7 @@ class PatchEmbed(nn.Module):
164
172
"""Image to Patch Embedding.
165
173
166
174
Args:
167
- img_size (int, tuple): Input image size.
175
+ img_size (int | tuple): Input image size.
168
176
default: 224.
169
177
patch_size (int): Width and height for a patch.
170
178
default: 16.
@@ -202,24 +210,34 @@ class VisionTransformer(nn.Module):
202
210
Image Recognition at Scale` - https://arxiv.org/abs/2010.11929
203
211
204
212
Args:
205
- img_size (tuple): input image size. Default: (224, 224) .
213
+ img_size (tuple): input image size. Default: (224, 224) .
206
214
patch_size (int, tuple): patch size. Default: 16.
207
215
in_channels (int): number of input channels. Default: 3.
208
216
embed_dim (int): embedding dimension. Default: 768.
209
217
depth (int): depth of transformer. Default: 12.
210
218
num_heads (int): number of attention heads. Default: 12.
211
- mlp_ratio (int): ratio of mlp hidden dim to embedding dim. Default: 4.
219
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim.
220
+ Default: 4.
221
+ out_indices (list | tuple | int): Output from which stages.
222
+ Default: -1.
212
223
qkv_bias (bool): enable bias for qkv if True. Default: True.
213
224
qk_scale (float): override default qk scale of head_dim ** -0.5 if set.
214
225
drop_rate (float): dropout rate. Default: 0.
215
226
attn_drop_rate (float): attention dropout rate. Default: 0.
227
+ drop_path_rate (float): Rate of DropPath. Default: 0.
216
228
norm_cfg (dict): Config dict for normalization layer.
217
- Default: dict(type='LN', requires_grad=True).
229
+ Default: dict(type='LN', eps=1e-6, requires_grad=True).
218
230
act_cfg (dict): Config dict for activation layer.
219
231
Default: dict(type='GELU').
220
232
norm_eval (bool): Whether to set norm layers to eval mode, namely,
221
233
freeze running stats (mean and var). Note: Effect on Batch Norm
222
234
and its variants only. Default: False.
235
+ final_norm (bool): Whether to add a additional layer to normalize
236
+ final feature map. Default: False.
237
+ interpolate_mode (str): Select the interpolate mode for position
238
+ embeding vector resize. Default: bicubic.
239
+ with_cls_token (bool): If concatenating class token into image tokens
240
+ as transformer input. Default: True.
223
241
with_cp (bool): Use checkpoint or not. Using checkpoint
224
242
will save some memory while slowing down the training speed.
225
243
Default: False.
@@ -233,13 +251,18 @@ def __init__(self,
233
251
depth = 12 ,
234
252
num_heads = 12 ,
235
253
mlp_ratio = 4 ,
254
+ out_indices = 11 ,
236
255
qkv_bias = True ,
237
256
qk_scale = None ,
238
257
drop_rate = 0. ,
239
258
attn_drop_rate = 0. ,
240
- norm_cfg = dict (type = 'LN' ),
259
+ drop_path_rate = 0. ,
260
+ norm_cfg = dict (type = 'LN' , eps = 1e-6 , requires_grad = True ),
241
261
act_cfg = dict (type = 'GELU' ),
242
262
norm_eval = False ,
263
+ final_norm = False ,
264
+ with_cls_token = True ,
265
+ interpolate_mode = 'bicubic' ,
243
266
with_cp = False ):
244
267
super (VisionTransformer , self ).__init__ ()
245
268
self .img_size = img_size
@@ -251,24 +274,39 @@ def __init__(self,
251
274
in_channels = in_channels ,
252
275
embed_dim = embed_dim )
253
276
277
+ self .with_cls_token = with_cls_token
278
+ self .cls_token = nn .Parameter (torch .zeros (1 , 1 , self .embed_dim ))
254
279
self .pos_embed = nn .Parameter (
255
- torch .zeros (1 , self .patch_embed .num_patches , embed_dim ))
280
+ torch .zeros (1 , self .patch_embed .num_patches + 1 , embed_dim ))
256
281
self .pos_drop = nn .Dropout (p = drop_rate )
257
282
258
- self .blocks = nn .Sequential (* [
283
+ if isinstance (out_indices , int ):
284
+ self .out_indices = [out_indices ]
285
+ elif isinstance (out_indices , list ) or isinstance (out_indices , tuple ):
286
+ self .out_indices = out_indices
287
+ else :
288
+ raise TypeError ('out_indices must be type of int, list or tuple' )
289
+
290
+ dpr = [x .item () for x in torch .linspace (0 , drop_path_rate , depth )
291
+ ] # stochastic depth decay rule
292
+ self .blocks = nn .ModuleList ([
259
293
Block (
260
294
dim = embed_dim ,
261
295
num_heads = num_heads ,
262
296
mlp_ratio = mlp_ratio ,
263
297
qkv_bias = qkv_bias ,
264
298
qk_scale = qk_scale ,
265
- drop = drop_rate ,
299
+ drop = dpr [ i ] ,
266
300
attn_drop = attn_drop_rate ,
267
301
act_cfg = act_cfg ,
268
302
norm_cfg = norm_cfg ,
269
303
with_cp = with_cp ) for i in range (depth )
270
304
])
271
- _ , self .norm = build_norm_layer (norm_cfg , embed_dim )
305
+
306
+ self .interpolate_mode = interpolate_mode
307
+ self .final_norm = final_norm
308
+ if final_norm :
309
+ _ , self .norm = build_norm_layer (norm_cfg , embed_dim )
272
310
273
311
self .norm_eval = norm_eval
274
312
self .with_cp = with_cp
@@ -283,28 +321,26 @@ def init_weights(self, pretrained=None):
283
321
state_dict = checkpoint
284
322
285
323
if 'pos_embed' in state_dict .keys ():
286
- state_dict ['pos_embed' ] = state_dict ['pos_embed' ][:, 1 :, :]
287
- logger .info (
288
- msg = 'Remove the "cls_token" dimension from the checkpoint' )
289
-
290
324
if self .pos_embed .shape != state_dict ['pos_embed' ].shape :
291
325
logger .info (msg = f'Resize the pos_embed shape from \
292
- { state_dict ["pos_embed" ].shape } to \
293
- { self .pos_embed .shape } ' )
326
+ { state_dict ["pos_embed" ].shape } to { self .pos_embed .shape } ' )
294
327
h , w = self .img_size
295
- pos_size = int (math .sqrt (state_dict ['pos_embed' ].shape [1 ]))
328
+ pos_size = int (
329
+ math .sqrt (state_dict ['pos_embed' ].shape [1 ] - 1 ))
296
330
state_dict ['pos_embed' ] = self .resize_pos_embed (
297
331
state_dict ['pos_embed' ], (h , w ), (pos_size , pos_size ),
298
- self .patch_size )
332
+ self .patch_size , self .interpolate_mode )
333
+
299
334
self .load_state_dict (state_dict , False )
300
335
301
336
elif pretrained is None :
302
337
# We only implement the 'jax_impl' initialization implemented at
303
338
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
304
- normal_init (self .pos_embed )
339
+ trunc_normal_ (self .pos_embed , std = .02 )
340
+ trunc_normal_ (self .cls_token , std = .02 )
305
341
for n , m in self .named_modules ():
306
342
if isinstance (m , Linear ):
307
- xavier_init (m .weight , distribution = 'uniform' )
343
+ trunc_normal_ (m .weight , std = .02 )
308
344
if m .bias is not None :
309
345
if 'mlp' in n :
310
346
normal_init (m .bias , std = 1e-6 )
@@ -316,7 +352,7 @@ def init_weights(self, pretrained=None):
316
352
constant_init (m .bias , 0 )
317
353
elif isinstance (m , (_BatchNorm , nn .GroupNorm , nn .LayerNorm )):
318
354
constant_init (m .bias , 0 )
319
- constant_init (m .weight , 1 )
355
+ constant_init (m .weight , 1.0 )
320
356
else :
321
357
raise TypeError ('pretrained must be a str or None' )
322
358
@@ -340,19 +376,20 @@ def _pos_embeding(self, img, patched_img, pos_embed):
340
376
x_len , pos_len = patched_img .shape [1 ], pos_embed .shape [1 ]
341
377
if x_len != pos_len :
342
378
if pos_len == (self .img_size [0 ] // self .patch_size ) * (
343
- self .img_size [1 ] // self .patch_size ):
379
+ self .img_size [1 ] // self .patch_size ) + 1 :
344
380
pos_h = self .img_size [0 ] // self .patch_size
345
381
pos_w = self .img_size [1 ] // self .patch_size
346
382
else :
347
383
raise ValueError (
348
384
'Unexpected shape of pos_embed, got {}.' .format (
349
385
pos_embed .shape ))
350
386
pos_embed = self .resize_pos_embed (pos_embed , img .shape [2 :],
351
- (pos_h , pos_w ), self .patch_size )
352
- return patched_img + pos_embed
387
+ (pos_h , pos_w ), self .patch_size ,
388
+ self .interpolate_mode )
389
+ return self .pos_drop (patched_img + pos_embed )
353
390
354
391
@staticmethod
355
- def resize_pos_embed (pos_embed , input_shpae , pos_shape , patch_size ):
392
+ def resize_pos_embed (pos_embed , input_shpae , pos_shape , patch_size , mode ):
356
393
"""Resize pos_embed weights.
357
394
358
395
Resize pos_embed using bicubic interpolate method.
@@ -367,26 +404,52 @@ def resize_pos_embed(pos_embed, input_shpae, pos_shape, patch_size):
367
404
assert pos_embed .ndim == 3 , 'shape of pos_embed must be [B, L, C]'
368
405
input_h , input_w = input_shpae
369
406
pos_h , pos_w = pos_shape
370
- pos_embed = pos_embed .reshape (1 , pos_h , pos_w ,
371
- pos_embed .shape [2 ]).permute (0 , 3 , 1 , 2 )
372
- pos_embed = F .interpolate (
373
- pos_embed ,
407
+ cls_token_weight = pos_embed [:, 0 ]
408
+ pos_embed_weight = pos_embed [:, (- 1 * pos_h * pos_w ):]
409
+ pos_embed_weight = pos_embed_weight .reshape (
410
+ 1 , pos_h , pos_w , pos_embed .shape [2 ]).permute (0 , 3 , 1 , 2 )
411
+ pos_embed_weight = F .interpolate (
412
+ pos_embed_weight ,
374
413
size = [input_h // patch_size , input_w // patch_size ],
375
414
align_corners = False ,
376
- mode = 'bicubic' )
377
- pos_embed = torch .flatten (pos_embed , 2 ).transpose (1 , 2 )
415
+ mode = mode )
416
+ cls_token_weight = cls_token_weight .unsqueeze (1 )
417
+ pos_embed_weight = torch .flatten (pos_embed_weight , 2 ).transpose (1 , 2 )
418
+ pos_embed = torch .cat ((cls_token_weight , pos_embed_weight ), dim = 1 )
378
419
return pos_embed
379
420
380
421
def forward (self , inputs ):
422
+ B = inputs .shape [0 ]
423
+
381
424
x = self .patch_embed (inputs )
425
+
426
+ cls_tokens = self .cls_token .expand (B , - 1 , - 1 )
427
+ x = torch .cat ((cls_tokens , x ), dim = 1 )
382
428
x = self ._pos_embeding (inputs , x , self .pos_embed )
383
- x = self .blocks (x )
384
- x = self .norm (x )
385
- B , _ , C = x .shape
386
- x = x .reshape (B , inputs .shape [2 ] // self .patch_size ,
387
- inputs .shape [3 ] // self .patch_size ,
388
- C ).permute (0 , 3 , 1 , 2 )
389
- return [x ]
429
+
430
+ if not self .with_cls_token :
431
+ # Remove class token for transformer input
432
+ x = x [:, 1 :]
433
+
434
+ outs = []
435
+ for i , blk in enumerate (self .blocks ):
436
+ x = blk (x )
437
+ if i == len (self .blocks ) - 1 :
438
+ if self .final_norm :
439
+ x = self .norm (x )
440
+ if i in self .out_indices :
441
+ if self .with_cls_token :
442
+ # Remove class token and reshape token for decoder head
443
+ out = x [:, 1 :]
444
+ else :
445
+ out = x
446
+ B , _ , C = out .shape
447
+ out = out .reshape (B , inputs .shape [2 ] // self .patch_size ,
448
+ inputs .shape [3 ] // self .patch_size ,
449
+ C ).permute (0 , 3 , 1 , 2 )
450
+ outs .append (out )
451
+
452
+ return tuple (outs )
390
453
391
454
def train (self , mode = True ):
392
455
super (VisionTransformer , self ).train (mode )
0 commit comments