Skip to content

Commit 85227b4

Browse files
xvjiaruiJunjun2016
andauthored
[Improvement] Refactor Swin-Transformer (open-mmlab#800)
* [Improvement] Refactor Swin-Transformer * fixed swin test * update patch emebd, add more tests * fixed test * remove pretrain_style * fixed padding * resolve coments * use mmcv 2tuple * refactor init_cfg Co-authored-by: Junjun2016 <[email protected]>
1 parent ab12009 commit 85227b4

File tree

11 files changed

+937
-246
lines changed

11 files changed

+937
-246
lines changed

configs/_base_/models/upernet_swin.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@
2323
drop_path_rate=0.3,
2424
use_abs_pos_embed=False,
2525
act_cfg=dict(type='GELU'),
26-
norm_cfg=backbone_norm_cfg,
27-
pretrain_style='official'),
26+
norm_cfg=backbone_norm_cfg),
2827
decode_head=dict(
2928
type='UPerHead',
3029
in_channels=[96, 192, 384, 768],

configs/swin/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@
1111
window_size=7,
1212
use_abs_pos_embed=False,
1313
drop_path_rate=0.3,
14-
patch_norm=True,
15-
pretrain_style='official'),
14+
patch_norm=True),
1615
decode_head=dict(in_channels=[96, 192, 384, 768], num_classes=150),
1716
auxiliary_head=dict(in_channels=384, num_classes=150))
1817

mmseg/models/backbones/mit.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -278,8 +278,6 @@ class MixVisionTransformer(BaseModule):
278278
Default: dict(type='LN')
279279
act_cfg (dict): The activation config for FFNs.
280280
Defalut: dict(type='GELU').
281-
pretrain_style (str): Choose to use official or mmcls pretrain weights.
282-
Default: official.
283281
pretrained (str, optional): model pretrained path. Default: None.
284282
init_cfg (dict or list[dict], optional): Initialization config dict.
285283
Default: None.
@@ -302,15 +300,10 @@ def __init__(self,
302300
drop_path_rate=0.,
303301
act_cfg=dict(type='GELU'),
304302
norm_cfg=dict(type='LN', eps=1e-6),
305-
pretrain_style='official',
306303
pretrained=None,
307304
init_cfg=None):
308305
super().__init__()
309306

310-
assert pretrain_style in [
311-
'official', 'mmcls'
312-
], 'we only support official weights or mmcls weights.'
313-
314307
if isinstance(pretrained, str) or pretrained is None:
315308
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
316309
'please use "init_cfg" instead')
@@ -330,7 +323,6 @@ def __init__(self,
330323

331324
self.out_indices = out_indices
332325
assert max(out_indices) < self.num_stages
333-
self.pretrain_style = pretrain_style
334326
self.pretrained = pretrained
335327
self.init_cfg = init_cfg
336328

@@ -350,7 +342,6 @@ def __init__(self,
350342
kernel_size=patch_sizes[i],
351343
stride=strides[i],
352344
padding=patch_sizes[i] // 2,
353-
pad_to_patch_size=False,
354345
norm_cfg=norm_cfg)
355346
layer = ModuleList([
356347
TransformerEncoderLayer(
@@ -403,8 +394,7 @@ def forward(self, x):
403394
outs = []
404395

405396
for i, layer in enumerate(self.layers):
406-
x, H, W = layer[0](x), layer[0].DH, layer[0].DW
407-
hw_shape = (H, W)
397+
x, hw_shape = layer[0](x)
408398
for block in layer[1]:
409399
x = block(x, hw_shape)
410400
x = layer[2](x)

0 commit comments

Comments
 (0)