Skip to content

Commit dbca8b4

Browse files
XiaLiPKUxvjiarui
andauthored
[Feature] Support EMANet (open-mmlab#34)
* add emanet * fixed bug and typos * add emanet config * fixed padding * fixed identity * rename * rename * add concat_input * fallback to update last * Fixed concat * update EMANet * Add tests * remove self-implement norm Co-authored-by: Jiarui XU <[email protected]>
1 parent 3c6dd9e commit dbca8b4

File tree

10 files changed

+282
-4
lines changed

10 files changed

+282
-4
lines changed
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# model settings
2+
norm_cfg = dict(type='SyncBN', requires_grad=True)
3+
model = dict(
4+
type='EncoderDecoder',
5+
pretrained='open-mmlab://resnet50_v1c',
6+
backbone=dict(
7+
type='ResNetV1c',
8+
depth=50,
9+
num_stages=4,
10+
out_indices=(0, 1, 2, 3),
11+
dilations=(1, 1, 2, 4),
12+
strides=(1, 2, 1, 1),
13+
norm_cfg=norm_cfg,
14+
norm_eval=False,
15+
style='pytorch',
16+
contract_dilation=True),
17+
decode_head=dict(
18+
type='EMAHead',
19+
in_channels=2048,
20+
in_index=3,
21+
channels=256,
22+
ema_channels=512,
23+
num_bases=64,
24+
num_stages=3,
25+
momentum=0.1,
26+
dropout_ratio=0.1,
27+
num_classes=19,
28+
norm_cfg=norm_cfg,
29+
align_corners=False,
30+
loss_decode=dict(
31+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
32+
auxiliary_head=dict(
33+
type='FCNHead',
34+
in_channels=1024,
35+
in_index=2,
36+
channels=256,
37+
num_convs=1,
38+
concat_input=False,
39+
dropout_ratio=0.1,
40+
num_classes=19,
41+
norm_cfg=norm_cfg,
42+
align_corners=False,
43+
loss_decode=dict(
44+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)))
45+
# model training and testing settings
46+
train_cfg = dict()
47+
test_cfg = dict(mode='whole')

configs/emanet/README.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Expectation-Maximization Attention Networks for Semantic Segmentation
2+
3+
## Introduction
4+
```
5+
@inproceedings{li2019expectation,
6+
title={Expectation-maximization attention networks for semantic segmentation},
7+
author={Li, Xia and Zhong, Zhisheng and Wu, Jianlong and Yang, Yibo and Lin, Zhouchen and Liu, Hong},
8+
booktitle={Proceedings of the IEEE International Conference on Computer Vision},
9+
pages={9167--9176},
10+
year={2019}
11+
}
12+
```
13+
14+
## Results and models
15+
16+
### Cityscapes
17+
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | download |
18+
|--------|----------|-----------|--------:|---------:|----------------|------:|---------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
19+
| EMANet | R-50-D8 | 512x1024 | 80000 | 5.4 | 4.58 | 77.59 | 79.44 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/emanet/emanet_r50-d8_512x1024_80k_cityscapes/emanet_r50-d8_512x1024_80k_cityscapes_20200901_100301-c43fcef1.pth) &#124; [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/emanet/emanet_r50-d8_512x1024_80k_cityscapes/emanet_r50-d8_512x1024_80k_cityscapes-20200901_100301.log.json) |
20+
| EMANet | R-101-D8 | 512x1024 | 80000 | 6.2 | 2.87 | 79.10 | 81.21 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/emanet/emanet_r101-d8_512x1024_80k_cityscapes/emanet_r101-d8_512x1024_80k_cityscapes_20200901_100301-2d970745.pth) &#124; [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/emanet/emanet_r101-d8_512x1024_80k_cityscapes/emanet_r101-d8_512x1024_80k_cityscapes-20200901_100301.log.json) |
21+
| EMANet | R-50-D8 | 769x769 | 80000 | 8.9 | 1.97 | 79.33 | 80.49 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/emanet/emanet_r50-d8_769x769_80k_cityscapes/emanet_r50-d8_769x769_80k_cityscapes_20200901_100301-16f8de52.pth) &#124; [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/emanet/emanet_r50-d8_769x769_80k_cityscapes/emanet_r50-d8_769x769_80k_cityscapes-20200901_100301.log.json) |
22+
| EMANet | R-101-D8 | 769x769 | 80000 | 10.1 | 1.22 | 79.62 | 81.00 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/emanet/emanet_r101-d8_769x769_80k_cityscapes/emanet_r101-d8_769x769_80k_cityscapes_20200901_100301-47a324ce.pth) &#124; [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/emanet/emanet_r101-d8_769x769_80k_cityscapes/emanet_r101-d8_769x769_80k_cityscapes-20200901_100301.log.json) |
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
_base_ = './emanet_r50-d8_512x1024_80k_cityscapes.py'
2+
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
_base_ = './emanet_r50-d8_769x769_80k_cityscapes.py'
2+
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
_base_ = [
2+
'../_base_/models/emanet_r50-d8.py', '../_base_/datasets/cityscapes.py',
3+
'../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py'
4+
]
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
_base_ = [
2+
'../_base_/models/emanet_r50-d8.py',
3+
'../_base_/datasets/cityscapes_769x769.py', '../_base_/default_runtime.py',
4+
'../_base_/schedules/schedule_80k.py'
5+
]
6+
model = dict(
7+
decode_head=dict(align_corners=True),
8+
auxiliary_head=dict(align_corners=True))
9+
test_cfg = dict(mode='slide', crop_size=(769, 769), stride=(513, 513))

