Skip to content

Commit 5783bc1

Browse files
MengzhangLIxiexinch
andauthored
[Feature] Support STDC Network (new) (open-mmlab#995)
* refactor stdc code * update key * fix backbone inference * remove comments * fixing errors * fixing version conflict * fux typo * use STDCHead * upload models&logs * adding model converters script and fix unittest * fix error * fix error * fix error * delete redundant keys in config * fix errors in configs and unittest * fix errors in configs and unittest * fix errors in configs and unittest * change Memory name * refactor stdc2mmseg * change name to STDC * refactor stdc * refactor stdc * stdc refactor * stdc refactor * stdc refactor * stdc refactor * stdc refactor * stdc refactor * refactor stdc * stdc refactor Co-authored-by: xiexinch <[email protected]>
1 parent ff044c5 commit 5783bc1

File tree

17 files changed

+1018
-2
lines changed

17 files changed

+1018
-2
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ Supported methods:
9898
- [x] [PointRend (CVPR'2020)](configs/point_rend)
9999
- [x] [CGNet (TIP'2020)](configs/cgnet)
100100
- [x] [BiSeNetV2 (IJCV'2021)](configs/bisenetv2)
101+
- [x] [STDC (CVPR'2021)](configs/stdc)
101102
- [x] [SETR (CVPR'2021)](configs/setr)
102103
- [x] [DPT (ArXiv'2021)](configs/dpt)
103104
- [x] [SegFormer (NeurIPS'2021)](configs/segformer)

README_zh-CN.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 O
9797
- [x] [PointRend (CVPR'2020)](configs/point_rend)
9898
- [x] [CGNet (TIP'2020)](configs/cgnet)
9999
- [x] [BiSeNetV2 (IJCV'2021)](configs/bisenetv2)
100+
- [x] [STDC (CVPR'2021)](configs/stdc)
100101
- [x] [SETR (CVPR'2021)](configs/setr)
101102
- [x] [DPT (ArXiv'2021)](configs/dpt)
102103
- [x] [SegFormer (NeurIPS'2021)](configs/segformer)

configs/_base_/models/stdc.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
norm_cfg = dict(type='BN', requires_grad=True)
2+
model = dict(
3+
type='EncoderDecoder',
4+
pretrained=None,
5+
backbone=dict(
6+
type='STDCContextPathNet',
7+
backbone_cfg=dict(
8+
type='STDCNet',
9+
stdc_type='STDCNet1',
10+
in_channels=3,
11+
channels=(32, 64, 256, 512, 1024),
12+
bottleneck_type='cat',
13+
num_convs=4,
14+
norm_cfg=norm_cfg,
15+
act_cfg=dict(type='ReLU'),
16+
with_final_conv=False),
17+
last_in_channels=(1024, 512),
18+
out_channels=128,
19+
ffm_cfg=dict(in_channels=384, out_channels=256, scale_factor=4)),
20+
decode_head=dict(
21+
type='FCNHead',
22+
in_channels=256,
23+
channels=256,
24+
num_convs=1,
25+
num_classes=19,
26+
in_index=3,
27+
concat_input=False,
28+
dropout_ratio=0.1,
29+
norm_cfg=norm_cfg,
30+
align_corners=True,
31+
sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=10000),
32+
loss_decode=dict(
33+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
34+
auxiliary_head=[
35+
dict(
36+
type='FCNHead',
37+
in_channels=128,
38+
channels=64,
39+
num_convs=1,
40+
num_classes=19,
41+
in_index=2,
42+
norm_cfg=norm_cfg,
43+
concat_input=False,
44+
align_corners=False,
45+
sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=10000),
46+
loss_decode=dict(
47+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
48+
dict(
49+
type='FCNHead',
50+
in_channels=128,
51+
channels=64,
52+
num_convs=1,
53+
num_classes=19,
54+
in_index=1,
55+
norm_cfg=norm_cfg,
56+
concat_input=False,
57+
align_corners=False,
58+
sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=10000),
59+
loss_decode=dict(
60+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
61+
dict(
62+
type='STDCHead',
63+
in_channels=256,
64+
channels=64,
65+
num_convs=1,
66+
num_classes=2,
67+
boundary_threshold=0.1,
68+
in_index=0,
69+
norm_cfg=norm_cfg,
70+
concat_input=False,
71+
align_corners=False,
72+
loss_decode=[
73+
dict(
74+
type='CrossEntropyLoss',
75+
loss_name='loss_ce',
76+
use_sigmoid=True,
77+
loss_weight=1.0),
78+
dict(type='DiceLoss', loss_name='loss_dice', loss_weight=1.0)
79+
]),
80+
],
81+
# model training and testing settings
82+
train_cfg=dict(),
83+
test_cfg=dict(mode='whole'))

configs/stdc/README.md

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Rethinking BiSeNet For Real-time Semantic Segmentation
2+
3+
## Introduction
4+
5+
<!-- [ALGORITHM] -->
6+
7+
<a href="https://github.com/MichaelFan01/STDC-Seg">Official Repo</a>
8+
9+
<a href="https://github.com/open-mmlab/mmsegmentation/blob/v0.20.0/mmseg/models/backbones/stdc.py#L394">Code Snippet</a>
10+
11+
## Abstract
12+
13+
BiSeNet has been proved to be a popular two-stream network for real-time segmentation. However, its principle of adding an extra path to encode spatial information is time-consuming, and the backbones borrowed from pretrained tasks, e.g., image classification, may be inefficient for image segmentation due to the deficiency of task-specific design. To handle these problems, we propose a novel and efficient structure named Short-Term Dense Concatenate network (STDC network) by removing structure redundancy. Specifically, we gradually reduce the dimension of feature maps and use the aggregation of them for image representation, which forms the basic module of STDC network. In the decoder, we propose a Detail Aggregation module by integrating the learning of spatial information into low-level layers in single-stream manner. Finally, the low-level features and deep features are fused to predict the final segmentation results. Extensive experiments on Cityscapes and CamVid dataset demonstrate the effectiveness of our method by achieving promising trade-off between segmentation accuracy and inference speed. On Cityscapes, we achieve 71.9% mIoU on the test set with a speed of 250.4 FPS on NVIDIA GTX 1080Ti, which is 45.2% faster than the latest methods, and achieve 76.8% mIoU with 97.0 FPS while inferring on higher resolution images.
14+
15+
<!-- [IMAGE] -->
16+
<div align=center>
17+
<img src="https://user-images.githubusercontent.com/24582831/143640374-d0709587-edb2-4821-bb60-340035f6ad8f.png" width="60%"/>
18+
</div>
19+
20+
<details>
21+
<summary align="right"><a href="https://arxiv.org/abs/2104.13188">STDC (CVPR'2021)</a></summary>
22+
23+
```latex
24+
@inproceedings{fan2021rethinking,
25+
title={Rethinking BiSeNet For Real-time Semantic Segmentation},
26+
author={Fan, Mingyuan and Lai, Shenqi and Huang, Junshi and Wei, Xiaoming and Chai, Zhenhua and Luo, Junfeng and Wei, Xiaolin},
27+
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
28+
pages={9716--9725},
29+
year={2021}
30+
}
31+
```
32+
33+
</details>
34+
35+
## Usage
36+
37+
To use original repositories' [ImageNet Pretrained STDCNet Weights](https://drive.google.com/drive/folders/1wROFwRt8qWHD4jSo8Zu1gp1d6oYJ3ns1) , it is necessary to convert keys.
38+
39+
We provide a script [`stdc2mmseg.py`](../../tools/model_converters/stdc2mmseg.py) in the tools directory to convert the key of models from [the official repo](https://github.com/MichaelFan01/STDC-Seg) to MMSegmentation style.
40+
41+
```shell
42+
python tools/model_converters/stdc2mmseg.py ${PRETRAIN_PATH} ${STORE_PATH} ${STDC_TYPE}
43+
```
44+
45+
E.g.
46+
47+
```shell
48+
python tools/model_converters/stdc2mmseg.py ./STDCNet813M_73.91.tar ./pretrained/stdc1.pth STDC1
49+
50+
python tools/model_converters/stdc2mmseg.py ./STDCNet1446_76.47.tar ./pretrained/stdc2.pth STDC2
51+
```
52+
53+
This script convert model from `PRETRAIN_PATH` and store the converted model in `STORE_PATH`.
54+
55+
## Results and models
56+
57+
### Cityscapes
58+
59+
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
60+
| --------- | --------- | --------- | ------: | -------- | -------------- | ----: | ------------- | --------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
61+
| STDC1 (No Pretrain) | STDC1 | 512x1024 | 80000 | 7.15 | 23.06 | 71.52 | 73.35 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/stdc/stdc1_512x1024_80k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/v0.5/stdc/stdc1_512x1024_80k_cityscapes/stdc1_512x1024_80k_cityscapes_20211125_211245-2c8ba4c5.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/stdc/stdc1_512x1024_80k_cityscapes/stdc1_512x1024_80k_cityscapes_20211125_211245.log.json) |
62+
| STDC1| STDC1 | 512x1024 | 80000 | - | - | 75.10 | 77.72 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/stdc/stdc1_in1k-pre_512x1024_80k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/stdc/stdc1_in1k-pre_512x1024_80k_cityscapes/stdc1_in1k-pre_512x1024_80k_cityscapes_20211125_213942-880bb7d0.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/stdc/stdc1_in1k-pre_512x1024_80k_cityscapes/stdc1_in1k-pre_512x1024_80k_cityscapes_20211125_213942.log.json) |
63+
| STDC2 (No Pretrain) | STDC2 | 512x1024 | 80000 | 8.27 | 23.71 | 73.20 | 75.55 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/stdc/stdc2_512x1024_80k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/stdc/stdc2_512x1024_80k_cityscapes/stdc2_512x1024_80k_cityscapes_20211125_222450-82333ae0.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/stdc/stdc2_512x1024_80k_cityscapes/stdc2_512x1024_80k_cityscapes_20211125_222450.log.json) |
64+
| STDC2 | STDC2 | 512x1024 | 80000 | - | - | 77.17 | 79.01 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/stdc/stdc2_in1k-pre_512x1024_80k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/stdc/stdc2_in1k-pre_512x1024_80k_cityscapes/stdc2_in1k-pre_512x1024_80k_cityscapes_20211125_220437-d2c469f8.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/stdc/stdc2_in1k-pre_512x1024_80k_cityscapes/stdc2_in1k-pre_512x1024_80k_cityscapes_20211125_220437.log.json) |
65+
66+
Note:
67+
68+
- For STDC on Cityscapes dataset, default setting is 4 GPUs with 12 samples per GPU in training.
69+
- `No Pretrain` means the model is trained from scratch.
70+
- The FPS is for reference only. The environment is also different from paper setting, whose input size is `512x1024` and `768x1536`, i.e., 50% and 75% of our input size, respectively and using TensorRT.
71+
- The parameter `fusion_kernel` in `STDCHead` is not learnable. In official repo, `find_unused_parameters=True` is set [here](https://github.com/MichaelFan01/STDC-Seg/blob/59ff37fbd693b99972c76fcefe97caa14aeb619f/train.py#L220). You may check it by printing model parameters of original repo on your own.

configs/stdc/stdc.yml

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
Collections:
2+
- Name: stdc
3+
Metadata:
4+
Training Data:
5+
- Cityscapes
6+
Paper:
7+
URL: https://arxiv.org/abs/2104.13188
8+
Title: Rethinking BiSeNet For Real-time Semantic Segmentation
9+
README: configs/stdc/README.md
10+
Code:
11+
URL: https://github.com/open-mmlab/mmsegmentation/blob/v0.20.0/mmseg/models/backbones/stdc.py#L394
12+
Version: v0.20.0
13+
Converted From:
14+
Code: https://github.com/MichaelFan01/STDC-Seg
15+
Models:
16+
- Name: stdc1_512x1024_80k_cityscapes
17+
In Collection: stdc
18+
Metadata:
19+
backbone: STDC1
20+
crop size: (512,1024)
21+
lr schd: 80000
22+
inference time (ms/im):
23+
- value: 43.37
24+
hardware: V100
25+
backend: PyTorch
26+
batch size: 1
27+
mode: FP32
28+
resolution: (512,1024)
29+
Training Memory (GB): 7.15
30+
Results:
31+
- Task: Semantic Segmentation
32+
Dataset: Cityscapes
33+
Metrics:
34+
mIoU: 71.52
35+
mIoU(ms+flip): 73.35
36+
Config: configs/stdc/stdc1_512x1024_80k_cityscapes.py
37+
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/v0.5/stdc/stdc1_512x1024_80k_cityscapes/stdc1_512x1024_80k_cityscapes_20211125_211245-2c8ba4c5.pth
38+
- Name: stdc1_in1k-pre_512x1024_80k_cityscapes
39+
In Collection: stdc
40+
Metadata:
41+
backbone: STDC1
42+
crop size: (512,1024)
43+
lr schd: 80000
44+
Results:
45+
- Task: Semantic Segmentation
46+
Dataset: Cityscapes
47+
Metrics:
48+
mIoU: 75.1
49+
mIoU(ms+flip): 77.72
50+
Config: configs/stdc/stdc1_in1k-pre_512x1024_80k_cityscapes.py
51+
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/stdc/stdc1_in1k-pre_512x1024_80k_cityscapes/stdc1_in1k-pre_512x1024_80k_cityscapes_20211125_213942-880bb7d0.pth
52+
- Name: stdc2_512x1024_80k_cityscapes
53+
In Collection: stdc
54+
Metadata:
55+
backbone: STDC2
56+
crop size: (512,1024)
57+
lr schd: 80000
58+
inference time (ms/im):
59+
- value: 42.18
60+
hardware: V100
61+
backend: PyTorch
62+
batch size: 1
63+
mode: FP32
64+
resolution: (512,1024)
65+
Training Memory (GB): 8.27
66+
Results:
67+
- Task: Semantic Segmentation
68+
Dataset: Cityscapes
69+
Metrics:
70+
mIoU: 73.2
71+
mIoU(ms+flip): 75.55
72+
Config: configs/stdc/stdc2_512x1024_80k_cityscapes.py
73+
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/stdc/stdc2_512x1024_80k_cityscapes/stdc2_512x1024_80k_cityscapes_20211125_222450-82333ae0.pth
74+
- Name: stdc2_in1k-pre_512x1024_80k_cityscapes
75+
In Collection: stdc
76+
Metadata:
77+
backbone: STDC2
78+
crop size: (512,1024)
79+
lr schd: 80000
80+
Results:
81+
- Task: Semantic Segmentation
82+
Dataset: Cityscapes
83+
Metrics:
84+
mIoU: 77.17
85+
mIoU(ms+flip): 79.01
86+
Config: configs/stdc/stdc2_in1k-pre_512x1024_80k_cityscapes.py
87+
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/stdc/stdc2_in1k-pre_512x1024_80k_cityscapes/stdc2_in1k-pre_512x1024_80k_cityscapes_20211125_220437-d2c469f8.pth
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
_base_ = [
2+
'../_base_/models/stdc.py', '../_base_/datasets/cityscapes.py',
3+
'../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py'
4+
]
5+
lr_config = dict(warmup='linear', warmup_iters=1000)
6+
data = dict(
7+
samples_per_gpu=12,
8+
workers_per_gpu=4,
9+
)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
_base_ = './stdc1_512x1024_80k_cityscapes.py'
2+
model = dict(
3+
backbone=dict(
4+
backbone_cfg=dict(
5+
init_cfg=dict(
6+
type='Pretrained', checkpoint='./pretrained/stdc1.pth'))))
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
_base_ = './stdc1_512x1024_80k_cityscapes.py'
2+
model = dict(backbone=dict(backbone_cfg=dict(stdc_type='STDCNet2')))
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
_base_ = './stdc2_512x1024_80k_cityscapes.py'
2+
model = dict(
3+
backbone=dict(
4+
backbone_cfg=dict(
5+
init_cfg=dict(
6+
type='Pretrained', checkpoint='./pretrained/stdc2.pth'))))

mmseg/models/backbones/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from .resnest import ResNeSt
1313
from .resnet import ResNet, ResNetV1c, ResNetV1d
1414
from .resnext import ResNeXt
15+
from .stdc import STDCContextPathNet, STDCNet
1516
from .swin import SwinTransformer
1617
from .timm_backbone import TIMMBackbone
1718
from .twins import PCPVT, SVT
@@ -22,5 +23,6 @@
2223
'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', 'FastSCNN',
2324
'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3',
2425
'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer',
25-
'BiSeNetV1', 'BiSeNetV2', 'ICNet', 'TIMMBackbone', 'ERFNet', 'PCPVT', 'SVT'
26+
'BiSeNetV1', 'BiSeNetV2', 'ICNet', 'TIMMBackbone', 'ERFNet', 'PCPVT',
27+
'SVT', 'STDCNet', 'STDCContextPathNet'
2628
]

0 commit comments

Comments
 (0)