Skip to content

Commit ff8d971

Browse files
[Feature] Support SegNeXt in MMSegmentation 2.0 (open-mmlab#2654)
## Motivation Support SegNeXt in MMSeg 1.x branch. 0.x PR: open-mmlab#2600 --------- Co-authored-by: xiexinch <[email protected]>
1 parent cb2d8fe commit ff8d971

15 files changed

+1176
-2
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ Results and models are available in the [model zoo](docs/en/model_zoo.md).
117117
- [x] [ConvNeXt (CVPR'2022)](configs/convnext)
118118
- [x] [MAE (CVPR'2022)](configs/mae)
119119
- [x] [PoolFormer (CVPR'2022)](configs/poolformer)
120+
- [x] [SegNeXt (NeurIPS'2022)](configs/segnext)
120121

121122
</details>
122123

README_zh-CN.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 O
9898
- [x] [ConvNeXt (CVPR'2022)](configs/convnext)
9999
- [x] [MAE (CVPR'2022)](configs/mae)
100100
- [x] [PoolFormer (CVPR'2022)](configs/poolformer)
101+
- [x] [SegNeXt (NeurIPS'2022)](configs/segnext)
101102

102103
</details>
103104

configs/segnext/README.md

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# SegNeXt
2+
3+
> [SegNeXt: Rethinking Convolutional Attention Design for Semantic Segmentation](https://arxiv.org/abs/2209.08575)
4+
5+
## Introduction
6+
7+
<!-- [ALGORITHM] -->
8+
9+
<a href="https://github.com/visual-attention-network/segnext">Official Repo</a>
10+
11+
<a href="https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/mmseg/models/backbones/mscan.py#L328">Code Snippet</a>
12+
13+
## Abstract
14+
15+
<!-- [ABSTRACT] -->
16+
17+
We present SegNeXt, a simple convolutional network architecture for semantic segmentation. Recent transformer-based models have dominated the field of semantic segmentation due to the efficiency of self-attention in encoding spatial information. In this paper, we show that convolutional attention is a more efficient and effective way to encode contextual information than the self-attention mechanism in transformers. By re-examining the characteristics owned by successful segmentation models, we discover several key components leading to the performance improvement of segmentation models. This motivates us to design a novel convolutional attention network that uses cheap convolutional operations. Without bells and whistles, our SegNeXt significantly improves the performance of previous state-of-the-art methods on popular benchmarks, including ADE20K, Cityscapes, COCO-Stuff, Pascal VOC, Pascal Context, and iSAID. Notably, SegNeXt outperforms EfficientNet-L2 w/ NAS-FPN and achieves 90.6% mIoU on the Pascal VOC 2012 test leaderboard using only 1/10 parameters of it. On average, SegNeXt achieves about 2.0% mIoU improvements compared to the state-of-the-art methods on the ADE20K datasets with the same or fewer computations. Code is available at [this https URL](https://github.com/uyzhang/JSeg) (Jittor) and [this https URL](https://github.com/Visual-Attention-Network/SegNeXt) (Pytorch).
18+
19+
<!-- [IMAGE] -->
20+
21+
<div align=center>
22+
<img src="https://user-images.githubusercontent.com/24582831/215688018-5d4c8366-7793-4fdf-9397-960a09fac951.png" width="70%"/>
23+
</div>
24+
25+
## Results and models
26+
27+
### ADE20K
28+
29+
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
30+
| ------- | -------- | --------- | ------- | -------- | -------------- | ----- | ------------- | -------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
31+
| SegNeXt | MSCAN-T | 512x512 | 160000 | 17.88 | 52.38 | 41.50 | 42.59 | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/segnext/segnext_mscan-t_1xb16-adamw-160k_ade20k-512x512.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-t_1x16_512x512_adamw_160k_ade20k/segnext_mscan-t_1x16_512x512_adamw_160k_ade20k_20230210_140244-05bd8466.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-t_1x16_512x512_adamw_160k_ade20k/segnext_mscan-t_1x16_512x512_adamw_160k_ade20k_20230210_140244.log.json) |
32+
| SegNeXt | MSCAN-S | 512x512 | 160000 | 21.47 | 42.27 | 44.16 | 45.81 | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/segnext/segnext_mscan-s_1xb16-adamw-160k_ade20k-512x512.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-s_1x16_512x512_adamw_160k_ade20k/segnext_mscan-s_1x16_512x512_adamw_160k_ade20k_20230214_113014-43013668.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-s_1x16_512x512_adamw_160k_ade20k/segnext_mscan-s_1x16_512x512_adamw_160k_ade20k_20230214_113014.log.json) |
33+
| SegNeXt | MSCAN-B | 512x512 | 160000 | 31.03 | 35.15 | 48.03 | 49.68 | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/segnext/segnext_mscan-b_1xb16-adamw-160k_ade20k-512x512.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-b_1x16_512x512_adamw_160k_ade20k/segnext_mscan-b_1x16_512x512_adamw_160k_ade20k_20230209_172053-b6f6c70c.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-b_1x16_512x512_adamw_160k_ade20k/segnext_mscan-b_1x16_512x512_adamw_160k_ade20k_20230209_172053.log.json) |
34+
| SegNeXt | MSCAN-L | 512x512 | 160000 | 43.32 | 22.91 | 50.99 | 52.10 | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/segnext/segnext_mscan-l_1xb16-adamw-160k_ade20k-512x512.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-l_1x16_512x512_adamw_160k_ade20k/segnext_mscan-l_1x16_512x512_adamw_160k_ade20k_20230209_172055-19b14b63.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-l_1x16_512x512_adamw_160k_ade20k/segnext_mscan-l_1x16_512x512_adamw_160k_ade20k_20230209_172055.log.json) |
35+
36+
Note:
37+
38+
- When we integrated SegNeXt into MMSegmentation, we modified some layers' names to make them more precise and concise without changing the model architecture. Therefore, the keys of pre-trained weights are different from the [original weights](https://cloud.tsinghua.edu.cn/d/c15b25a6745946618462/), but don't worry about these changes. we have converted them and uploaded the checkpoints, you might find URL of pre-trained checkpoints in config files and can use them directly for training.
39+
40+
- The total batch size is 16. We trained for SegNeXt with a single GPU as the performance degrades significantly when using`SyncBN` (mainly in `OverlapPatchEmbed` modules of `MSCAN`) of PyTorch 1.9.
41+
42+
- There will be subtle differences when model testing as Non-negative Matrix Factorization (NMF) in `LightHamHead` will be initialized randomly. To control this randomness, please set the random seed when model testing. You can modify [`./tools/test.py`](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/tools/test.py) like:
43+
44+
```python
45+
def main():
46+
from mmengine.runner import seg_random_seed
47+
random_seed = xxx # set random seed recorded in training log
48+
set_random_seed(random_seed, deterministic=False)
49+
...
50+
```
51+
52+
- This model performance is sensitive to the seed values used, please refer to the log file for the specific settings of the seed. If you choose a different seed, the results might differ from the table results. Take SegNeXt Large for example, its results range from 49.60 to 51.0.
53+
54+
## Citation
55+
56+
```bibtex
57+
@article{guo2022segnext,
58+
title={SegNeXt: Rethinking Convolutional Attention Design for Semantic Segmentation},
59+
author={Guo, Meng-Hao and Lu, Cheng-Ze and Hou, Qibin and Liu, Zhengning and Cheng, Ming-Ming and Hu, Shi-Min},
60+
journal={arXiv preprint arXiv:2209.08575},
61+
year={2022}
62+
}
63+
```

configs/segnext/segnext.yml

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
Collections:
2+
- Name: SegNeXt
3+
Metadata:
4+
Training Data:
5+
- ADE20K
6+
Paper:
7+
URL: https://arxiv.org/abs/2209.08575
8+
Title: 'SegNeXt: Rethinking Convolutional Attention Design for Semantic Segmentation'
9+
README: configs/segnext/README.md
10+
Code:
11+
URL: https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/mmseg/models/backbones/mscan.py#L328
12+
Version: dev-1.x
13+
Converted From:
14+
Code: https://github.com/visual-attention-network/segnext
15+
Models:
16+
- Name: segnext_mscan-t_1xb16-adamw-160k_ade20k-512x512
17+
In Collection: SegNeXt
18+
Metadata:
19+
backbone: MSCAN-T
20+
crop size: (512,512)
21+
lr schd: 160000
22+
inference time (ms/im):
23+
- value: 19.09
24+
hardware: V100
25+
backend: PyTorch
26+
batch size: 1
27+
mode: FP32
28+
resolution: (512,512)
29+
Training Memory (GB): 17.88
30+
Results:
31+
- Task: Semantic Segmentation
32+
Dataset: ADE20K
33+
Metrics:
34+
mIoU: 41.5
35+
mIoU(ms+flip): 42.59
36+
Config: configs/segnext/segnext_mscan-t_1xb16-adamw-160k_ade20k-512x512.py
37+
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-t_1x16_512x512_adamw_160k_ade20k/segnext_mscan-t_1x16_512x512_adamw_160k_ade20k_20230210_140244-05bd8466.pth
38+
- Name: segnext_mscan-s_1xb16-adamw-160k_ade20k-512x512
39+
In Collection: SegNeXt
40+
Metadata:
41+
backbone: MSCAN-S
42+
crop size: (512,512)
43+
lr schd: 160000
44+
inference time (ms/im):
45+
- value: 23.66
46+
hardware: V100
47+
backend: PyTorch
48+
batch size: 1
49+
mode: FP32
50+
resolution: (512,512)
51+
Training Memory (GB): 21.47
52+
Results:
53+
- Task: Semantic Segmentation
54+
Dataset: ADE20K
55+
Metrics:
56+
mIoU: 44.16
57+
mIoU(ms+flip): 45.81
58+
Config: configs/segnext/segnext_mscan-s_1xb16-adamw-160k_ade20k-512x512.py
59+
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-s_1x16_512x512_adamw_160k_ade20k/segnext_mscan-s_1x16_512x512_adamw_160k_ade20k_20230214_113014-43013668.pth
60+
- Name: segnext_mscan-b_1xb16-adamw-160k_ade20k-512x512
61+
In Collection: SegNeXt
62+
Metadata:
63+
backbone: MSCAN-B
64+
crop size: (512,512)
65+
lr schd: 160000
66+
inference time (ms/im):
67+
- value: 28.45
68+
hardware: V100
69+
backend: PyTorch
70+
batch size: 1
71+
mode: FP32
72+
resolution: (512,512)
73+
Training Memory (GB): 31.03
74+
Results:
75+
- Task: Semantic Segmentation
76+
Dataset: ADE20K
77+
Metrics:
78+
mIoU: 48.03
79+
mIoU(ms+flip): 49.68
80+
Config: configs/segnext/segnext_mscan-b_1xb16-adamw-160k_ade20k-512x512.py
81+
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-b_1x16_512x512_adamw_160k_ade20k/segnext_mscan-b_1x16_512x512_adamw_160k_ade20k_20230209_172053-b6f6c70c.pth
82+
- Name: segnext_mscan-l_1xb16-adamw-160k_ade20k-512x512
83+
In Collection: SegNeXt
84+
Metadata:
85+
backbone: MSCAN-L
86+
crop size: (512,512)
87+
lr schd: 160000
88+
inference time (ms/im):
89+
- value: 43.65
90+
hardware: V100
91+
backend: PyTorch
92+
batch size: 1
93+
mode: FP32
94+
resolution: (512,512)
95+
Training Memory (GB): 43.32
96+
Results:
97+
- Task: Semantic Segmentation
98+
Dataset: ADE20K
99+
Metrics:
100+
mIoU: 50.99
101+
mIoU(ms+flip): 52.1
102+
Config: configs/segnext/segnext_mscan-l_1xb16-adamw-160k_ade20k-512x512.py
103+
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-l_1x16_512x512_adamw_160k_ade20k/segnext_mscan-l_1x16_512x512_adamw_160k_ade20k_20230209_172055-19b14b63.pth
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
_base_ = './segnext_mscan-t_1xb16-adamw-160k_ade20k-512x512.py'
2+
3+
# model settings
4+
checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segnext/mscan_b_20230227-3ab7d230.pth' # noqa
5+
ham_norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
6+
model = dict(
7+
type='EncoderDecoder',
8+
backbone=dict(
9+
embed_dims=[64, 128, 320, 512],
10+
depths=[3, 3, 12, 3],
11+
init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file),
12+
drop_path_rate=0.1,
13+
norm_cfg=dict(type='BN', requires_grad=True)),
14+
decode_head=dict(
15+
type='LightHamHead',
16+
in_channels=[128, 320, 512],
17+
in_index=[1, 2, 3],
18+
channels=512,
19+
ham_channels=512,
20+
dropout_ratio=0.1,
21+
num_classes=150,
22+
norm_cfg=ham_norm_cfg,
23+
align_corners=False,
24+
loss_decode=dict(
25+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
26+
# model training and testing settings
27+
train_cfg=dict(),
28+
test_cfg=dict(mode='whole'))
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
_base_ = './segnext_mscan-t_1xb16-adamw-160k_ade20k-512x512.py'
2+
# model settings
3+
checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segnext/mscan_l_20230227-cef260d4.pth' # noqa
4+
ham_norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
5+
model = dict(
6+
type='EncoderDecoder',
7+
backbone=dict(
8+
embed_dims=[64, 128, 320, 512],
9+
depths=[3, 5, 27, 3],
10+
init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file),
11+
drop_path_rate=0.3,
12+
norm_cfg=dict(type='BN', requires_grad=True)),
13+
decode_head=dict(
14+
type='LightHamHead',
15+
in_channels=[128, 320, 512],
16+
in_index=[1, 2, 3],
17+
channels=1024,
18+
ham_channels=1024,
19+
dropout_ratio=0.1,
20+
num_classes=150,
21+
norm_cfg=ham_norm_cfg,
22+
align_corners=False,
23+
loss_decode=dict(
24+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
25+
# model training and testing settings
26+
train_cfg=dict(),
27+
test_cfg=dict(mode='whole'))
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
_base_ = './segnext_mscan-t_1xb16-adamw-160k_ade20k-512x512.py'
2+
# model settings
3+
checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segnext/mscan_s_20230227-f33ccdf2.pth' # noqa
4+
ham_norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
5+
model = dict(
6+
type='EncoderDecoder',
7+
backbone=dict(
8+
embed_dims=[64, 128, 320, 512],
9+
depths=[2, 2, 4, 2],
10+
init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file),
11+
norm_cfg=dict(type='BN', requires_grad=True)),
12+
decode_head=dict(
13+
type='LightHamHead',
14+
in_channels=[128, 320, 512],
15+
in_index=[1, 2, 3],
16+
channels=256,
17+
ham_channels=256,
18+
ham_kwargs=dict(MD_R=16),
19+
dropout_ratio=0.1,
20+
num_classes=150,
21+
norm_cfg=ham_norm_cfg,
22+
align_corners=False,
23+
loss_decode=dict(
24+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
25+
# model training and testing settings
26+
train_cfg=dict(),
27+
test_cfg=dict(mode='whole'))
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
_base_ = [
2+
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py',
3+
'../_base_/datasets/ade20k.py'
4+
]
5+
# model settings
6+
checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segnext/mscan_t_20230227-119e8c9f.pth' # noqa
7+
ham_norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
8+
crop_size = (512, 512)
9+
data_preprocessor = dict(
10+
type='SegDataPreProcessor',
11+
mean=[123.675, 116.28, 103.53],
12+
std=[58.395, 57.12, 57.375],
13+
bgr_to_rgb=True,
14+
pad_val=0,
15+
seg_pad_val=255,
16+
size=(512, 512),
17+
test_cfg=dict(size_divisor=32))
18+
model = dict(
19+
type='EncoderDecoder',
20+
data_preprocessor=data_preprocessor,
21+
pretrained=None,
22+
backbone=dict(
23+
type='MSCAN',
24+
init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file),
25+
embed_dims=[32, 64, 160, 256],
26+
mlp_ratios=[8, 8, 4, 4],
27+
drop_rate=0.0,
28+
drop_path_rate=0.1,
29+
depths=[3, 3, 5, 2],
30+
attention_kernel_sizes=[5, [1, 7], [1, 11], [1, 21]],
31+
attention_kernel_paddings=[2, [0, 3], [0, 5], [0, 10]],
32+
act_cfg=dict(type='GELU'),
33+
norm_cfg=dict(type='BN', requires_grad=True)),
34+
decode_head=dict(
35+
type='LightHamHead',
36+
in_channels=[64, 160, 256],
37+
in_index=[1, 2, 3],
38+
channels=256,
39+
ham_channels=256,
40+
dropout_ratio=0.1,
41+
num_classes=150,
42+
norm_cfg=ham_norm_cfg,
43+
align_corners=False,
44+
loss_decode=dict(
45+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
46+
ham_kwargs=dict(
47+
MD_S=1,
48+
MD_R=16,
49+
train_steps=6,
50+
eval_steps=7,
51+
inv_t=100,
52+
rand_init=True)),
53+
# model training and testing settings
54+
train_cfg=dict(),
55+
test_cfg=dict(mode='whole'))
56+
57+
# dataset settings
58+
train_dataloader = dict(batch_size=16)
59+
60+
# optimizer
61+
optim_wrapper = dict(
62+
_delete_=True,
63+
type='OptimWrapper',
64+
optimizer=dict(
65+
type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01),
66+
paramwise_cfg=dict(
67+
custom_keys={
68+
'pos_block': dict(decay_mult=0.),
69+
'norm': dict(decay_mult=0.),
70+
'head': dict(lr_mult=10.)
71+
}))
72+
73+
param_scheduler = [
74+
dict(
75+
type='LinearLR', start_factor=1e-6, by_epoch=False, begin=0, end=1500),
76+
dict(
77+
type='PolyLR',
78+
power=1.0,
79+
begin=1500,
80+
end=160000,
81+
eta_min=0.0,
82+
by_epoch=False,
83+
)
84+
]

mmseg/models/backbones/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from .mit import MixVisionTransformer
1212
from .mobilenet_v2 import MobileNetV2
1313
from .mobilenet_v3 import MobileNetV3
14+
from .mscan import MSCAN
1415
from .pidnet import PIDNet
1516
from .resnest import ResNeSt
1617
from .resnet import ResNet, ResNetV1c, ResNetV1d
@@ -27,5 +28,5 @@
2728
'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3',
2829
'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer',
2930
'BiSeNetV1', 'BiSeNetV2', 'ICNet', 'TIMMBackbone', 'ERFNet', 'PCPVT',
30-
'SVT', 'STDCNet', 'STDCContextPathNet', 'BEiT', 'MAE', 'PIDNet'
31+
'SVT', 'STDCNet', 'STDCContextPathNet', 'BEiT', 'MAE', 'PIDNet', 'MSCAN'
3132
]

0 commit comments

Comments
 (0)