mmseg/models/decode_heads/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from .aspp_head import ASPPHead
33
from .cc_head import CCHead
44
from .da_head import DAHead
5+
from .ema_head import EMAHead
56
from .enc_head import EncHead
67
from .fcn_head import FCNHead
78
from .fpn_head import FPNHead
@@ -17,5 +18,5 @@
1718
__all__ = [
1819
'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead',
1920
'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead',
20-
'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead'
21+
'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead'
2122
]

mmseg/models/decode_heads/ema_head.py

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
import math
2+
3+
import torch
4+
import torch.distributed as dist
5+
import torch.nn as nn
6+
import torch.nn.functional as F
7+
from mmcv.cnn import ConvModule
8+
9+
from ..builder import HEADS
10+
from .decode_head import BaseDecodeHead
11+
12+
13+
def reduce_mean(tensor):
14+
"""Reduce mean when distributed training."""
15+
if not (dist.is_available() and dist.is_initialized()):
16+
return tensor
17+
tensor = tensor.clone()
18+
dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM)
19+
return tensor
20+
21+
22+
class EMAModule(nn.Module):
23+
"""Expectation Maximization Attention Module used in EMANet.
24+
25+
Args:
26+
channels (int): Channels of the whole module.
27+
num_bases (int): Number of bases.
28+
num_stages (int): Number of the EM iterations.
29+
"""
30+
31+
def __init__(self, channels, num_bases, num_stages, momentum):
32+
super(EMAModule, self).__init__()
33+
assert num_stages >= 1, 'num_stages must be at least 1!'
34+
self.num_bases = num_bases
35+
self.num_stages = num_stages
36+
self.momentum = momentum
37+
38+
bases = torch.zeros(1, channels, self.num_bases)
39+
bases.normal_(0, math.sqrt(2. / self.num_bases))
40+
# [1, channels, num_bases]
41+
bases = F.normalize(bases, dim=1, p=2)
42+
self.register_buffer('bases', bases)
43+
44+
def forward(self, feats):
45+
"""Forward function."""
46+
batch_size, channels, height, width = feats.size()
47+
# [batch_size, channels, height*width]
48+
feats = feats.view(batch_size, channels, height * width)
49+
# [batch_size, channels, num_bases]
50+
bases = self.bases.repeat(batch_size, 1, 1)
51+
52+
with torch.no_grad():
53+
for i in range(self.num_stages):
54+
# [batch_size, height*width, num_bases]
55+
attention = torch.einsum('bcn,bck->bnk', feats, bases)
56+
attention = F.softmax(attention, dim=2)
57+
# l1 norm
58+
attention_normed = F.normalize(attention, dim=1, p=1)
59+
# [batch_size, channels, num_bases]
60+
bases = torch.einsum('bcn,bnk->bck', feats, attention_normed)
61+
# l2 norm
62+
bases = F.normalize(bases, dim=1, p=2)
63+
64+
feats_recon = torch.einsum('bck,bnk->bcn', bases, attention)
65+
feats_recon = feats_recon.view(batch_size, channels, height, width)
66+
67+
if self.training:
68+
bases = bases.mean(dim=0, keepdim=True)
69+
bases = reduce_mean(bases)
70+
# l2 norm
71+
bases = F.normalize(bases, dim=1, p=2)
72+
self.bases = (1 -
73+
self.momentum) * self.bases + self.momentum * bases
74+
75+
return feats_recon
76+
77+
78+
@HEADS.register_module()
79+
class EMAHead(BaseDecodeHead):
80+
"""Expectation Maximization Attention Networks for Semantic Segmentation.
81+
82+
This head is the implementation of `EMANet
83+
<https://arxiv.org/abs/1907.13426>`_.
84+
85+
Args:
86+
ema_channels (int): EMA module channels
87+
num_bases (int): Number of bases.
88+
num_stages (int): Number of the EM iterations.
89+
concat_input (bool): Whether concat the input and output of convs
90+
before classification layer. Default: True
91+
momentum (float): Momentum to update the base. Default: 0.1.
92+
"""
93+
94+
def __init__(self,
95+
ema_channels,
96+
num_bases,
97+
num_stages,
98+
concat_input=True,
99+
momentum=0.1,
100+
**kwargs):
101+
super(EMAHead, self).__init__(**kwargs)
102+
self.ema_channels = ema_channels
103+
self.num_bases = num_bases
104+
self.num_stages = num_stages
105+
self.concat_input = concat_input
106+
self.momentum = momentum
107+
self.ema_module = EMAModule(self.ema_channels, self.num_bases,
108+
self.num_stages, self.momentum)
109+
110+
self.ema_in_conv = ConvModule(
111+
self.in_channels,
112+
self.ema_channels,
113+
3,
114+
padding=1,
115+
conv_cfg=self.conv_cfg,
116+
norm_cfg=self.norm_cfg,
117+
act_cfg=self.act_cfg)
118+
# project (0, inf) -> (-inf, inf)
119+
self.ema_mid_conv = ConvModule(
120+
self.ema_channels,
121+
self.ema_channels,
122+
1,
123+
conv_cfg=self.conv_cfg,
124+
norm_cfg=None,
125+
act_cfg=None)
126+
for param in self.ema_mid_conv.parameters():
127+
param.requires_grad = False
128+
129+
self.ema_out_conv = ConvModule(
130+
self.ema_channels,
131+
self.ema_channels,
132+
1,
133+
conv_cfg=self.conv_cfg,
134+
norm_cfg=self.norm_cfg,
135+
act_cfg=None)
136+
self.bottleneck = ConvModule(
137+
self.ema_channels,
138+
self.channels,
139+
3,
140+
padding=1,
141+
conv_cfg=self.conv_cfg,
142+
norm_cfg=self.norm_cfg,
143+
act_cfg=self.act_cfg)
144+
if self.concat_input:
145+
self.conv_cat = ConvModule(
146+
self.in_channels + self.channels,
147+
self.channels,
148+
kernel_size=3,
149+
padding=1,
150+
conv_cfg=self.conv_cfg,
151+
norm_cfg=self.norm_cfg,
152+
act_cfg=self.act_cfg)
153+
154+
def forward(self, inputs):
155+
"""Forward function."""
156+
x = self._transform_inputs(inputs)
157+
feats = self.ema_in_conv(x)
158+
identity = feats
159+
feats = self.ema_mid_conv(feats)
160+
recon = self.ema_module(feats)
161+
recon = F.relu(recon, inplace=True)
162+
recon = self.ema_out_conv(recon)
163+
output = F.relu(identity + recon, inplace=True)
164+
output = self.bottleneck(output)
165+
if self.concat_input:
166+
output = self.conv_cat(torch.cat([x, output], dim=1))
167+
output = self.cls_seg(output)
168+
return output

