Skip to content

Commit 8ff866d

Browse files
Add "disentangled non-local (DNL) neural networks" [ECCV2020] (open-mmlab#37)
* Add DNLHead * add configs * add weight decay mult * add norm back * Update README.md * matched inference performance * Fixed shape * sep conv_out * no norm * add norm back * complete model zoo * add tests * Add test forward * Add more test Co-authored-by: Jiarui XU <[email protected]>
1 parent dbca8b4 commit 8ff866d

19 files changed

+324
-4
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@ Supported methods:
7272
- [x] [ANN](configs/ann)
7373
- [x] [OCRNet](configs/ocrnet)
7474
- [x] [Fast-SCNN](configs/fastscnn)
75+
- [x] [Semantic FPN](configs/sem_fpn)
76+
- [x] [EMANet](configs/emanet)
77+
- [x] [DNLNet](configs/dnlnet)
7578
- [x] [Mixed Precision (FP16) Training](configs/fp16/README.md)
7679

7780
## Installation

configs/_base_/models/dnl_r50-d8.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# model settings
2+
norm_cfg = dict(type='SyncBN', requires_grad=True)
3+
model = dict(
4+
type='EncoderDecoder',
5+
pretrained='open-mmlab://resnet50_v1c',
6+
backbone=dict(
7+
type='ResNetV1c',
8+
depth=50,
9+
num_stages=4,
10+
out_indices=(0, 1, 2, 3),
11+
dilations=(1, 1, 2, 4),
12+
strides=(1, 2, 1, 1),
13+
norm_cfg=norm_cfg,
14+
norm_eval=False,
15+
style='pytorch',
16+
contract_dilation=True),
17+
decode_head=dict(
18+
type='DNLHead',
19+
in_channels=2048,
20+
in_index=3,
21+
channels=512,
22+
dropout_ratio=0.1,
23+
reduction=2,
24+
use_scale=True,
25+
mode='embedded_gaussian',
26+
num_classes=19,
27+
norm_cfg=norm_cfg,
28+
align_corners=False,
29+
loss_decode=dict(
30+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
31+
auxiliary_head=dict(
32+
type='FCNHead',
33+
in_channels=1024,
34+
in_index=2,
35+
channels=256,
36+
num_convs=1,
37+
concat_input=False,
38+
dropout_ratio=0.1,
39+
num_classes=19,
40+
norm_cfg=norm_cfg,
41+
align_corners=False,
42+
loss_decode=dict(
43+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)))
44+
# model training and testing settings
45+
train_cfg = dict()
46+
test_cfg = dict(mode='whole')

configs/dnlnet/README.md

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Disentangled Non-Local Neural Networks
2+
3+
## Introduction
4+
5+
This example is to reproduce ["Disentangled Non-Local Neural Networks"](https://arxiv.org/abs/2006.06668) for semantic segmentation. It is still in progress.
6+
7+
## Citation
8+
```
9+
@misc{yin2020disentangled,
10+
title={Disentangled Non-Local Neural Networks},
11+
author={Minghao Yin and Zhuliang Yao and Yue Cao and Xiu Li and Zheng Zhang and Stephen Lin and Han Hu},
12+
year={2020},
13+
booktitle={ECCV}
14+
}
15+
```
16+
17+
## Results and models (in progress)
18+
19+
### Cityscapes
20+
21+
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | download |
22+
|--------|----------|-----------|--------:|---------:|----------------|------:|---------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
23+
| dnl | R-50-D8 | 512x1024 | 40000 | 7.3 | 2.56 | 78.61 | - | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r50-d8_512x1024_40k_cityscapes/dnl_r50-d8_512x1024_40k_cityscapes_20200904_233629-53d4ea93.pth) &#124; [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r50-d8_512x1024_40k_cityscapes/dnl_r50-d8_512x1024_40k_cityscapes-20200904_233629.log.json) |
24+
| dnl | R-101-D8 | 512x1024 | 40000 | 10.9 | 1.96 | 78.31 | - | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r101-d8_512x1024_40k_cityscapes/dnl_r101-d8_512x1024_40k_cityscapes_20200904_233629-9928ffef.pth) &#124; [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r101-d8_512x1024_40k_cityscapes/dnl_r101-d8_512x1024_40k_cityscapes-20200904_233629.log.json) |
25+
| dnl | R-50-D8 | 769x769 | 40000 | 9.2 | 1.50 | 78.44 | 80.27 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r50-d8_769x769_40k_cityscapes/dnl_r50-d8_769x769_40k_cityscapes_20200820_232206-0f283785.pth) &#124; [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r50-d8_769x769_40k_cityscapes/dnl_r50-d8_769x769_40k_cityscapes-20200820_232206.log.json) |
26+
| dnl | R-101-D8 | 769x769 | 40000 | 12.6 | 1.02 | 76.39 | 77.77 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r101-d8_769x769_40k_cityscapes/dnl_r101-d8_769x769_40k_cityscapes_20200820_171256-76c596df.pth) &#124; [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r101-d8_769x769_40k_cityscapes/dnl_r101-d8_769x769_40k_cityscapes-20200820_171256.log.json) |
27+
| dnl | R-50-D8 | 512x1024 | 80000 | - | - | 79.33 | - | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r50-d8_512x1024_80k_cityscapes/dnl_r50-d8_512x1024_80k_cityscapes_20200904_233629-58b2f778.pth) &#124; [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r50-d8_512x1024_80k_cityscapes/dnl_r50-d8_512x1024_80k_cityscapes-20200904_233629.log.json) |
28+
| dnl | R-101-D8 | 512x1024 | 80000 | - | - | 80.41 | - | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r101-d8_512x1024_80k_cityscapes/dnl_r101-d8_512x1024_80k_cityscapes_20200904_233629-758e2dd4.pth) &#124; [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r101-d8_512x1024_80k_cityscapes/dnl_r101-d8_512x1024_80k_cityscapes-20200904_233629.log.json) |
29+
| dnl | R-50-D8 | 769x769 | 80000 | - | - | 79.36 | 80.70 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r50-d8_769x769_80k_cityscapes/dnl_r50-d8_769x769_80k_cityscapes_20200820_011925-366bc4c7.pth) &#124; [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r50-d8_769x769_80k_cityscapes/dnl_r50-d8_769x769_80k_cityscapes-20200820_011925.log.json) |
30+
| dnl | R-101-D8 | 769x769 | 80000 | - | - | 79.41 | 80.68 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r101-d8_769x769_80k_cityscapes/dnl_r101-d8_769x769_80k_cityscapes_20200821_051111-95ff84ab.pth) &#124; [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r101-d8_769x769_80k_cityscapes/dnl_r101-d8_769x769_80k_cityscapes-20200821_051111.log.json) |
31+
32+
33+
### ADE20K
34+
35+
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | download |
36+
|--------|----------|-----------|--------:|---------:|----------------|------:|---------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
37+
| DNL | R-50-D8 | 512x512 | 80000 | 8.8 | 20.66 | 41.76 | 42.99 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r50-d8_512x512_80k_ade20k/dnl_r50-d8_512x512_80k_ade20k_20200826_183354-1cf6e0c1.pth) &#124; [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r50-d8_512x512_80k_ade20k/dnl_r50-d8_512x512_80k_ade20k-20200826_183354.log.json) |
38+
| DNL | R-101-D8 | 512x512 | 80000 | 12.8 | 12.54 | 43.76 | 44.91 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r101-d8_512x512_80k_ade20k/dnl_r101-d8_512x512_80k_ade20k_20200826_183354-d820d6ea.pth) &#124; [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r101-d8_512x512_80k_ade20k/dnl_r101-d8_512x512_80k_ade20k-20200826_183354.log.json) |
39+
| DNL | R-50-D8 | 512x512 | 160000 | - | - | 41.87 | 43.01 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r50-d8_512x512_160k_ade20k/dnl_r50-d8_512x512_160k_ade20k_20200826_183350-37837798.pth) &#124; [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r50-d8_512x512_160k_ade20k/dnl_r50-d8_512x512_160k_ade20k-20200826_183350.log.json) |
40+
| DNL | R-101-D8 | 512x512 | 160000 | - | - | 44.25 | 45.78 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r101-d8_512x512_160k_ade20k/dnl_r101-d8_512x512_160k_ade20k_20200826_183350-ed522c61.pth) &#124; [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/dnlnet/dnl_r101-d8_512x512_160k_ade20k/dnl_r101-d8_512x512_160k_ade20k-20200826_183350.log.json) |
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
_base_ = './dnl_r50-d8_512x1024_40k_cityscapes.py'
2+
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
_base_ = './dnl_r50-d8_512x1024_80k_cityscapes.py'
2+
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
_base_ = './dnl_r50-d8_512x512_160k_ade20k.py'
2+
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
_base_ = './dnl_r50-d8_512x512_80k_ade20k.py'
2+
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
_base_ = './dnl_r50-d8_769x769_40k_cityscapes.py'
2+
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
_base_ = './dnl_r50-d8_769x769_80k_cityscapes.py'
2+
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
_base_ = [
2+
'../_base_/models/dnl_r50-d8.py', '../_base_/datasets/cityscapes.py',
3+
'../_base_/default_runtime.py', '../_base_/schedules/schedule_40k.py'
4+
]
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
_base_ = [
2+
'../_base_/models/dnl_r50-d8.py', '../_base_/datasets/cityscapes.py',
3+
'../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py'
4+
]
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
_base_ = [
2+
'../_base_/models/dnl_r50-d8.py', '../_base_/datasets/ade20k.py',
3+
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
4+
]
5+
model = dict(
6+
decode_head=dict(num_classes=150), auxiliary_head=dict(num_classes=150))
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
_base_ = [
2+
'../_base_/models/dnl_r50-d8.py', '../_base_/datasets/ade20k.py',
3+
'../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py'
4+
]
5+
model = dict(
6+
decode_head=dict(num_classes=150), auxiliary_head=dict(num_classes=150))
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
_base_ = [
2+
'../_base_/models/dnl_r50-d8.py',
3+
'../_base_/datasets/cityscapes_769x769.py', '../_base_/default_runtime.py',
4+
'../_base_/schedules/schedule_40k.py'
5+
]
6+
model = dict(
7+
decode_head=dict(align_corners=True),
8+
auxiliary_head=dict(align_corners=True))
9+
test_cfg = dict(mode='slide', crop_size=(769, 769), stride=(513, 513))
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
_base_ = [
2+
'../_base_/models/dnl_r50-d8.py',
3+
'../_base_/datasets/cityscapes_769x769.py', '../_base_/default_runtime.py',
4+
'../_base_/schedules/schedule_80k.py'
5+
]
6+
model = dict(
7+
decode_head=dict(align_corners=True),
8+
auxiliary_head=dict(align_corners=True))
9+
test_cfg = dict(mode='slide', crop_size=(769, 769), stride=(513, 513))
10+
optimizer = dict(
11+
paramwise_cfg=dict(
12+
custom_keys=dict(theta=dict(wd_mult=0.), phi=dict(wd_mult=0.))))

