Skip to content

Commit dd47cef

Browse files
authored
[Feature] Support PIDNet (open-mmlab#2609)
## Motivation Support SOTA real-time semantic segmentation method in [Paper with code](https://paperswithcode.com/task/real-time-semantic-segmentation) Paper: https://arxiv.org/pdf/2206.02066.pdf Official repo: https://github.com/XuJiacong/PIDNet ## Current results **Cityscapes** |Model|Ref mIoU|mIoU (ours)| |---|---|---| |PIDNet-S|78.8|78.74| |PIDNet-M|79.9|80.22| |PIDNet-L|80.9|80.89| ## TODO - [x] Support inference with official weights - [x] Support training on Cityscapes - [x] Update docstring - [x] Add unit test
1 parent 8c89ff3 commit dd47cef

20 files changed

+1646
-4
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ Results and models are available in the [model zoo](docs/en/model_zoo.md).
159159
- [x] [K-Net (NeurIPS'2021)](configs/knet)
160160
- [x] [MaskFormer (NeurIPS'2021)](configs/maskformer)
161161
- [x] [Mask2Former (CVPR'2022)](configs/mask2former)
162+
- [x] [PIDNet (ArXiv'2022)](configs/pidnet)
162163

163164
</details>
164165

README_zh-CN.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 O
140140
- [x] [K-Net (NeurIPS'2021)](configs/knet)
141141
- [x] [MaskFormer (NeurIPS'2021)](configs/maskformer)
142142
- [x] [Mask2Former (CVPR'2022)](configs/mask2former)
143+
- [x] [PIDNet (ArXiv'2022)](configs/pidnet)
143144

144145
</details>
145146

configs/pidnet/README.md

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# PIDNet
2+
3+
> [PIDNet: A Real-time Semantic Segmentation Network Inspired from PID Controller](https://arxiv.org/pdf/2206.02066.pdf)
4+
5+
## Introduction
6+
7+
<!-- [ALGORITHM] -->
8+
9+
<a href="https://github.com/XuJiacong/PIDNet">Official Repo</a>
10+
11+
<a href="https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/mmseg/models/backbones/pidnet.py">Code Snippet</a>
12+
13+
## Abstract
14+
15+
<!-- [ABSTRACT] -->
16+
17+
Two-branch network architecture has shown its efficiency and effectiveness for real-time semantic segmentation tasks. However, direct fusion of low-level details and high-level semantics will lead to a phenomenon that the detailed features are easily overwhelmed by surrounding contextual information, namely overshoot in this paper, which limits the improvement of the accuracy of existed two-branch models. In this paper, we bridge a connection between Convolutional Neural Network (CNN) and Proportional-IntegralDerivative (PID) controller and reveal that the two-branch network is nothing but a Proportional-Integral (PI) controller, which inherently suffers from the similar overshoot issue. To alleviate this issue, we propose a novel threebranch network architecture: PIDNet, which possesses three branches to parse the detailed, context and boundary information (derivative of semantics), respectively, and employs boundary attention to guide the fusion of detailed and context branches in final stage. The family of PIDNets achieve the best trade-off between inference speed and accuracy and their test accuracy surpasses all the existed models with similar inference speed on Cityscapes, CamVid and COCO-Stuff datasets. Especially, PIDNet-S achieves 78.6% mIOU with inference speed of 93.2 FPS on Cityscapes test set and 80.1% mIOU with speed of 153.7 FPS on CamVid test set.
18+
19+
<!-- [IMAGE] -->
20+
21+
<div align=center>
22+
<img src="https://raw.githubusercontent.com/XuJiacong/PIDNet/main/figs/pidnet.jpg" width="800"/>
23+
</div>
24+
25+
## Results and models
26+
27+
### Cityscapes
28+
29+
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
30+
| ------ | -------- | --------- | ------- | -------- | -------------- | ----- | ------------- | ----------------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
31+
| PIDNet | PIDNet-S | 1024x1024 | 120000 | 3.38 | 80.82 | 78.74 | 80.87 | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/pidnet/pidnet-s_2xb6-120k_1024x1024-cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/pidnet/pidnet-s_2xb6-120k_1024x1024-cityscapes/pidnet-s_2xb6-120k_1024x1024-cityscapes_20230302_191700-bb8e3bcc.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/pidnet/pidnet-s_2xb6-120k_1024x1024-cityscapes/pidnet-s_2xb6-120k_1024x1024-cityscapes_20230302_191700.json) |
32+
| PIDNet | PIDNet-M | 1024x1024 | 120000 | 5.14 | 71.98 | 80.22 | 82.05 | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/pidnet/pidnet-m_2xb6-120k_1024x1024-cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/pidnet/pidnet-m_2xb6-120k_1024x1024-cityscapes/pidnet-m_2xb6-120k_1024x1024-cityscapes_20230301_143452-f9bcdbf3.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/pidnet/pidnet-m_2xb6-120k_1024x1024-cityscapes/pidnet-m_2xb6-120k_1024x1024-cityscapes_20230301_143452.json) |
33+
| PIDNet | PIDNet-L | 1024x1024 | 120000 | 5.83 | 60.06 | 80.89 | 82.37 | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/pidnet/pidnet-l_2xb6-120k_1024x1024-cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/pidnet/pidnet-l_2xb6-120k_1024x1024-cityscapes/pidnet-l_2xb6-120k_1024x1024-cityscapes_20230303_114514-0783ca6b.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/pidnet/pidnet-l_2xb6-120k_1024x1024-cityscapes/pidnet-l_2xb6-120k_1024x1024-cityscapes_20230303_114514.json) |
34+
35+
## Notes
36+
37+
The pretrained weights in config files are converted from [the official repo](https://github.com/XuJiacong/PIDNet#models).
38+
39+
## Citation
40+
41+
```bibtex
42+
@misc{xu2022pidnet,
43+
title={PIDNet: A Real-time Semantic Segmentation Network Inspired from PID Controller},
44+
author={Jiacong Xu and Zixiang Xiong and Shankar P. Bhattacharyya},
45+
year={2022},
46+
eprint={2206.02066},
47+
archivePrefix={arXiv},
48+
primaryClass={cs.CV}
49+
}
50+
```
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
_base_ = './pidnet-s_2xb6-120k_1024x1024-cityscapes.py'
2+
checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/pidnet/pidnet-l_imagenet1k_20230306-67889109.pth' # noqa
3+
model = dict(
4+
backbone=dict(
5+
channels=64,
6+
ppm_channels=112,
7+
num_stem_blocks=3,
8+
num_branch_blocks=4,
9+
init_cfg=dict(checkpoint=checkpoint_file)),
10+
decode_head=dict(in_channels=256, channels=256))
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
_base_ = './pidnet-s_2xb6-120k_1024x1024-cityscapes.py'
2+
checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/pidnet/pidnet-m_imagenet1k_20230306-39893c52.pth' # noqa
3+
model = dict(
4+
backbone=dict(channels=64, init_cfg=dict(checkpoint=checkpoint_file)),
5+
decode_head=dict(in_channels=256))
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
_base_ = [
2+
'../_base_/datasets/cityscapes_1024x1024.py',
3+
'../_base_/default_runtime.py'
4+
]
5+
6+
# The class_weight is borrowed from https://github.com/openseg-group/OCNet.pytorch/issues/14 # noqa
7+
# Licensed under the MIT License
8+
class_weight = [
9+
0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754, 1.0489, 0.8786,
10+
1.0023, 0.9539, 0.9843, 1.1116, 0.9037, 1.0865, 1.0955, 1.0865, 1.1529,
11+
1.0507
12+
]
13+
checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/pidnet/pidnet-s_imagenet1k_20230306-715e6273.pth' # noqa
14+
crop_size = (1024, 1024)
15+
data_preprocessor = dict(
16+
type='SegDataPreProcessor',
17+
mean=[123.675, 116.28, 103.53],
18+
std=[58.395, 57.12, 57.375],
19+
bgr_to_rgb=True,
20+
pad_val=0,
21+
seg_pad_val=255,
22+
size=crop_size)
23+
norm_cfg = dict(type='SyncBN', requires_grad=True)
24+
model = dict(
25+
type='EncoderDecoder',
26+
data_preprocessor=data_preprocessor,
27+
backbone=dict(
28+
type='PIDNet',
29+
in_channels=3,
30+
channels=32,
31+
ppm_channels=96,
32+
num_stem_blocks=2,
33+
num_branch_blocks=3,
34+
align_corners=False,
35+
norm_cfg=norm_cfg,
36+
act_cfg=dict(type='ReLU', inplace=True),
37+
init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file)),
38+
decode_head=dict(
39+
type='PIDHead',
40+
in_channels=128,
41+
channels=128,
42+
num_classes=19,
43+
norm_cfg=norm_cfg,
44+
act_cfg=dict(type='ReLU', inplace=True),
45+
align_corners=True,
46+
loss_decode=[
47+
dict(
48+
type='CrossEntropyLoss',
49+
use_sigmoid=False,
50+
class_weight=class_weight,
51+
loss_weight=0.4),
52+
dict(
53+
type='OhemCrossEntropy',
54+
thres=0.9,
55+
min_kept=131072,
56+
class_weight=class_weight,
57+
loss_weight=1.0),
58+
dict(type='BoundaryLoss', loss_weight=20.0),
59+
dict(
60+
type='OhemCrossEntropy',
61+
thres=0.9,
62+
min_kept=131072,
63+
class_weight=class_weight,
64+
loss_weight=1.0)
65+
]),
66+
train_cfg=dict(),
67+
test_cfg=dict(mode='whole'))
68+
69+
train_pipeline = [
70+
dict(type='LoadImageFromFile'),
71+
dict(type='LoadAnnotations'),
72+
dict(
73+
type='RandomResize',
74+
scale=(2048, 1024),
75+
ratio_range=(0.5, 2.0),
76+
keep_ratio=True),
77+
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
78+
dict(type='RandomFlip', prob=0.5),
79+
dict(type='PhotoMetricDistortion'),
80+
dict(type='GenerateEdge', edge_width=4),
81+
dict(type='PackSegInputs')
82+
]
83+
train_dataloader = dict(batch_size=6, dataset=dict(pipeline=train_pipeline))
84+
85+
iters = 120000
86+
# optimizer
87+
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
88+
optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer, clip_grad=None)
89+
# learning policy
90+
param_scheduler = [
91+
dict(
92+
type='PolyLR',
93+
eta_min=0,
94+
power=0.9,
95+
begin=0,
96+
end=iters,
97+
by_epoch=False)
98+
]
99+
# training schedule for 120k
100+
train_cfg = dict(
101+
type='IterBasedTrainLoop', max_iters=iters, val_interval=iters // 10)
102+
val_cfg = dict(type='ValLoop')
103+
test_cfg = dict(type='TestLoop')
104+
default_hooks = dict(
105+
timer=dict(type='IterTimerHook'),
106+
logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),
107+
param_scheduler=dict(type='ParamSchedulerHook'),
108+
checkpoint=dict(
109+
type='CheckpointHook', by_epoch=False, interval=iters // 10),
110+
sampler_seed=dict(type='DistSamplerSeedHook'),
111+
visualization=dict(type='SegVisualizationHook'))
112+
113+
randomness = dict(seed=304)

configs/pidnet/pidnet.yml

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
Collections:
2+
- Name: PIDNet
3+
Metadata:
4+
Training Data:
5+
- Cityscapes
6+
Paper:
7+
URL: https://arxiv.org/pdf/2206.02066.pdf
8+
Title: 'PIDNet: A Real-time Semantic Segmentation Network Inspired from PID Controller'
9+
README: configs/pidnet/README.md
10+
Code:
11+
URL: https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/mmseg/models/backbones/pidnet.py
12+
Version: dev-1.x
13+
Converted From:
14+
Code: https://github.com/XuJiacong/PIDNet
15+
Models:
16+
- Name: pidnet-s_2xb6-120k_1024x1024-cityscapes
17+
In Collection: PIDNet
18+
Metadata:
19+
backbone: PIDNet-S
20+
crop size: (1024,1024)
21+
lr schd: 120000
22+
inference time (ms/im):
23+
- value: 12.37
24+
hardware: V100
25+
backend: PyTorch
26+
batch size: 1
27+
mode: FP32
28+
resolution: (1024,1024)
29+
Training Memory (GB): 3.38
30+
Results:
31+
- Task: Semantic Segmentation
32+
Dataset: Cityscapes
33+
Metrics:
34+
mIoU: 78.74
35+
mIoU(ms+flip): 80.87
36+
Config: configs/pidnet/pidnet-s_2xb6-120k_1024x1024-cityscapes.py
37+
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/pidnet/pidnet-s_2xb6-120k_1024x1024-cityscapes/pidnet-s_2xb6-120k_1024x1024-cityscapes_20230302_191700-bb8e3bcc.pth
38+
- Name: pidnet-m_2xb6-120k_1024x1024-cityscapes
39+
In Collection: PIDNet
40+
Metadata:
41+
backbone: PIDNet-M
42+
crop size: (1024,1024)
43+
lr schd: 120000
44+
inference time (ms/im):
45+
- value: 13.89
46+
hardware: V100
47+
backend: PyTorch
48+
batch size: 1
49+
mode: FP32
50+
resolution: (1024,1024)
51+
Training Memory (GB): 5.14
52+
Results:
53+
- Task: Semantic Segmentation
54+
Dataset: Cityscapes
55+
Metrics:
56+
mIoU: 80.22
57+
mIoU(ms+flip): 82.05
58+
Config: configs/pidnet/pidnet-m_2xb6-120k_1024x1024-cityscapes.py
59+
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/pidnet/pidnet-m_2xb6-120k_1024x1024-cityscapes/pidnet-m_2xb6-120k_1024x1024-cityscapes_20230301_143452-f9bcdbf3.pth
60+
- Name: pidnet-l_2xb6-120k_1024x1024-cityscapes
61+
In Collection: PIDNet
62+
Metadata:
63+
backbone: PIDNet-L
64+
crop size: (1024,1024)
65+
lr schd: 120000
66+
inference time (ms/im):
67+
- value: 16.65
68+
hardware: V100
69+
backend: PyTorch
70+
batch size: 1
71+
mode: FP32
72+
resolution: (1024,1024)
73+
Training Memory (GB): 5.83
74+
Results:
75+
- Task: Semantic Segmentation
76+
Dataset: Cityscapes
77+
Metrics:
78+
mIoU: 80.89
79+
mIoU(ms+flip): 82.37
80+
Config: configs/pidnet/pidnet-l_2xb6-120k_1024x1024-cityscapes.py
81+
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/pidnet/pidnet-l_2xb6-120k_1024x1024-cityscapes/pidnet-l_2xb6-120k_1024x1024-cityscapes_20230303_114514-0783ca6b.pth

mmseg/models/backbones/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from .mit import MixVisionTransformer
1212
from .mobilenet_v2 import MobileNetV2
1313
from .mobilenet_v3 import MobileNetV3
14+
from .pidnet import PIDNet
1415
from .resnest import ResNeSt
1516
from .resnet import ResNet, ResNetV1c, ResNetV1d
1617
from .resnext import ResNeXt
@@ -26,5 +27,5 @@
2627
'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3',
2728
'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer',
2829
'BiSeNetV1', 'BiSeNetV2', 'ICNet', 'TIMMBackbone', 'ERFNet', 'PCPVT',
29-
'SVT', 'STDCNet', 'STDCContextPathNet', 'BEiT', 'MAE'
30+
'SVT', 'STDCNet', 'STDCContextPathNet', 'BEiT', 'MAE', 'PIDNet'
3031
]

0 commit comments

Comments
 (0)