Skip to content

Commit 1b0c8e7

Browse files
committed
Merge branch 'iamhankai-master'
2 parents e685618 + 2df77ee commit 1b0c8e7

File tree

3 files changed

+280
-3
lines changed

3 files changed

+280
-3
lines changed

tests/test_models.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,8 @@ def test_model_default_cfgs(model_name, batch_size):
116116
model.reset_classifier(0, '') # reset classifier and set global pooling to pass-through
117117
outputs = model.forward(input_tensor)
118118
assert len(outputs.shape) == 4
119-
if not isinstance(model, timm.models.MobileNetV3):
120-
# FIXME mobilenetv3 forward_features vs removed pooling differ
119+
if not isinstance(model, timm.models.MobileNetV3) and not isinstance(model, timm.models.GhostNet):
120+
# FIXME mobilenetv3/ghostnet forward_features vs removed pooling differ
121121
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2]
122122

123123
# check classifier name matches default_cfg
@@ -150,7 +150,7 @@ def test_model_features_pretrained(model_name, batch_size):
150150

151151
EXCLUDE_JIT_FILTERS = [
152152
'*iabn*', 'tresnet*', # models using inplace abn unlikely to ever be scriptable
153-
'dla*', 'hrnet*', # hopefully fix at some point
153+
'dla*', 'hrnet*', 'ghostnet*', # hopefully fix at some point
154154
]
155155

156156

timm/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .dla import *
66
from .dpn import *
77
from .efficientnet import *
8+
from .ghostnet import *
89
from .gluon_resnet import *
910
from .gluon_xception import *
1011
from .hardcorenas import *

timm/models/ghostnet.py

