Skip to content

Commit ff1e2d8

Browse files
authored
[Feature] Support Real-time model ERFNet (open-mmlab#960)
* first commit * Fixing Unittest Error * first refactory of ERFNet * Refactorying NonBottleneck1d Module * uploading models&logs * uploading models&logs * fix partial bugs & typos * ERFNet * add ERFNet with FCNHead * fix typos of ERFNet * add name on README.md cover * chane name to T-ITS'2017 * fix lint error
1 parent 313189d commit ff1e2d8

File tree

10 files changed

+607
-1
lines changed

10 files changed

+607
-1
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ Supported backbones:
7070
Supported methods:
7171

7272
- [x] [FCN (CVPR'2015/TPAMI'2017)](configs/fcn)
73+
- [x] [ERFNet (T-ITS'2017)](configs/erfnet)
7374
- [x] [UNet (MICCAI'2016/Nat. Methods'2019)](configs/unet)
7475
- [x] [PSPNet (CVPR'2017)](configs/pspnet)
7576
- [x] [DeepLabV3 (ArXiv'2017)](configs/deeplabv3)

README_zh-CN.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 O
6969
已支持的算法:
7070

7171
- [x] [FCN (CVPR'2015/TPAMI'2017)](configs/fcn)
72+
- [x] [ERFNet (T-ITS'2017)](configs/erfnet)
7273
- [x] [UNet (MICCAI'2016/Nat. Methods'2019)](configs/unet)
7374
- [x] [PSPNet (CVPR'2017)](configs/pspnet)
7475
- [x] [DeepLabV3 (ArXiv'2017)](configs/deeplabv3)
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# model settings
2+
norm_cfg = dict(type='SyncBN', requires_grad=True)
3+
model = dict(
4+
type='EncoderDecoder',
5+
pretrained=None,
6+
backbone=dict(
7+
type='ERFNet',
8+
in_channels=3,
9+
enc_downsample_channels=(16, 64, 128),
10+
enc_stage_non_bottlenecks=(5, 8),
11+
enc_non_bottleneck_dilations=(2, 4, 8, 16),
12+
enc_non_bottleneck_channels=(64, 128),
13+
dec_upsample_channels=(64, 16),
14+
dec_stages_non_bottleneck=(2, 2),
15+
dec_non_bottleneck_channels=(64, 16),
16+
dropout_ratio=0.1,
17+
init_cfg=None),
18+
decode_head=dict(
19+
type='FCNHead',
20+
in_channels=16,
21+
channels=128,
22+
num_convs=1,
23+
concat_input=False,
24+
dropout_ratio=0.1,
25+
num_classes=19,
26+
norm_cfg=norm_cfg,
27+
align_corners=False,
28+
loss_decode=dict(
29+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
30+
# model training and testing settings
31+
train_cfg=dict(),
32+
test_cfg=dict(mode='whole'))

configs/erfnet/README.md

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# ERFNet: Efficient Residual Factorized ConvNet for Real-time Semantic Segmentation
2+
3+
## Introduction
4+
5+
<!-- [ALGORITHM] -->
6+
7+
<a href="https://github.com/Eromera/erfnet_pytorch">Official Repo</a>
8+
9+
<a href="https://github.com/open-mmlab/mmsegmentation/blob/v0.20.0/mmseg/models/backbones/erfnet.py#L321">Code Snippet</a>
10+
11+
## Abstract
12+
13+
Semantic segmentation is a challenging task that addresses most of the perception needs of intelligent vehicles (IVs) in an unified way. Deep neural networks excel at this task, as they can be trained end-to-end to accurately classify multiple object categories in an image at pixel level. However, a good tradeoff between high quality and computational resources is yet not present in the state-of-the-art semantic segmentation approaches, limiting their application in real vehicles. In this paper, we propose a deep architecture that is able to run in real time while providing accurate semantic segmentation. The core of our architecture is a novel layer that uses residual connections and factorized convolutions in order to remain efficient while retaining remarkable accuracy. Our approach is able to run at over 83 FPS in a single Titan X, and 7 FPS in a Jetson TX1 (embedded device). A comprehensive set of experiments on the publicly available Cityscapes data set demonstrates that our system achieves an accuracy that is similar to the state of the art, while being orders of magnitude faster to compute than other architectures that achieve top precision. The resulting tradeoff makes our model an ideal approach for scene understanding in IV applications. The code is publicly available at: https://github.com/Eromera/erfnet.
14+
15+
<!-- [IMAGE] -->
16+
<div align=center>
17+
<img src="https://user-images.githubusercontent.com/24582831/143479729-ea7951f6-1a3c-47d6-aaee-62c5759c0638.png" width="60%"/>
18+
</div>
19+
20+
<details>
21+
<summary align="right"><a href="http://www.robesafe.uah.es/personal/eduardo.romera/pdfs/Romera17tits.pdf">ERFNet (T-ITS'2017)</a></summary>
22+
23+
```latex
24+
@article{romera2017erfnet,
25+
title={Erfnet: Efficient residual factorized convnet for real-time semantic segmentation},
26+
author={Romera, Eduardo and Alvarez, Jos{\'e} M and Bergasa, Luis M and Arroyo, Roberto},
27+
journal={IEEE Transactions on Intelligent Transportation Systems},
28+
volume={19},
29+
number={1},
30+
pages={263--272},
31+
year={2017},
32+
publisher={IEEE}
33+
}
34+
```
35+
36+
</details>
37+
38+
## Results and models
39+
40+
### Cityscapes
41+
42+
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
43+
| --------- | --------- | --------- | ------: | -------- | -------------- | ----: | ------------- | --------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
44+
| FCN | ERFNet | 512x1024 | 160000 | 6.04 | 15.26 | 71.08 | 72.6 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/erfnet/erfnet_fcn_4x4_512x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/erfnet/erfnet_fcn_4x4_512x1024_160k_cityscapes/erfnet_fcn_4x4_512x1024_160k_cityscapes_20211126_082056-03d333ed.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/erfnet/erfnet_fcn_4x4_512x1024_160k_cityscapes/erfnet_fcn_4x4_512x1024_160k_cityscapes_20211126_082056.log.json) |
45+
46+
Note:
47+
48+
- The model is trained from scratch.
49+
50+
- Last deconvolution layer in the [original paper](https://github.com/Eromera/erfnet_pytorch/blob/master/train/erfnet.py#L123) is replaced by a naive `FCNHead` decoder head and a bilinear upsampling layer, found more effective and efficient.

configs/erfnet/erfnet.yml

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
Collections:
2+
- Name: erfnet
3+
Metadata:
4+
Training Data:
5+
- Cityscapes
6+
Paper:
7+
URL: http://www.robesafe.uah.es/personal/eduardo.romera/pdfs/Romera17tits.pdf
8+
Title: 'ERFNet: Efficient Residual Factorized ConvNet for Real-time Semantic Segmentation'
9+
README: configs/erfnet/README.md
10+
Code:
11+
URL: https://github.com/open-mmlab/mmsegmentation/blob/v0.20.0/mmseg/models/backbones/erfnet.py#L321
12+
Version: v0.20.0
13+
Converted From:
14+
Code: https://github.com/Eromera/erfnet_pytorch
15+
Models:
16+
- Name: erfnet_fcn_4x4_512x1024_160k_cityscapes
17+
In Collection: erfnet
18+
Metadata:
19+
backbone: ERFNet
20+
crop size: (512,1024)
21+
lr schd: 160000
22+
inference time (ms/im):
23+
- value: 65.53
24+
hardware: V100
25+
backend: PyTorch
26+
batch size: 1
27+
mode: FP32
28+
resolution: (512,1024)
29+
Training Memory (GB): 6.04
30+
Results:
31+
- Task: Semantic Segmentation
32+
Dataset: Cityscapes
33+
Metrics:
34+
mIoU: 71.08
35+
mIoU(ms+flip): 72.6
36+
Config: configs/erfnet/erfnet_fcn_4x4_512x1024_160k_cityscapes.py
37+
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/erfnet/erfnet_fcn_4x4_512x1024_160k_cityscapes/erfnet_fcn_4x4_512x1024_160k_cityscapes_20211126_082056-03d333ed.pth
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
_base_ = [
2+
'../_base_/models/erfnet_fcn.py', '../_base_/datasets/cityscapes.py',
3+
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
4+
]
5+
data = dict(
6+
samples_per_gpu=4,
7+
workers_per_gpu=4,
8+
)

mmseg/models/backbones/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from .bisenetv1 import BiSeNetV1
33
from .bisenetv2 import BiSeNetV2
44
from .cgnet import CGNet
5+
from .erfnet import ERFNet
56
from .fast_scnn import FastSCNN
67
from .hrnet import HRNet
78
from .icnet import ICNet
@@ -20,5 +21,5 @@
2021
'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', 'FastSCNN',
2122
'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3',
2223
'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer',
23-
'BiSeNetV1', 'BiSeNetV2', 'ICNet', 'TIMMBackbone'
24+
'BiSeNetV1', 'BiSeNetV2', 'ICNet', 'TIMMBackbone', 'ERFNet'
2425
]

0 commit comments

Comments
 (0)