Skip to content

Commit 98067be

Browse files
authored
[Fix] Add setr & vit msg. (open-mmlab#635)
* [Fix] Add setr & vit msg. * Fix init bug * Modify init_cfg arg * Add conv_seg init
1 parent ec91893 commit 98067be

File tree

4 files changed

+21
-12
lines changed

4 files changed

+21
-12
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ Supported backbones:
6363
- [x] [ResNeSt (ArXiv'2020)](configs/resnest/README.md)
6464
- [x] [MobileNetV2 (CVPR'2018)](configs/mobilenet_v2/README.md)
6565
- [x] [MobileNetV3 (ICCV'2019)](configs/mobilenet_v3/README.md)
66+
- [x] [Vision Transformer (ICLR'2021)]
6667

6768
Supported methods:
6869

@@ -89,6 +90,7 @@ Supported methods:
8990
- [x] [DNLNet (ECCV'2020)](configs/dnlnet)
9091
- [x] [PointRend (CVPR'2020)](configs/point_rend)
9192
- [x] [CGNet (TIP'2020)](configs/cgnet)
93+
- [x] [SETR (CVPR'2021)](configs/setr)
9294

9395
## Installation
9496

README_zh-CN.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 O
6262
- [x] [ResNeSt (ArXiv'2020)](configs/resnest/README.md)
6363
- [x] [MobileNetV2 (CVPR'2018)](configs/mobilenet_v2/README.md)
6464
- [x] [MobileNetV3 (ICCV'2019)](configs/mobilenet_v3/README.md)
65+
- [x] [Vision Transformer (ICLR'2021)]
6566

6667
已支持的算法:
6768

@@ -87,6 +88,7 @@ MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 O
8788
- [x] [DNLNet (ECCV'2020)](configs/dnlnet)
8889
- [x] [PointRend (CVPR'2020)](configs/point_rend)
8990
- [x] [CGNet (TIP'2020)](configs/cgnet)
91+
- [x] [SETR (CVPR'2021)](configs/setr)
9092

9193
## 安装
9294

mmseg/models/decode_heads/setr_up_head.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import torch.nn as nn
2-
from mmcv.cnn import ConvModule, build_norm_layer, constant_init
2+
from mmcv.cnn import ConvModule, build_norm_layer
33

44
from ..builder import HEADS
55
from .decode_head import BaseDecodeHead
@@ -18,18 +18,28 @@ class SETRUPHead(BaseDecodeHead):
1818
up_scale (int): The scale factor of interpolate. Default:4.
1919
kernel_size (int): The kernel size of convolution when decoding
2020
feature information from backbone. Default: 3.
21+
init_cfg (dict | list[dict] | None): Initialization config dict.
22+
Default: dict(
23+
type='Constant', val=1.0, bias=0, layer='LayerNorm').
2124
"""
2225

2326
def __init__(self,
2427
norm_layer=dict(type='LN', eps=1e-6, requires_grad=True),
2528
num_convs=1,
2629
up_scale=4,
2730
kernel_size=3,
31+
init_cfg=[
32+
dict(type='Constant', val=1.0, bias=0, layer='LayerNorm'),
33+
dict(
34+
type='Normal',
35+
std=0.01,
36+
override=dict(name='conv_seg'))
37+
],
2838
**kwargs):
2939

3040
assert kernel_size in [1, 3], 'kernel_size must be 1 or 3.'
3141

32-
super(SETRUPHead, self).__init__(**kwargs)
42+
super(SETRUPHead, self).__init__(init_cfg=init_cfg, **kwargs)
3343

3444
assert isinstance(self.in_channels, int)
3545

@@ -38,7 +48,7 @@ def __init__(self,
3848
self.up_convs = nn.ModuleList()
3949
in_channels = self.in_channels
4050
out_channels = self.channels
41-
for i in range(num_convs):
51+
for _ in range(num_convs):
4252
self.up_convs.append(
4353
nn.Sequential(
4454
ConvModule(
@@ -55,12 +65,6 @@ def __init__(self,
5565
align_corners=self.align_corners)))
5666
in_channels = out_channels
5767

58-
def init_weights(self):
59-
for m in self.modules():
60-
if isinstance(m, nn.LayerNorm):
61-
constant_init(m.bias, 0)
62-
constant_init(m.weight, 1.0)
63-
6468
def forward(self, x):
6569
x = self._transform_inputs(x)
6670

tests/test_models/test_heads/test_setr_up_head.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,14 @@ def test_setr_up_head(capsys):
1616
# as embed_dim.
1717
SETRUPHead(in_channels=(32, 32), channels=16, num_classes=19)
1818

19-
# test init_weights of head
19+
# test init_cfg of head
2020
head = SETRUPHead(
2121
in_channels=32,
2222
channels=16,
2323
norm_cfg=dict(type='SyncBN'),
24-
num_classes=19)
25-
head.init_weights()
24+
num_classes=19,
25+
init_cfg=dict(type='Kaiming'))
26+
super(SETRUPHead, head).init_weights()
2627

2728
# test inference of Naive head
2829
# the auxiliary head of Naive head is same as Naive head

0 commit comments

Comments
 (0)