mmseg/models/decode_heads/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from .aspp_head import ASPPHead
33
from .cc_head import CCHead
44
from .da_head import DAHead
5+
from .dnl_head import DNLHead
56
from .ema_head import EMAHead
67
from .enc_head import EncHead
78
from .fcn_head import FCNHead
@@ -18,5 +19,5 @@
1819
__all__ = [
1920
'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead',
2021
'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead',
21-
'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead'
22+
'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead'
2223
]

mmseg/models/decode_heads/dnl_head.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
import torch
2+
from mmcv.cnn import NonLocal2d
3+
from torch import nn
4+
5+
from ..builder import HEADS
6+
from .fcn_head import FCNHead
7+
8+
9+
class DisentangledNonLocal2d(NonLocal2d):
10+
"""Disentangled Non-Local Blocks.
11+
12+
Args:
13+
temperature (float): Temperature to adjust attention. Default: 0.05
14+
"""
15+
16+
def __init__(self, *arg, temperature, **kwargs):
17+
super().__init__(*arg, **kwargs)
18+
self.temperature = temperature
19+
self.conv_mask = nn.Conv2d(self.in_channels, 1, kernel_size=1)
20+
21+
def embedded_gaussian(self, theta_x, phi_x):
22+
"""Embedded gaussian with temperature."""
23+
24+
# NonLocal2d pairwise_weight: [N, HxW, HxW]
25+
pairwise_weight = torch.matmul(theta_x, phi_x)
26+
if self.use_scale:
27+
# theta_x.shape[-1] is `self.inter_channels`
28+
pairwise_weight /= theta_x.shape[-1]**0.5
29+
pairwise_weight /= self.temperature
30+
pairwise_weight = pairwise_weight.softmax(dim=-1)
31+
return pairwise_weight
32+
33+
def forward(self, x):
34+
# x: [N, C, H, W]
35+
n = x.size(0)
36+
37+
# g_x: [N, HxW, C]
38+
g_x = self.g(x).view(n, self.inter_channels, -1)
39+
g_x = g_x.permute(0, 2, 1)
40+
41+
# theta_x: [N, HxW, C], phi_x: [N, C, HxW]
42+
if self.mode == 'gaussian':
43+
theta_x = x.view(n, self.in_channels, -1)
44+
theta_x = theta_x.permute(0, 2, 1)
45+
if self.sub_sample:
46+
phi_x = self.phi(x).view(n, self.in_channels, -1)
47+
else:
48+
phi_x = x.view(n, self.in_channels, -1)
49+
elif self.mode == 'concatenation':
50+
theta_x = self.theta(x).view(n, self.inter_channels, -1, 1)
51+
phi_x = self.phi(x).view(n, self.inter_channels, 1, -1)
52+
else:
53+
theta_x = self.theta(x).view(n, self.inter_channels, -1)
54+
theta_x = theta_x.permute(0, 2, 1)
55+
phi_x = self.phi(x).view(n, self.inter_channels, -1)
56+
57+
# subtract mean
58+
theta_x -= theta_x.mean(dim=-2, keepdim=True)
59+
phi_x -= phi_x.mean(dim=-1, keepdim=True)
60+
61+
pairwise_func = getattr(self, self.mode)
62+
# pairwise_weight: [N, HxW, HxW]
63+
pairwise_weight = pairwise_func(theta_x, phi_x)
64+
65+
# y: [N, HxW, C]
66+
y = torch.matmul(pairwise_weight, g_x)
67+
# y: [N, C, H, W]
68+
y = y.permute(0, 2, 1).contiguous().reshape(n, self.inter_channels,
69+
*x.size()[2:])
70+
71+
# unary_mask: [N, 1, HxW]
72+
unary_mask = self.conv_mask(x)
73+
unary_mask = unary_mask.view(n, 1, -1)
74+
unary_mask = unary_mask.softmax(dim=-1)
75+
# unary_x: [N, 1, C]
76+
unary_x = torch.matmul(unary_mask, g_x)
77+
# unary_x: [N, C, 1, 1]
78+
unary_x = unary_x.permute(0, 2, 1).contiguous().reshape(
79+
n, self.inter_channels, 1, 1)
80+
81+
output = x + self.conv_out(y + unary_x)
82+
83+
return output
84+
85+
86+
@HEADS.register_module()
87+
class DNLHead(FCNHead):
88+
"""Disentangled Non-Local Neural Networks.
89+
90+
This head is the implementation of `DNLNet
91+
<https://arxiv.org/abs/2006.06668>`_.
92+
93+
Args:
94+
reduction (int): Reduction factor of projection transform. Default: 2.
95+
use_scale (bool): Whether to scale pairwise_weight by
96+
sqrt(1/inter_channels). Default: False.
97+
mode (str): The nonlocal mode. Options are 'embedded_gaussian',
98+
'dot_product'. Default: 'embedded_gaussian.'.
99+
temperature (float): Temperature to adjust attention. Default: 0.05
100+
"""
101+
102+
def __init__(self,
103+
reduction=2,
104+
use_scale=True,
105+
mode='embedded_gaussian',
106+
temperature=0.05,
107+
**kwargs):
108+
super(DNLHead, self).__init__(num_convs=2, **kwargs)
109+
self.reduction = reduction
110+
self.use_scale = use_scale
111+
self.mode = mode
112+
self.temperature = temperature
113+
self.dnl_block = DisentangledNonLocal2d(
114+
in_channels=self.channels,
115+
reduction=self.reduction,
116+
use_scale=self.use_scale,
117+
conv_cfg=self.conv_cfg,
118+
norm_cfg=self.norm_cfg,
119+
mode=self.mode,
120+
temperature=self.temperature)
121+
122+
def forward(self, inputs):
123+
"""Forward function."""
124+
x = self._transform_inputs(inputs)
125+
output = self.convs[0](x)
126+
output = self.dnl_block(output)
127+
output = self.convs[1](output)
128+
if self.concat_input:
129+
output = self.conv_cat(torch.cat([x, output], dim=1))
130+
output = self.cls_seg(output)
131+
return output

