Skip to content

Commit 0c4c3b7

Browse files
authored
[Fix] Fix some vit init bugs (open-mmlab#609)
* [Fix] Fix vit init bug * Add some vit unit tests * Modify module import * Fix pretrain weights bug * Modify pretrained judge * Add some unit tests to improve code cov * Optimize code * Fix vit unit test
1 parent 458fc78 commit 0c4c3b7

File tree

3 files changed

+76
-29
lines changed

3 files changed

+76
-29
lines changed

mmseg/models/backbones/vit.py

Lines changed: 43 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import math
2+
import warnings
23

34
import torch
45
import torch.nn as nn
56
import torch.nn.functional as F
67
from mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init,
78
kaiming_init, normal_init, trunc_normal_init)
89
from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
9-
from mmcv.runner import _load_checkpoint
10-
from mmcv.runner.base_module import BaseModule, ModuleList
10+
from mmcv.runner import BaseModule, ModuleList, _load_checkpoint
1111
from torch.nn.modules.batchnorm import _BatchNorm
1212
from torch.nn.modules.utils import _pair as to_2tuple
1313

@@ -140,12 +140,6 @@ def __init__(self,
140140
self.norm = None
141141

142142
def forward(self, x):
143-
B, C, H, W = x.shape
144-
# FIXME look at relaxing size constraints
145-
# assert H == self.img_size[0] and W == self.img_size[1], \
146-
# f"Input image size ({H}*{W}) doesn't " \
147-
# f'match model ({self.img_size[0]}*{self.img_size[1]}).'
148-
# The output size is (B, N, D), where N=H*W/P/P, D is embid_dim
149143
x = self.projection(x).flatten(2).transpose(1, 2)
150144

151145
if self.norm is not None:
@@ -185,8 +179,12 @@ class VisionTransformer(BaseModule):
185179
Default: dict(type='LN')
186180
act_cfg (dict): The activation config for FFNs.
187181
Defalut: dict(type='GELU').
188-
final_norm (bool): Whether to add a additional layer to normalize
182+
patch_norm (bool): Whether to add a norm in PatchEmbed Block.
183+
Default: False.
184+
final_norm (bool): Whether to add a additional layer to normalize
189185
final feature map. Default: False.
186+
out_shape (str): Select the output format of feature information.
187+
Default: NCHW.
190188
interpolate_mode (str): Select the interpolate mode for position
191189
embeding vector resize. Default: bicubic.
192190
num_fcs (int): The number of fully-connected layers for FFNs.
@@ -198,6 +196,9 @@ class VisionTransformer(BaseModule):
198196
some memory while slowing down the training speed. Default: False.
199197
pretrain_style (str): Choose to use timm or mmcls pretrain weights.
200198
Default: timm.
199+
pretrained (str, optional): model pretrained path. Default: None.
200+
init_cfg (dict or list[dict], optional): Initialization config dict.
201+
Default: None.
201202
"""
202203

203204
def __init__(self,
@@ -216,12 +217,16 @@ def __init__(self,
216217
with_cls_token=True,
217218
norm_cfg=dict(type='LN'),
218219
act_cfg=dict(type='GELU'),
220+
patch_norm=False,
219221
final_norm=False,
222+
out_shape='NCHW',
220223
interpolate_mode='bicubic',
221224
num_fcs=2,
222225
norm_eval=False,
223226
with_cp=False,
224-
pretrain_style='timm'):
227+
pretrain_style='timm',
228+
pretrained=None,
229+
init_cfg=None):
225230
super(VisionTransformer, self).__init__()
226231

227232
if isinstance(img_size, int):
@@ -235,16 +240,32 @@ def __init__(self,
235240

236241
assert pretrain_style in ['timm', 'mmcls']
237242

238-
self.pretrain_style = pretrain_style
243+
assert out_shape in ['NLC',
244+
'NCHW'], 'output shape must be "NLC" or "NCHW".'
245+
246+
if isinstance(pretrained, str) or pretrained is None:
247+
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
248+
'please use "init_cfg" instead')
249+
else:
250+
raise TypeError('pretrained must be a str or None')
251+
239252
self.img_size = img_size
240253
self.patch_size = patch_size
254+
self.out_shape = out_shape
255+
self.interpolate_mode = interpolate_mode
256+
self.norm_eval = norm_eval
257+
self.with_cp = with_cp
258+
self.pretrain_style = pretrain_style
259+
self.pretrained = pretrained
260+
self.init_cfg = init_cfg
241261

242262
self.patch_embed = PatchEmbed(
243263
img_size=img_size,
244264
patch_size=patch_size,
245265
in_channels=in_channels,
246266
embed_dim=embed_dims,
247-
norm_cfg=norm_cfg)
267+
norm_cfg=norm_cfg if patch_norm else None)
268+
248269
num_patches = self.patch_embed.num_patches
249270

250271
self.with_cls_token = with_cls_token
@@ -280,24 +301,20 @@ def __init__(self,
280301
norm_cfg=norm_cfg,
281302
batch_first=True))
282303

283-
self.interpolate_mode = interpolate_mode
284304
self.final_norm = final_norm
285305
if final_norm:
286306
self.norm1_name, norm1 = build_norm_layer(
287307
norm_cfg, embed_dims, postfix=1)
288308
self.add_module(self.norm1_name, norm1)
289309

290-
self.norm_eval = norm_eval
291-
self.with_cp = with_cp
292-
293310
@property
294311
def norm1(self):
295312
return getattr(self, self.norm1_name)
296313

297-
def init_weights(self, pretrained=None):
298-
if isinstance(pretrained, str):
314+
def init_weights(self):
315+
if isinstance(self.pretrained, str):
299316
logger = get_root_logger()
300-
checkpoint = _load_checkpoint(pretrained, logger=logger)
317+
checkpoint = _load_checkpoint(self.pretrained, logger=logger)
301318
if 'state_dict' in checkpoint:
302319
state_dict = checkpoint['state_dict']
303320
elif 'model' in checkpoint:
@@ -325,7 +342,8 @@ def init_weights(self, pretrained=None):
325342

326343
self.load_state_dict(state_dict, False)
327344

328-
elif pretrained is None:
345+
elif self.pretrained is None:
346+
super(VisionTransformer, self).init_weights()
329347
# We only implement the 'jax_impl' initialization implemented at
330348
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
331349
trunc_normal_init(self.pos_embed, std=.02)
@@ -345,8 +363,6 @@ def init_weights(self, pretrained=None):
345363
elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
346364
constant_init(m.bias, 0)
347365
constant_init(m.weight, 1.0)
348-
else:
349-
raise TypeError('pretrained must be a str or None')
350366

351367
def _pos_embeding(self, img, patched_img, pos_embed):
352368
"""Positiong embeding method.
@@ -436,10 +452,11 @@ def forward(self, inputs):
436452
out = x[:, 1:]
437453
else:
438454
out = x
439-
B, _, C = out.shape
440-
out = out.reshape(B, inputs.shape[2] // self.patch_size,
441-
inputs.shape[3] // self.patch_size,
442-
C).permute(0, 3, 1, 2)
455+
if self.out_shape == 'NCHW':
456+
B, _, C = out.shape
457+
out = out.reshape(B, inputs.shape[2] // self.patch_size,
458+
inputs.shape[3] // self.patch_size,
459+
C).permute(0, 3, 1, 2)
443460
outs.append(out)
444461

445462
return tuple(outs)

mmseg/models/utils/timm_convert.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ def vit_convert(timm_dict):
2727
new_k = new_k.replace('attn.proj', 'attn.attn.out_proj')
2828
else:
2929
new_k = k
30-
new_k = f'backbone.{new_k}'
3130
mmseg_dict[new_k] = v
3231

3332
return mmseg_dict

tests/test_models/test_backbones/test_vit.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,21 +24,35 @@ def test_vit_backbone():
2424
x = torch.randn(1, 196)
2525
VisionTransformer.resize_pos_embed(x, 512, 512, 224, 224, 'bilinear')
2626

27-
with pytest.raises(ValueError):
27+
with pytest.raises(RuntimeError):
2828
# forward inputs must be [N, C, H, W]
2929
x = torch.randn(3, 30, 30)
3030
model = VisionTransformer()
3131
model(x)
3232

3333
with pytest.raises(AssertionError):
34+
# The length of img_size tuple must be lower than 3.
3435
VisionTransformer(img_size=(224, 224, 224))
3536

37+
with pytest.raises(TypeError):
38+
# Pretrained must be None or Str.
39+
VisionTransformer(pretrained=123)
40+
41+
with pytest.raises(AssertionError):
42+
# out_shape must be 'NLC' or 'NCHW;'
43+
VisionTransformer(out_shape='NCL')
44+
3645
# Test img_size isinstance tuple
3746
imgs = torch.randn(1, 3, 224, 224)
38-
model = VisionTransformer(img_size=(224, 224))
47+
model = VisionTransformer(img_size=(224, ))
3948
model.init_weights()
4049
model(imgs)
4150

51+
# Test img_size isinstance tuple
52+
imgs = torch.randn(1, 3, 224, 224)
53+
model = VisionTransformer(img_size=(224, 224))
54+
model(imgs)
55+
4256
# Test norm_eval = True
4357
model = VisionTransformer(norm_eval=True)
4458
model.train()
@@ -50,6 +64,11 @@ def test_vit_backbone():
5064

5165
assert check_norm_state(model.modules(), True)
5266

67+
# Test normal size input image
68+
imgs = torch.randn(1, 3, 224, 224)
69+
feat = model(imgs)
70+
assert feat[-1].shape == (1, 768, 14, 14)
71+
5372
# Test large size input image
5473
imgs = torch.randn(1, 3, 256, 256)
5574
feat = model(imgs)
@@ -81,8 +100,20 @@ def test_vit_backbone():
81100
feat = model(imgs)
82101
assert feat[-1].shape == (1, 768, 14, 14)
83102

103+
# Test out_shape == 'NLC'
104+
model = VisionTransformer(out_shape='NLC')
105+
imgs = torch.randn(1, 3, 224, 224)
106+
feat = model(imgs)
107+
assert feat[-1].shape == (1, 196, 768)
108+
84109
# Test final norm
85110
model = VisionTransformer(final_norm=True)
86111
imgs = torch.randn(1, 3, 224, 224)
87112
feat = model(imgs)
88113
assert feat[-1].shape == (1, 768, 14, 14)
114+
115+
# Test patch norm
116+
model = VisionTransformer(patch_norm=True)
117+
imgs = torch.randn(1, 3, 224, 224)
118+
feat = model(imgs)
119+
assert feat[-1].shape == (1, 768, 14, 14)

0 commit comments

Comments
 (0)