Lines changed: 276 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,276 @@
1+
"""
2+
An implementation of GhostNet Model as defined in:
3+
GhostNet: More Features from Cheap Operations. https://arxiv.org/abs/1911.11907
4+
The train script of the model is similar to that of MobileNetV3
5+
Original model: https://github.com/huawei-noah/CV-backbones/tree/master/ghostnet_pytorch
6+
"""
7+
import math
8+
from functools import partial
9+
10+
import torch
11+
import torch.nn as nn
12+
import torch.nn.functional as F
13+
14+
15+
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
16+
from .layers import SelectAdaptivePool2d, Linear, hard_sigmoid
17+
from .efficientnet_blocks import SqueezeExcite, ConvBnAct, make_divisible
18+
from .helpers import build_model_with_cfg
19+
from .registry import register_model
20+
21+
22+
__all__ = ['GhostNet']
23+
24+
25+
def _cfg(url='', **kwargs):
26+
return {
27+
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (1, 1),
28+
'crop_pct': 0.875, 'interpolation': 'bilinear',
29+
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
30+
'first_conv': 'conv_stem', 'classifier': 'classifier',
31+
**kwargs
32+
}
33+
34+
35+
default_cfgs = {
36+
'ghostnet_050': _cfg(url=''),
37+
'ghostnet_100': _cfg(
38+
url='https://github.com/huawei-noah/CV-backbones/releases/download/ghostnet_pth/ghostnet_1x.pth'),
39+
'ghostnet_130': _cfg(url=''),
40+
}
41+
42+
43+
_SE_LAYER = partial(SqueezeExcite, gate_fn=hard_sigmoid, divisor=4)
44+
45+
46+
class GhostModule(nn.Module):
47+
def __init__(self, inp, oup, kernel_size=1, ratio=2, dw_size=3, stride=1, relu=True):
48+
super(GhostModule, self).__init__()
49+
self.oup = oup
50+
init_channels = math.ceil(oup / ratio)
51+
new_channels = init_channels * (ratio - 1)
52+
53+
self.primary_conv = nn.Sequential(
54+
nn.Conv2d(inp, init_channels, kernel_size, stride, kernel_size//2, bias=False),
55+
nn.BatchNorm2d(init_channels),
56+
nn.ReLU(inplace=True) if relu else nn.Sequential(),
57+
)
58+
59+
self.cheap_operation = nn.Sequential(
60+
nn.Conv2d(init_channels, new_channels, dw_size, 1, dw_size//2, groups=init_channels, bias=False),
61+
nn.BatchNorm2d(new_channels),
62+
nn.ReLU(inplace=True) if relu else nn.Sequential(),
63+
)
64+
65+
def forward(self, x):
66+
x1 = self.primary_conv(x)
67+
x2 = self.cheap_operation(x1)
68+
out = torch.cat([x1, x2], dim=1)
69+
return out[:, :self.oup, :, :]
70+
71+
72+
class GhostBottleneck(nn.Module):
73+
""" Ghost bottleneck w/ optional SE"""
74+
75+
def __init__(self, in_chs, mid_chs, out_chs, dw_kernel_size=3,
76+
stride=1, act_layer=nn.ReLU, se_ratio=0.):
77+
super(GhostBottleneck, self).__init__()
78+
has_se = se_ratio is not None and se_ratio > 0.
79+
self.stride = stride
80+
81+
# Point-wise expansion
82+
self.ghost1 = GhostModule(in_chs, mid_chs, relu=True)
83+
84+
# Depth-wise convolution
85+
if self.stride > 1:
86+
self.conv_dw = nn.Conv2d(
87+
mid_chs, mid_chs, dw_kernel_size, stride=stride,
88+
padding=(dw_kernel_size-1)//2, groups=mid_chs, bias=False)
89+
self.bn_dw = nn.BatchNorm2d(mid_chs)
90+
else:
91+
self.conv_dw = None
92+
self.bn_dw = None
93+
94+
# Squeeze-and-excitation
95+
self.se = _SE_LAYER(mid_chs, se_ratio=se_ratio) if has_se else None
96+
97+
# Point-wise linear projection
98+
self.ghost2 = GhostModule(mid_chs, out_chs, relu=False)
99+
100+
# shortcut
101+
if in_chs == out_chs and self.stride == 1:
102+
self.shortcut = nn.Sequential()
103+
else:
104+
self.shortcut = nn.Sequential(
105+
nn.Conv2d(
106+
in_chs, in_chs, dw_kernel_size, stride=stride,
107+
padding=(dw_kernel_size-1)//2, groups=in_chs, bias=False),
108+
nn.BatchNorm2d(in_chs),
109+
nn.Conv2d(in_chs, out_chs, 1, stride=1, padding=0, bias=False),
110+
nn.BatchNorm2d(out_chs),
111+
)
112+
113+
114+
def forward(self, x):
115+
residual = x
116+
117+
# 1st ghost bottleneck
118+
x = self.ghost1(x)
119+
120+
# Depth-wise convolution
121+
if self.conv_dw is not None:
122+
x = self.conv_dw(x)
123+
x = self.bn_dw(x)
124+
125+
# Squeeze-and-excitation
126+
if self.se is not None:
127+
x = self.se(x)
128+
129+
# 2nd ghost bottleneck
130+
x = self.ghost2(x)
131+
132+
x += self.shortcut(residual)
133+
return x
134+
135+
136+
class GhostNet(nn.Module):
137+
def __init__(self, cfgs, num_classes=1000, width=1.0, dropout=0.2, in_chans=3, output_stride=32):
138+
super(GhostNet, self).__init__()
139+
# setting of inverted residual blocks
140+
assert output_stride == 32, 'only output_stride==32 is valid, dilation not supported'
141+
self.cfgs = cfgs
142+
self.num_classes = num_classes
143+
self.dropout = dropout
144+
self.feature_info = []
145+
146+
# building first layer
147+
stem_chs = make_divisible(16 * width, 4)
148+
self.conv_stem = nn.Conv2d(in_chans, stem_chs, 3, 2, 1, bias=False)
149+
self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module=f'conv_stem'))
150+
self.bn1 = nn.BatchNorm2d(stem_chs)
151+
self.act1 = nn.ReLU(inplace=True)
152+
prev_chs = stem_chs
153+
154+
# building inverted residual blocks
155+
stages = nn.ModuleList([])
156+
block = GhostBottleneck
157+
stage_idx = 0
158+
net_stride = 2
159+
for cfg in self.cfgs:
160+
layers = []
161+
s = 1
162+
for k, exp_size, c, se_ratio, s in cfg:
163+
out_chs = make_divisible(c * width, 4)
164+
mid_chs = make_divisible(exp_size * width, 4)
165+
layers.append(block(prev_chs, mid_chs, out_chs, k, s, se_ratio=se_ratio))
166+
prev_chs = out_chs
167+
if s > 1:
168+
net_stride *= 2
169+
self.feature_info.append(dict(
170+
num_chs=prev_chs, reduction=net_stride, module=f'blocks.{stage_idx}'))
171+
stages.append(nn.Sequential(*layers))
172+
stage_idx += 1
173+
174+
out_chs = make_divisible(exp_size * width, 4)
175+
stages.append(nn.Sequential(ConvBnAct(prev_chs, out_chs, 1)))
176+
self.pool_dim = prev_chs = out_chs
177+
178+
self.blocks = nn.Sequential(*stages)
179+
180+
# building last several layers
181+
self.num_features = out_chs = 1280
182+
self.global_pool = SelectAdaptivePool2d(pool_type='avg')
183+
self.conv_head = nn.Conv2d(prev_chs, out_chs, 1, 1, 0, bias=True)
184+
self.act2 = nn.ReLU(inplace=True)
185+
self.classifier = Linear(out_chs, num_classes)
186+
187+
def get_classifier(self):
188+
return self.classifier
189+
190+
def reset_classifier(self, num_classes, global_pool='avg'):
191+
self.num_classes = num_classes
192+
# cannot meaningfully change pooling of efficient head after creation
193+
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
194+
self.classifier = Linear(self.pool_dim, num_classes) if num_classes > 0 else nn.Identity()
195+
196+
def forward_features(self, x):
197+
x = self.conv_stem(x)
198+
x = self.bn1(x)
199+
x = self.act1(x)
200+
x = self.blocks(x)
201+
x = self.global_pool(x)
202+
x = self.conv_head(x)
203+
x = self.act2(x)
204+
return x
205+
206+
def forward(self, x):
207+
x = self.forward_features(x)
208+
if not self.global_pool.is_identity():
209+
x = x.view(x.size(0), -1)
210+
if self.dropout > 0.:
211+
x = F.dropout(x, p=self.dropout, training=self.training)
212+
x = self.classifier(x)
213+
return x
214+
215+
216+
def _create_ghostnet(variant, width=1.0, pretrained=False, **kwargs):
217+
"""
218+
Constructs a GhostNet model
219+
"""
220+
cfgs = [
221+
# k, t, c, SE, s
222+
# stage1
223+
[[3, 16, 16, 0, 1]],
224+
# stage2
225+
[[3, 48, 24, 0, 2]],
226+
[[3, 72, 24, 0, 1]],
227+
# stage3
228+
[[5, 72, 40, 0.25, 2]],
229+
[[5, 120, 40, 0.25, 1]],
230+
# stage4
231+
[[3, 240, 80, 0, 2]],
232+
[[3, 200, 80, 0, 1],
233+
[3, 184, 80, 0, 1],
234+
[3, 184, 80, 0, 1],
235+
[3, 480, 112, 0.25, 1],
236+
[3, 672, 112, 0.25, 1]
237+
],
238+
# stage5
239+
[[5, 672, 160, 0.25, 2]],
240+
[[5, 960, 160, 0, 1],
241+
[5, 960, 160, 0.25, 1],
242+
[5, 960, 160, 0, 1],
243+
[5, 960, 160, 0.25, 1]
244+
]
245+
]
246+
model_kwargs = dict(
247+
cfgs=cfgs,
248+
width=width,
249+
**kwargs,
250+
)
251+
return build_model_with_cfg(
252+
GhostNet, variant, pretrained,
253+
default_cfg=default_cfgs[variant],
254+
feature_cfg=dict(flatten_sequential=True),
255+
**model_kwargs)
256+
257+
258+
@register_model
259+
def ghostnet_050(pretrained=False, **kwargs):
260+
""" GhostNet-0.5x """
261+
model = _create_ghostnet('ghostnet_050', width=0.5, pretrained=pretrained, **kwargs)
262+
return model
263+
264+
265+
@register_model
266+
def ghostnet_100(pretrained=False, **kwargs):
267+
""" GhostNet-1.0x """
268+
model = _create_ghostnet('ghostnet_100', width=1.0, pretrained=pretrained, **kwargs)
269+
return model
270+
271+
272+
@register_model
273+
def ghostnet_130(pretrained=False, **kwargs):
274+
""" GhostNet-1.3x """
275+
model = _create_ghostnet('ghostnet_130', width=1.3, pretrained=pretrained, **kwargs)
276+
return model

0 commit comments

Comments
 (0)