Skip to content

Commit b8f42c7

Browse files
authored
Add Semantic FPN (open-mmlab#94)
* Add Semantic FPN * remove HRFPN
1 parent 597b8a6 commit b8f42c7

File tree

14 files changed

+388
-37
lines changed

14 files changed

+388
-37
lines changed

configs/_base_/models/fpn_r50.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# model settings
2+
norm_cfg = dict(type='SyncBN', requires_grad=True)
3+
model = dict(
4+
type='EncoderDecoder',
5+
pretrained='open-mmlab://resnet50_v1c',
6+
backbone=dict(
7+
type='ResNetV1c',
8+
depth=50,
9+
num_stages=4,
10+
out_indices=(0, 1, 2, 3),
11+
dilations=(1, 1, 1, 1),
12+
strides=(1, 2, 2, 2),
13+
norm_cfg=norm_cfg,
14+
norm_eval=False,
15+
style='pytorch',
16+
contract_dilation=True),
17+
neck=dict(
18+
type='FPN',
19+
in_channels=[256, 512, 1024, 2048],
20+
out_channels=256,
21+
num_outs=4),
22+
decode_head=dict(
23+
type='FPNHead',
24+
in_channels=[256, 256, 256, 256],
25+
in_index=[0, 1, 2, 3],
26+
feature_strides=[4, 8, 16, 32],
27+
channels=128,
28+
dropout_ratio=0.1,
29+
num_classes=19,
30+
norm_cfg=norm_cfg,
31+
align_corners=False,
32+
loss_decode=dict(
33+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)))
34+
# model training and testing settings
35+
train_cfg = dict()
36+
test_cfg = dict(mode='whole')

configs/sem_fpn/README.md

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Panoptic Feature Pyramid Networks
2+
3+
## Introduction
4+
```
5+
@article{Kirillov_2019,
6+
title={Panoptic Feature Pyramid Networks},
7+
ISBN={9781728132938},
8+
url={http://dx.doi.org/10.1109/CVPR.2019.00656},
9+
DOI={10.1109/cvpr.2019.00656},
10+
journal={2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
11+
publisher={IEEE},
12+
author={Kirillov, Alexander and Girshick, Ross and He, Kaiming and Dollar, Piotr},
13+
year={2019},
14+
month={Jun}
15+
}
16+
```
17+
18+
## Results and models
19+
20+
### Cityscapes
21+
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | download |
22+
|--------|----------|-----------|--------:|---------:|----------------|------:|---------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
23+
| FPN | R-50 | 512x1024 | 80000 | 2.8 | 13.54 | 74.52 | 76.08 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/sem_fpn/fpn_r50_512x1024_80k_cityscapes/fpn_r50_512x1024_80k_cityscapes_20200717_021437-94018a0d.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/sem_fpn/fpn_r50_512x1024_80k_cityscapes/fpn_r50_512x1024_80k_cityscapes-20200717_021437.log.json) |
24+
| FPN | R-101 | 512x1024 | 80000 | 3.9 | 10.29 | 75.80 | 77.40 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/sem_fpn/fpn_r101_512x1024_80k_cityscapes/fpn_r101_512x1024_80k_cityscapes_20200717_012416-c5800d4c.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/sem_fpn/fpn_r101_512x1024_80k_cityscapes/fpn_r101_512x1024_80k_cityscapes-20200717_012416.log.json) |
25+
26+
### ADE20K
27+
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | download |
28+
|--------|----------|-----------|--------:|---------:|----------------|------:|---------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
29+
| FPN | R-50 | 512x512 | 160000 | 4.9 | 55.77 | 37.49 | 39.09 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/sem_fpn/fpn_r50_512x512_160k_ade20k/fpn_r50_512x512_160k_ade20k_20200718_131734-5b5a6ab9.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/sem_fpn/fpn_r50_512x512_160k_ade20k/fpn_r50_512x512_160k_ade20k-20200718_131734.log.json) |
30+
| FPN | R-101 | 512x512 | 160000 | 5.9 | 40.58 | 39.35 | 40.72 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/sem_fpn/fpn_r101_512x512_160k_ade20k/fpn_r101_512x512_160k_ade20k_20200718_131734-306b5004.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/sem_fpn/fpn_r101_512x512_160k_ade20k/fpn_r101_512x512_160k_ade20k-20200718_131734.log.json) |
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
_base_ = './fpn_r50_512x1024_80k_cityscapes.py'
2+
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
_base_ = './fpn_r50_512x512_160k_ade20k.py'
2+
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
_base_ = [
2+
'../_base_/models/fpn_r50.py', '../_base_/datasets/cityscapes.py',
3+
'../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py'
4+
]
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
_base_ = [
2+
'../_base_/models/fpn_r50.py', '../_base_/datasets/ade20k.py',
3+
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
4+
]
5+
model = dict(decode_head=dict(num_classes=150))