tests/test_models/test_forward.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,11 @@ def test_mobilenet_v2_forward():
162162
'mobilenet_v2/pspnet_m-v2-d8_512x1024_80k_cityscapes.py')
163163

164164

165+
def test_dnlnet_forward():
166+
_test_encoder_decoder_forward(
167+
'dnlnet/dnl_r50-d8_512x1024_40k_cityscapes.py')
168+
169+
165170
def test_emanet_forward():
166171
_test_encoder_decoder_forward(
167172
'emanet/emanet_r50-d8_512x1024_80k_cityscapes.py')

tests/test_models/test_heads.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
from mmcv.utils.parrots_wrapper import SyncBatchNorm
77

88
from mmseg.models.decode_heads import (ANNHead, ASPPHead, CCHead, DAHead,
9-
DepthwiseSeparableASPPHead, EMAHead,
10-
EncHead, FCNHead, GCHead, NLHead,
11-
OCRHead, PSAHead, PSPHead, UPerHead)
9+
DepthwiseSeparableASPPHead, DNLHead,
10+
EMAHead, EncHead, FCNHead, GCHead,
11+
NLHead, OCRHead, PSAHead, PSPHead,
12+
UPerHead)
1213
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
1314

1415

@@ -541,6 +542,46 @@ def test_dw_aspp_head():
541542
assert outputs.shape == (1, head.num_classes, 45, 45)
542543