tests/test_models/test_forward.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,11 @@ def test_mobilenet_v2_forward():
162162
'mobilenet_v2/pspnet_m-v2-d8_512x1024_80k_cityscapes.py')
163163

164164

165+
def test_emanet_forward():
166+
_test_encoder_decoder_forward(
167+
'emanet/emanet_r50-d8_512x1024_80k_cityscapes.py')
168+
169+
165170
def get_world_size(process_group):
166171

167172
return 1

tests/test_models/test_heads.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
from mmcv.utils.parrots_wrapper import SyncBatchNorm
77

88
from mmseg.models.decode_heads import (ANNHead, ASPPHead, CCHead, DAHead,
9-
DepthwiseSeparableASPPHead, EncHead,
10-
FCNHead, GCHead, NLHead, OCRHead,
11-
PSAHead, PSPHead, UPerHead)
9+
DepthwiseSeparableASPPHead, EMAHead,
10+
EncHead, FCNHead, GCHead, NLHead,
11+
OCRHead, PSAHead, PSPHead, UPerHead)
1212
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
1313

1414

@@ -539,3 +539,21 @@ def test_dw_aspp_head():
539539
assert head.aspp_modules[2].depthwise_conv.dilation == (24, 24)
540540
outputs = head(inputs)
541541
assert outputs.shape == (1, head.num_classes, 45, 45)
542+
543+
544+
def test_emanet_head():
545+
head = EMAHead(
546+
in_channels=32,
547+
ema_channels=24,
548+
channels=16,
549+
num_stages=3,
550+
num_bases=16,
551+
num_classes=19)
552+
for param in head.ema_mid_conv.parameters():
553+
assert not param.requires_grad
554+
assert hasattr(head, 'ema_module')
555+
inputs = [torch.randn(1, 32, 45, 45)]
556+
if torch.cuda.is_available():
557+
head, inputs = to_cuda(head, inputs)
558+
outputs = head(inputs)
559+
assert outputs.shape == (1, head.num_classes, 45, 45)

0 commit comments

Comments
 (0)