Skip to content

Commit aa9b609

Browse files
authored
Add option for output shape of ViT (open-mmlab#530)
* Add arg: final_reshape to control if converting output feature information from NLC to NCHW; * Fix the default value of final_reshape; * Modify arg: final_reshape to arg: out_shape; * Fix some unit test bug;
1 parent f884489 commit aa9b609

File tree

2 files changed

+23
-4
lines changed

2 files changed

+23
-4
lines changed

mmseg/models/backbones/vit.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,8 @@ class VisionTransformer(nn.Module):
234234
and its variants only. Default: False.
235235
final_norm (bool): Whether to add a additional layer to normalize
236236
final feature map. Default: False.
237+
out_reshape (str): Select the output format of feature information.
238+
Default: NCHW.
237239
interpolate_mode (str): Select the interpolate mode for position
238240
embeding vector resize. Default: bicubic.
239241
with_cls_token (bool): If concatenating class token into image tokens
@@ -261,6 +263,7 @@ def __init__(self,
261263
act_cfg=dict(type='GELU'),
262264
norm_eval=False,
263265
final_norm=False,
266+
out_shape='NCHW',
264267
with_cls_token=True,
265268
interpolate_mode='bicubic',
266269
with_cp=False):
@@ -303,6 +306,11 @@ def __init__(self,
303306
with_cp=with_cp) for i in range(depth)
304307
])
305308

309+
assert out_shape in ['NLC',
310+
'NCHW'], 'output shape must be "NLC" or "NCHW".'
311+
312+
self.out_shape = out_shape
313+
306314
self.interpolate_mode = interpolate_mode
307315
self.final_norm = final_norm
308316
if final_norm:
@@ -443,10 +451,11 @@ def forward(self, inputs):
443451
out = x[:, 1:]
444452
else:
445453
out = x
446-
B, _, C = out.shape
447-
out = out.reshape(B, inputs.shape[2] // self.patch_size,
448-
inputs.shape[3] // self.patch_size,
449-
C).permute(0, 3, 1, 2)
454+
if self.out_shape == 'NCHW':
455+
B, _, C = out.shape
456+
out = out.reshape(B, inputs.shape[2] // self.patch_size,
457+
inputs.shape[3] // self.patch_size,
458+
C).permute(0, 3, 1, 2)
450459
outs.append(out)
451460

452461
return tuple(outs)

tests/test_models/test_backbones/test_vit.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ def test_vit_backbone():
3030
model = VisionTransformer()
3131
model(x)
3232

33+
with pytest.raises(AssertionError):
34+
# out_shape must be 'NLC' or 'NCHW;'
35+
VisionTransformer(out_shape='NCL')
36+
3337
# Test img_size isinstance int
3438
imgs = torch.randn(1, 3, 224, 224)
3539
model = VisionTransformer(img_size=224)
@@ -72,3 +76,9 @@ def test_vit_backbone():
7276
imgs = torch.randn(1, 3, 224, 224)
7377
feat = model(imgs)
7478
assert feat[-1].shape == (1, 768, 14, 14)
79+
80+
# Test final reshape arg
81+
imgs = torch.randn(1, 3, 224, 224)
82+
model = VisionTransformer(out_shape='NLC')
83+
feat = model(imgs)
84+
assert feat[-1].shape == (1, 196, 768)

0 commit comments

Comments
 (0)