543544

545+
def test_dnl_head():
546+
# DNL with 'embedded_gaussian' mode
547+
head = DNLHead(in_channels=32, channels=16, num_classes=19)
548+
assert len(head.convs) == 2
549+
assert hasattr(head, 'dnl_block')
550+
assert head.dnl_block.temperature == 0.05
551+
inputs = [torch.randn(1, 32, 45, 45)]
552+
if torch.cuda.is_available():
553+
head, inputs = to_cuda(head, inputs)
554+
outputs = head(inputs)
555+
assert outputs.shape == (1, head.num_classes, 45, 45)
556+
557+
# NonLocal2d with 'dot_product' mode
558+
head = DNLHead(
559+
in_channels=32, channels=16, num_classes=19, mode='dot_product')
560+
inputs = [torch.randn(1, 32, 45, 45)]
561+
if torch.cuda.is_available():
562+
head, inputs = to_cuda(head, inputs)
563+
outputs = head(inputs)
564+
assert outputs.shape == (1, head.num_classes, 45, 45)
565+
566+
# NonLocal2d with 'gaussian' mode
567+
head = DNLHead(
568+
in_channels=32, channels=16, num_classes=19, mode='gaussian')
569+
inputs = [torch.randn(1, 32, 45, 45)]
570+
if torch.cuda.is_available():
571+
head, inputs = to_cuda(head, inputs)
572+
outputs = head(inputs)
573+
assert outputs.shape == (1, head.num_classes, 45, 45)
574+
575+
# NonLocal2d with 'concatenation' mode
576+
head = DNLHead(
577+
in_channels=32, channels=16, num_classes=19, mode='concatenation')
578+
inputs = [torch.randn(1, 32, 45, 45)]
579+
if torch.cuda.is_available():
580+
head, inputs = to_cuda(head, inputs)
581+
outputs = head(inputs)
582+
assert outputs.shape == (1, head.num_classes, 45, 45)
583+
584+
544585
def test_emanet_head():
545586
head = EMAHead(
546587
in_channels=32,

0 commit comments

Comments
 (0)