mmseg/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
build_head, build_loss, build_segmentor)
44
from .decode_heads import * # noqa: F401,F403
55
from .losses import * # noqa: F401,F403
6+
from .necks import * # noqa: F401,F403
67
from .segmentors import * # noqa: F401,F403
78

89
__all__ = [

mmseg/models/decode_heads/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .da_head import DAHead
55
from .enc_head import EncHead
66
from .fcn_head import FCNHead
7+
from .fpn_head import FPNHead
78
from .gc_head import GCHead
89
from .nl_head import NLHead
910
from .ocr_head import OCRHead
@@ -16,5 +17,5 @@
1617
__all__ = [
1718
'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead',
1819
'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead',
19-
'EncHead', 'DepthwiseSeparableFCNHead'
20+
'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead'
2021
]
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import numpy as np
2+
import torch.nn as nn
3+
from mmcv.cnn import ConvModule
4+
5+
from mmseg.ops import resize
6+
from ..builder import HEADS
7+
from .decode_head import BaseDecodeHead
8+
9+
10+
@HEADS.register_module()
11+
class FPNHead(BaseDecodeHead):
12+
"""Panoptic Feature Pyramid Networks.
13+
14+
This head is the implementation of `Semantic FPN
15+
<https://arxiv.org/abs/1901.02446>`_.
16+
17+
Args:
18+
feature_strides (tuple[int]): The strides for input feature maps.
19+
stack_lateral. All strides suppose to be power of 2. The first
20+
one is of largest resolution.
21+
"""
22+
23+
def __init__(self, feature_strides, **kwargs):
24+
super(FPNHead, self).__init__(
25+
input_transform='multiple_select', **kwargs)
26+
assert len(feature_strides) == len(self.in_channels)
27+
assert min(feature_strides) == feature_strides[0]
28+
self.feature_strides = feature_strides
29+
30+
self.scale_heads = nn.ModuleList()
31+
for i in range(len(feature_strides)):
32+
head_length = max(
33+
1,
34+
int(np.log2(feature_strides[i]) - np.log2(feature_strides[0])))
35+
scale_head = []
36+
for k in range(head_length):
37+
scale_head.append(
38+
ConvModule(
39+
self.in_channels[i] if k == 0 else self.channels,
40+
self.channels,
41+
3,
42+
padding=1,
43+
conv_cfg=self.conv_cfg,
44+
norm_cfg=self.norm_cfg,
45+
act_cfg=self.act_cfg))
46+
if feature_strides[i] != feature_strides[0]:
47+
scale_head.append(
48+
nn.Upsample(
49+
scale_factor=2,
50+
mode='bilinear',
51+
align_corners=self.align_corners))
52+
self.scale_heads.append(nn.Sequential(*scale_head))
53+
54+
def forward(self, inputs):
55+
56+
x = self._transform_inputs(inputs)
57+
58+
output = self.scale_heads[0](x[0])
59+
for i in range(1, len(self.feature_strides)):
60+
# non inplace
61+
output = output + resize(
62+
self.scale_heads[i](x[i]),
63+
size=output.shape[2:],
64+
mode='bilinear',
65+
align_corners=self.align_corners)
66+
67+
output = self.cls_seg(output)
68+
return output

mmseg/models/necks/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .fpn import FPN
2+
3+
__all__ = ['FPN']

0 commit comments

Comments
 (0)