44The train script of the model is similar to that of MobileNetV3
55Original model: https://github.com/huawei-noah/CV-backbones/tree/master/ghostnet_pytorch
66"""
7+ import math
8+ from functools import partial
9+
710import torch
811import torch .nn as nn
912import torch .nn .functional as F
10- import math
13+
1114
1215from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
13- from .layers import SelectAdaptivePool2d
16+ from .layers import SelectAdaptivePool2d , Linear , hard_sigmoid
17+ from .efficientnet_blocks import SqueezeExcite , ConvBnAct , make_divisible
1418from .helpers import build_model_with_cfg
1519from .registry import register_model
1620
@@ -36,70 +40,15 @@ def _cfg(url='', **kwargs):
3640}
3741
3842
39- def _make_divisible (v , divisor , min_value = None ):
40- """
41- This function is taken from the original tf repo.
42- It ensures that all layers have a channel number that is divisible by 8
43- It can be seen here:
44- https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
45- """
46- if min_value is None :
47- min_value = divisor
48- new_v = max (min_value , int (v + divisor / 2 ) // divisor * divisor )
49- # Make sure that round down does not go down by more than 10%.
50- if new_v < 0.9 * v :
51- new_v += divisor
52- return new_v
53-
54-
55- def hard_sigmoid (x , inplace : bool = False ):
56- if inplace :
57- return x .add_ (3. ).clamp_ (0. , 6. ).div_ (6. )
58- else :
59- return F .relu6 (x + 3. ) / 6.
60-
61-
62- class SqueezeExcite (nn .Module ):
63- def __init__ (self , in_chs , se_ratio = 0.25 , reduced_base_chs = None ,
64- act_layer = nn .ReLU , gate_fn = hard_sigmoid , divisor = 4 , ** _ ):
65- super (SqueezeExcite , self ).__init__ ()
66- self .gate_fn = gate_fn
67- reduced_chs = _make_divisible ((reduced_base_chs or in_chs ) * se_ratio , divisor )
68- self .avg_pool = nn .AdaptiveAvgPool2d (1 )
69- self .conv_reduce = nn .Conv2d (in_chs , reduced_chs , 1 , bias = True )
70- self .act1 = act_layer (inplace = True )
71- self .conv_expand = nn .Conv2d (reduced_chs , in_chs , 1 , bias = True )
72-
73- def forward (self , x ):
74- x_se = self .avg_pool (x )
75- x_se = self .conv_reduce (x_se )
76- x_se = self .act1 (x_se )
77- x_se = self .conv_expand (x_se )
78- x = x * self .gate_fn (x_se )
79- return x
80-
81-
82- class ConvBnAct (nn .Module ):
83- def __init__ (self , in_chs , out_chs , kernel_size ,
84- stride = 1 , act_layer = nn .ReLU ):
85- super (ConvBnAct , self ).__init__ ()
86- self .conv = nn .Conv2d (in_chs , out_chs , kernel_size , stride , kernel_size // 2 , bias = False )
87- self .bn1 = nn .BatchNorm2d (out_chs )
88- self .act1 = act_layer (inplace = True )
89-
90- def forward (self , x ):
91- x = self .conv (x )
92- x = self .bn1 (x )
93- x = self .act1 (x )
94- return x
43+ _SE_LAYER = partial (SqueezeExcite , gate_fn = hard_sigmoid , divisor = 4 )
9544
9645
9746class GhostModule (nn .Module ):
9847 def __init__ (self , inp , oup , kernel_size = 1 , ratio = 2 , dw_size = 3 , stride = 1 , relu = True ):
9948 super (GhostModule , self ).__init__ ()
10049 self .oup = oup
10150 init_channels = math .ceil (oup / ratio )
102- new_channels = init_channels * (ratio - 1 )
51+ new_channels = init_channels * (ratio - 1 )
10352
10453 self .primary_conv = nn .Sequential (
10554 nn .Conv2d (inp , init_channels , kernel_size , stride , kernel_size // 2 , bias = False ),
@@ -116,8 +65,8 @@ def __init__(self, inp, oup, kernel_size=1, ratio=2, dw_size=3, stride=1, relu=T
11665 def forward (self , x ):
11766 x1 = self .primary_conv (x )
11867 x2 = self .cheap_operation (x1 )
119- out = torch .cat ([x1 ,x2 ], dim = 1 )
120- return out [:,:self .oup ,:, :]
68+ out = torch .cat ([x1 , x2 ], dim = 1 )
69+ return out [:, :self .oup , :, :]
12170
12271
12372class GhostBottleneck (nn .Module ):
@@ -134,27 +83,28 @@ def __init__(self, in_chs, mid_chs, out_chs, dw_kernel_size=3,
13483
13584 # Depth-wise convolution
13685 if self .stride > 1 :
137- self .conv_dw = nn .Conv2d (mid_chs , mid_chs , dw_kernel_size , stride = stride ,
138- padding = ( dw_kernel_size - 1 ) // 2 ,
139- groups = mid_chs , bias = False )
86+ self .conv_dw = nn .Conv2d (
87+ mid_chs , mid_chs , dw_kernel_size , stride = stride ,
88+ padding = ( dw_kernel_size - 1 ) // 2 , groups = mid_chs , bias = False )
14089 self .bn_dw = nn .BatchNorm2d (mid_chs )
90+ else :
91+ self .conv_dw = None
92+ self .bn_dw = None
14193
14294 # Squeeze-and-excitation
143- if has_se :
144- self .se = SqueezeExcite (mid_chs , se_ratio = se_ratio )
145- else :
146- self .se = None
95+ self .se = _SE_LAYER (mid_chs , se_ratio = se_ratio ) if has_se else None
14796
14897 # Point-wise linear projection
14998 self .ghost2 = GhostModule (mid_chs , out_chs , relu = False )
15099
151100 # shortcut
152- if ( in_chs == out_chs and self .stride == 1 ) :
101+ if in_chs == out_chs and self .stride == 1 :
153102 self .shortcut = nn .Sequential ()
154103 else :
155104 self .shortcut = nn .Sequential (
156- nn .Conv2d (in_chs , in_chs , dw_kernel_size , stride = stride ,
157- padding = (dw_kernel_size - 1 )// 2 , groups = in_chs , bias = False ),
105+ nn .Conv2d (
106+ in_chs , in_chs , dw_kernel_size , stride = stride ,
107+ padding = (dw_kernel_size - 1 )// 2 , groups = in_chs , bias = False ),
158108 nn .BatchNorm2d (in_chs ),
159109 nn .Conv2d (in_chs , out_chs , 1 , stride = 1 , padding = 0 , bias = False ),
160110 nn .BatchNorm2d (out_chs ),
@@ -168,7 +118,7 @@ def forward(self, x):
168118 x = self .ghost1 (x )
169119
170120 # Depth-wise convolution
171- if self .stride > 1 :
121+ if self .conv_dw is not None :
172122 x = self .conv_dw (x )
173123 x = self .bn_dw (x )
174124
@@ -184,52 +134,55 @@ def forward(self, x):
184134
185135
186136class GhostNet (nn .Module ):
187- def __init__ (self , cfgs , num_classes = 1000 , width = 1.0 , dropout = 0.2 , in_chans = 3 ):
137+ def __init__ (self , cfgs , num_classes = 1000 , width = 1.0 , dropout = 0.2 , in_chans = 3 , output_stride = 32 ):
188138 super (GhostNet , self ).__init__ ()
189139 # setting of inverted residual blocks
140+ assert output_stride == 32 , 'only output_stride==32 is valid, dilation not supported'
190141 self .cfgs = cfgs
191142 self .num_classes = num_classes
192143 self .dropout = dropout
193144 self .feature_info = []
194145
195146 # building first layer
196- output_channel = _make_divisible (16 * width , 4 )
197- self .conv_stem = nn .Conv2d (in_chans , output_channel , 3 , 2 , 1 , bias = False )
198- self .feature_info .append (dict (num_chs = output_channel , reduction = 2 , module = f'conv_stem' ))
199- self .bn1 = nn .BatchNorm2d (output_channel )
147+ stem_chs = make_divisible (16 * width , 4 )
148+ self .conv_stem = nn .Conv2d (in_chans , stem_chs , 3 , 2 , 1 , bias = False )
149+ self .feature_info .append (dict (num_chs = stem_chs , reduction = 2 , module = f'conv_stem' ))
150+ self .bn1 = nn .BatchNorm2d (stem_chs )
200151 self .act1 = nn .ReLU (inplace = True )
201- input_channel = output_channel
152+ prev_chs = stem_chs
202153
203154 # building inverted residual blocks
204155 stages = nn .ModuleList ([])
205156 block = GhostBottleneck
206157 stage_idx = 0
158+ net_stride = 2
207159 for cfg in self .cfgs :
208160 layers = []
161+ s = 1
209162 for k , exp_size , c , se_ratio , s in cfg :
210- output_channel = _make_divisible (c * width , 4 )
211- hidden_channel = _make_divisible (exp_size * width , 4 )
212- layers .append (block (input_channel , hidden_channel , output_channel , k , s ,
213- se_ratio = se_ratio ))
214- input_channel = output_channel
163+ out_chs = make_divisible (c * width , 4 )
164+ mid_chs = make_divisible (exp_size * width , 4 )
165+ layers .append (block (prev_chs , mid_chs , out_chs , k , s , se_ratio = se_ratio ))
166+ prev_chs = out_chs
215167 if s > 1 :
216- self .feature_info .append (dict (num_chs = output_channel , reduction = 2 ** (stage_idx + 2 ),
217- module = f'blocks.{ stage_idx } ' ))
168+ net_stride *= 2
169+ self .feature_info .append (dict (
170+ num_chs = prev_chs , reduction = net_stride , module = f'blocks.{ stage_idx } ' ))
218171 stages .append (nn .Sequential (* layers ))
219172 stage_idx += 1
220173
221- output_channel = _make_divisible (exp_size * width , 4 )
222- stages .append (nn .Sequential (ConvBnAct (input_channel , output_channel , 1 )))
223- self .pool_dim = input_channel = output_channel
174+ out_chs = make_divisible (exp_size * width , 4 )
175+ stages .append (nn .Sequential (ConvBnAct (prev_chs , out_chs , 1 )))
176+ self .pool_dim = prev_chs = out_chs
224177
225178 self .blocks = nn .Sequential (* stages )
226179
227180 # building last several layers
228- self .num_features = output_channel = 1280
181+ self .num_features = out_chs = 1280
229182 self .global_pool = SelectAdaptivePool2d (pool_type = 'avg' )
230- self .conv_head = nn .Conv2d (input_channel , output_channel , 1 , 1 , 0 , bias = True )
183+ self .conv_head = nn .Conv2d (prev_chs , out_chs , 1 , 1 , 0 , bias = True )
231184 self .act2 = nn .ReLU (inplace = True )
232- self .classifier = nn . Linear (output_channel , num_classes )
185+ self .classifier = Linear (out_chs , num_classes )
233186
234187 def get_classifier (self ):
235188 return self .classifier
0 commit comments