1- from mmcv .cnn import ConvModule , build_norm_layer
2- from torch import nn
1+ from mmcv .cnn import ConvModule
2+ from torch import nn as nn
3+ from torch .utils import checkpoint as cp
34
45
56class InvertedResidual (nn .Module ):
6- """Inverted residual module .
7+ """InvertedResidual block for MobileNetV2 .
78
89 Args:
910 in_channels (int): The input channels of the InvertedResidual block.
1011 out_channels (int): The output channels of the InvertedResidual block.
1112 stride (int): Stride of the middle (first) 3x3 convolution.
12- expand_ratio (int): adjusts number of channels of the hidden layer
13+ expand_ratio (int): Adjusts number of channels of the hidden layer
1314 in InvertedResidual by this amount.
15+ dilation (int): Dilation rate of depthwise conv. Default: 1
1416 conv_cfg (dict): Config dict for convolution layer.
1517 Default: None, which means using conv2d.
1618 norm_cfg (dict): Config dict for normalization layer.
1719 Default: dict(type='BN').
1820 act_cfg (dict): Config dict for activation layer.
1921 Default: dict(type='ReLU6').
22+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
23+ memory while slowing down the training speed. Default: False.
24+
25+ Returns:
26+ Tensor: The output tensor
2027 """
2128
2229 def __init__ (self ,
@@ -27,47 +34,59 @@ def __init__(self,
2734 dilation = 1 ,
2835 conv_cfg = None ,
2936 norm_cfg = dict (type = 'BN' ),
30- act_cfg = dict (type = 'ReLU6' )):
37+ act_cfg = dict (type = 'ReLU6' ),
38+ with_cp = False ):
3139 super (InvertedResidual , self ).__init__ ()
3240 self .stride = stride
33- assert stride in [1 , 2 ]
34-
41+ assert stride in [1 , 2 ], f'stride must in [1, 2]. ' \
42+ f'But received { stride } .'
43+ self .with_cp = with_cp
44+ self .use_res_connect = self .stride == 1 and in_channels == out_channels
3545 hidden_dim = int (round (in_channels * expand_ratio ))
36- self .use_res_connect = self .stride == 1 \
37- and in_channels == out_channels
3846
3947 layers = []
4048 if expand_ratio != 1 :
41- # pw
4249 layers .append (
4350 ConvModule (
44- in_channels ,
45- hidden_dim ,
51+ in_channels = in_channels ,
52+ out_channels = hidden_dim ,
4653 kernel_size = 1 ,
4754 conv_cfg = conv_cfg ,
4855 norm_cfg = norm_cfg ,
4956 act_cfg = act_cfg ))
5057 layers .extend ([
51- # dw
5258 ConvModule (
53- hidden_dim ,
54- hidden_dim ,
59+ in_channels = hidden_dim ,
60+ out_channels = hidden_dim ,
5561 kernel_size = 3 ,
56- padding = dilation ,
5762 stride = stride ,
63+ padding = dilation ,
5864 dilation = dilation ,
5965 groups = hidden_dim ,
6066 conv_cfg = conv_cfg ,
6167 norm_cfg = norm_cfg ,
6268 act_cfg = act_cfg ),
63- # pw-linear
64- nn .Conv2d (hidden_dim , out_channels , 1 , 1 , 0 , bias = False ),
65- build_norm_layer (norm_cfg , out_channels )[1 ],
69+ ConvModule (
70+ in_channels = hidden_dim ,
71+ out_channels = out_channels ,
72+ kernel_size = 1 ,
73+ conv_cfg = conv_cfg ,
74+ norm_cfg = norm_cfg ,
75+ act_cfg = None )
6676 ])
6777 self .conv = nn .Sequential (* layers )
6878
6979 def forward (self , x ):
70- if self .use_res_connect :
71- return x + self .conv (x )
80+
81+ def _inner_forward (x ):
82+ if self .use_res_connect :
83+ return x + self .conv (x )
84+ else :
85+ return self .conv (x )
86+
87+ if self .with_cp and x .requires_grad :
88+ out = cp .checkpoint (_inner_forward , x )
7289 else :
73- return self .conv (x )
90+ out = _inner_forward (x )
91+
92+ return out
0 commit comments