Skip to content

Commit f86c24d

Browse files
authored
[Enhance] Refactor inverted residual (open-mmlab#164)
* [Enhance] Unifed InvertedResidual in MobileNetV2 and FastSCNN * [Enhance] Unifed InvertedResidual in MobileNetV2 and FastSCNN
1 parent 924571e commit f86c24d

File tree

8 files changed

+50
-189
lines changed

8 files changed

+50
-189
lines changed

configs/fastscnn/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,4 @@
1515
### Cityscapes
1616
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | download |
1717
|------------|-----------|-----------|--------:|----------|----------------|------:|---------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
18-
| Fast-SCNN | Fast-SCNN | 512x1024 | 80000 | 8.4 | 63.61 | 69.06 | - | [model](https://openmmlab.oss-cn-hangzhou.aliyuncs.com/mmsegmentation/v0.5/fast_scnn/fast_scnn_4x8_80k_lr0.12_cityscapes-cae6c46a.pth) | [log](https://openmmlab.oss-cn-hangzhou.aliyuncs.com/mmsegmentation/v0.5/fast_scnn/fast_scnn_4x8_80k_lr0.12_cityscapes-20200807_165744.log.json) |
18+
| Fast-SCNN | Fast-SCNN | 512x1024 | 80000 | 8.4 | 63.61 | 69.06 | - | [model](https://download.openmmlab.com/mmsegmentation/v0.5/fast_scnn/fast_scnn_4x8_80k_lr0.12_cityscapes-f5096c79.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/fast_scnn/fast_scnn_4x8_80k_lr0.12_cityscapes-20200807_165744.log.json) |

configs/fastscnn/fast_scnn_4x8_80k_lr0.12_pascal.py

Lines changed: 0 additions & 70 deletions
This file was deleted.

mmseg/models/backbones/fast_scnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66

77
from mmseg.models.decode_heads.psp_head import PPM
88
from mmseg.ops import resize
9-
from mmseg.utils import InvertedResidual
109
from ..builder import BACKBONES
10+
from ..utils.inverted_residual import InvertedResidual
1111

1212

1313
class LearningToDownsample(nn.Module):

mmseg/models/backbones/mobilenet_v2.py

Lines changed: 1 addition & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -1,102 +1,12 @@
11
import logging
22

33
import torch.nn as nn
4-
import torch.utils.checkpoint as cp
54
from mmcv.cnn import ConvModule, constant_init, kaiming_init
65
from mmcv.runner import load_checkpoint
76
from torch.nn.modules.batchnorm import _BatchNorm
87

98
from ..builder import BACKBONES
10-
from ..utils import make_divisible
11-
12-
13-
class InvertedResidual(nn.Module):
14-
"""InvertedResidual block for MobileNetV2.
15-
16-
Args:
17-
in_channels (int): The input channels of the InvertedResidual block.
18-
out_channels (int): The output channels of the InvertedResidual block.
19-
stride (int): Stride of the middle (first) 3x3 convolution.
20-
expand_ratio (int): Adjusts number of channels of the hidden layer
21-
in InvertedResidual by this amount.
22-
dilation (int): Dilation rate of depthwise conv. Default: 1
23-
conv_cfg (dict): Config dict for convolution layer.
24-
Default: None, which means using conv2d.
25-
norm_cfg (dict): Config dict for normalization layer.
26-
Default: dict(type='BN').
27-
act_cfg (dict): Config dict for activation layer.
28-
Default: dict(type='ReLU6').
29-
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
30-
memory while slowing down the training speed. Default: False.
31-
32-
Returns:
33-
Tensor: The output tensor
34-
"""
35-
36-
def __init__(self,
37-
in_channels,
38-
out_channels,
39-
stride,
40-
expand_ratio,
41-
dilation=1,
42-
conv_cfg=None,
43-
norm_cfg=dict(type='BN'),
44-
act_cfg=dict(type='ReLU6'),
45-
with_cp=False):
46-
super(InvertedResidual, self).__init__()
47-
self.stride = stride
48-
assert stride in [1, 2], f'stride must in [1, 2]. ' \
49-
f'But received {stride}.'
50-
self.with_cp = with_cp
51-
self.use_res_connect = self.stride == 1 and in_channels == out_channels
52-
hidden_dim = int(round(in_channels * expand_ratio))
53-
54-
layers = []
55-
if expand_ratio != 1:
56-
layers.append(
57-
ConvModule(
58-
in_channels=in_channels,
59-
out_channels=hidden_dim,
60-
kernel_size=1,
61-
conv_cfg=conv_cfg,
62-
norm_cfg=norm_cfg,
63-
act_cfg=act_cfg))
64-
layers.extend([
65-
ConvModule(
66-
in_channels=hidden_dim,
67-
out_channels=hidden_dim,
68-
kernel_size=3,
69-
stride=stride,
70-
padding=dilation,
71-
dilation=dilation,
72-
groups=hidden_dim,
73-
conv_cfg=conv_cfg,
74-
norm_cfg=norm_cfg,
75-
act_cfg=act_cfg),
76-
ConvModule(
77-
in_channels=hidden_dim,
78-
out_channels=out_channels,
79-
kernel_size=1,
80-
conv_cfg=conv_cfg,
81-
norm_cfg=norm_cfg,
82-
act_cfg=None)
83-
])
84-
self.conv = nn.Sequential(*layers)
85-
86-
def forward(self, x):
87-
88-
def _inner_forward(x):
89-
if self.use_res_connect:
90-
return x + self.conv(x)
91-
else:
92-
return self.conv(x)
93-
94-
if self.with_cp and x.requires_grad:
95-
out = cp.checkpoint(_inner_forward, x)
96-
else:
97-
out = _inner_forward(x)
98-
99-
return out
9+
from ..utils import InvertedResidual, make_divisible
10010

10111

10212
@BACKBONES.register_module()

mmseg/models/utils/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
from .inverted_residual import InvertedResidual
12
from .make_divisible import make_divisible
23
from .res_layer import ResLayer
34
from .self_attention_block import SelfAttentionBlock
45

5-
__all__ = ['ResLayer', 'SelfAttentionBlock', 'make_divisible']
6+
__all__ = [
7+
'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual'
8+
]
Lines changed: 41 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,29 @@
1-
from mmcv.cnn import ConvModule, build_norm_layer
2-
from torch import nn
1+
from mmcv.cnn import ConvModule
2+
from torch import nn as nn
3+
from torch.utils import checkpoint as cp
34

45

56
class InvertedResidual(nn.Module):
6-
"""Inverted residual module.
7+
"""InvertedResidual block for MobileNetV2.
78
89
Args:
910
in_channels (int): The input channels of the InvertedResidual block.
1011
out_channels (int): The output channels of the InvertedResidual block.
1112
stride (int): Stride of the middle (first) 3x3 convolution.
12-
expand_ratio (int): adjusts number of channels of the hidden layer
13+
expand_ratio (int): Adjusts number of channels of the hidden layer
1314
in InvertedResidual by this amount.
15+
dilation (int): Dilation rate of depthwise conv. Default: 1
1416
conv_cfg (dict): Config dict for convolution layer.
1517
Default: None, which means using conv2d.
1618
norm_cfg (dict): Config dict for normalization layer.
1719
Default: dict(type='BN').
1820
act_cfg (dict): Config dict for activation layer.
1921
Default: dict(type='ReLU6').
22+
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
23+
memory while slowing down the training speed. Default: False.
24+
25+
Returns:
26+
Tensor: The output tensor
2027
"""
2128

2229
def __init__(self,
@@ -27,47 +34,59 @@ def __init__(self,
2734
dilation=1,
2835
conv_cfg=None,
2936
norm_cfg=dict(type='BN'),
30-
act_cfg=dict(type='ReLU6')):
37+
act_cfg=dict(type='ReLU6'),
38+
with_cp=False):
3139
super(InvertedResidual, self).__init__()
3240
self.stride = stride
33-
assert stride in [1, 2]
34-
41+
assert stride in [1, 2], f'stride must in [1, 2]. ' \
42+
f'But received {stride}.'
43+
self.with_cp = with_cp
44+
self.use_res_connect = self.stride == 1 and in_channels == out_channels
3545
hidden_dim = int(round(in_channels * expand_ratio))
36-
self.use_res_connect = self.stride == 1 \
37-
and in_channels == out_channels
3846

3947
layers = []
4048
if expand_ratio != 1:
41-
# pw
4249
layers.append(
4350
ConvModule(
44-
in_channels,
45-
hidden_dim,
51+
in_channels=in_channels,
52+
out_channels=hidden_dim,
4653
kernel_size=1,
4754
conv_cfg=conv_cfg,
4855
norm_cfg=norm_cfg,
4956
act_cfg=act_cfg))
5057
layers.extend([
51-
# dw
5258
ConvModule(
53-
hidden_dim,
54-
hidden_dim,
59+
in_channels=hidden_dim,
60+
out_channels=hidden_dim,
5561
kernel_size=3,
56-
padding=dilation,
5762
stride=stride,
63+
padding=dilation,
5864
dilation=dilation,
5965
groups=hidden_dim,
6066
conv_cfg=conv_cfg,
6167
norm_cfg=norm_cfg,
6268
act_cfg=act_cfg),
63-
# pw-linear
64-
nn.Conv2d(hidden_dim, out_channels, 1, 1, 0, bias=False),
65-
build_norm_layer(norm_cfg, out_channels)[1],
69+
ConvModule(
70+
in_channels=hidden_dim,
71+
out_channels=out_channels,
72+
kernel_size=1,
73+
conv_cfg=conv_cfg,
74+
norm_cfg=norm_cfg,
75+
act_cfg=None)
6676
])
6777
self.conv = nn.Sequential(*layers)
6878

6979
def forward(self, x):
70-
if self.use_res_connect:
71-
return x + self.conv(x)
80+
81+
def _inner_forward(x):
82+
if self.use_res_connect:
83+
return x + self.conv(x)
84+
else:
85+
return self.conv(x)
86+
87+
if self.with_cp and x.requires_grad:
88+
out = cp.checkpoint(_inner_forward, x)
7289
else:
73-
return self.conv(x)
90+
out = _inner_forward(x)
91+
92+
return out

mmseg/utils/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from .collect_env import collect_env
2-
from .inverted_residual_module import InvertedResidual
32
from .logger import get_root_logger
43

5-
__all__ = ['get_root_logger', 'collect_env', 'InvertedResidual']
4+
__all__ = ['get_root_logger', 'collect_env']

tests/test_utils/test_inverted_residual_module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pytest
22
import torch
33

4-
from mmseg.utils import InvertedResidual
4+
from mmseg.models.utils import InvertedResidual
55

66

77
def test_inv_residual():

0 commit comments

Comments
 (0)