Skip to content

Commit b7fc9d3

Browse files
committed
Add GhostNetV2
1 parent c6f6de3 commit b7fc9d3

File tree

6 files changed

+962
-0
lines changed

6 files changed

+962
-0
lines changed

ghostnetv2_pytorch/README.md

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# GhostNetV2: Enhance Cheap Operation with Long-Range Attention
2+
3+
Code for our NeurIPS 2022 (Spotlight) paper, [GhostNetV2: Enhance Cheap Operation with Long-Range Attention](https://openreview.net/pdf/6db544c65bbd0fa7d7349508454a433c112470e2.pdf). Light-weight convolutional neural networks (CNNs) are specially designed for applications on mobile devices with faster inference speed. The convolutional operation can only capture local information in a window region, which prevents performance from being further improved. Introducing self-attention into convolution can capture global information well, but it will largely encumber the actual speed. In this paper, we propose a hardware-friendly attention mechanism (dubbed DFC attention) and then present a new GhostNetV2 architecture for mobile applications. The proposed DFC attention is constructed based on fully-connected layers, which can not only execute fast on common hardware but also capture the dependence between long-range pixels. We further revisit the expressiveness bottleneck in previous GhostNet and propose to enhance expanded features produced by cheap operations with DFC attention, so that a GhostNetV2 block can aggregate local and long-range information simultaneously. Extensive experiments demonstrate the superiority of GhostNetV2 over existing architectures. For example, it achieves 75.3% top-1 accuracy on ImageNet with 167M FLOPs, significantly suppressing GhostNetV1 (74.5%) with a similar computational cost.
4+
5+
The information flow of DFC attention:
6+
7+
<p align="center">
8+
<img src="fig/dfc.PNG" width="800">
9+
</p>
10+
11+
12+
The diagrams of blocks in GhostNetV1 and GhostNetV2:
13+
14+
<p align="center">
15+
<img src="fig/ghostnetv2.PNG" width="800">
16+
</p>
17+
18+
19+
20+
## Requirements
21+
22+
- python 3
23+
- pytorch == 1.7.1
24+
- torchvision == 0.8.2
25+
- timm==0.3.2
26+
27+
## Usage
28+
29+
30+
Run ghostnetv2/train.py` to train models. For example, you can run the following code to train GhostNetV2 on ImageNet dataset.
31+
32+
```shell
33+
python -m torch.distributed.launch --nproc_per_node=8 train.py path_to_imagenet/ --output /cache/models/ --model ghostnetv2 -b 128 --sched step --epochs 450 --decay-epochs 2.4 --decay-rate .973 --opt rmsproptf --opt-eps .001 -j 7 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.2 --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --lr .064 --lr-noise 0.42 0.9 --width 1.0
34+
```
35+
## Results
36+
37+
<p align="center">
38+
<img src="fig/imagenet.PNG" width="900">
39+
</p>

ghostnetv2_pytorch/fig/dfc.PNG

261 KB
Loading

ghostnetv2_pytorch/fig/ghostnetv2.PNG

129 KB
Loading

ghostnetv2_pytorch/fig/imagenet.PNG

95.5 KB
Loading
Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
# 2020.11.06-Changed for building GhostNetV2
2+
# Huawei Technologies Co., Ltd. <[email protected]>
3+
"""
4+
Creates a GhostNet Model as defined in:
5+
GhostNet: More Features from Cheap Operations By Kai Han, Yunhe Wang, Qi Tian, Jianyuan Guo, Chunjing Xu, Chang Xu.
6+
https://arxiv.org/abs/1911.11907
7+
Modified from https://github.com/d-li14/mobilenetv3.pytorch and https://github.com/rwightman/pytorch-image-models
8+
"""
9+
import torch
10+
import torch.nn as nn
11+
import torch.nn.functional as F
12+
import math
13+
14+
from timm.models.registry import register_model
15+
16+
def _make_divisible(v, divisor, min_value=None):
17+
"""
18+
This function is taken from the original tf repo.
19+
It ensures that all layers have a channel number that is divisible by 8
20+
It can be seen here:
21+
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
22+
"""
23+
if min_value is None:
24+
min_value = divisor
25+
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
26+
# Make sure that round down does not go down by more than 10%.
27+
if new_v < 0.9 * v:
28+
new_v += divisor
29+
return new_v
30+
31+
def hard_sigmoid(x, inplace: bool = False):
32+
if inplace:
33+
return x.add_(3.).clamp_(0., 6.).div_(6.)
34+
else:
35+
return F.relu6(x + 3.) / 6.
36+
37+
class SqueezeExcite(nn.Module):
38+
def __init__(self, in_chs, se_ratio=0.25, reduced_base_chs=None,
39+
act_layer=nn.ReLU, gate_fn=hard_sigmoid, divisor=4, **_):
40+
super(SqueezeExcite, self).__init__()
41+
self.gate_fn = gate_fn
42+
reduced_chs = _make_divisible((reduced_base_chs or in_chs) * se_ratio, divisor)
43+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
44+
self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True)
45+
self.act1 = act_layer(inplace=True)
46+
self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True)
47+
48+
def forward(self, x):
49+
x_se = self.avg_pool(x)
50+
x_se = self.conv_reduce(x_se)
51+
x_se = self.act1(x_se)
52+
x_se = self.conv_expand(x_se)
53+
x = x * self.gate_fn(x_se)
54+
return x
55+
56+
class ConvBnAct(nn.Module):
57+
def __init__(self, in_chs, out_chs, kernel_size,
58+
stride=1, act_layer=nn.ReLU):
59+
super(ConvBnAct, self).__init__()
60+
self.conv = nn.Conv2d(in_chs, out_chs, kernel_size, stride, kernel_size//2, bias=False)
61+
self.bn1 = nn.BatchNorm2d(out_chs)
62+
self.act1 = act_layer(inplace=True)
63+
64+
def forward(self, x):
65+
x = self.conv(x)
66+
x = self.bn1(x)
67+
x = self.act1(x)
68+
return x
69+
70+
class GhostModuleV2(nn.Module):
71+
def __init__(self, inp, oup, kernel_size=1, ratio=2, dw_size=3, stride=1, relu=True,mode=None,args=None):
72+
super(GhostModuleV2, self).__init__()
73+
self.mode=mode
74+
self.gate_fn=nn.Sigmoid()
75+
76+
if self.mode in ['original']:
77+
self.oup = oup
78+
init_channels = math.ceil(oup / ratio)
79+
new_channels = init_channels*(ratio-1)
80+
self.primary_conv = nn.Sequential(
81+
nn.Conv2d(inp, init_channels, kernel_size, stride, kernel_size//2, bias=False),
82+
nn.BatchNorm2d(init_channels),
83+
nn.ReLU(inplace=True) if relu else nn.Sequential(),
84+
)
85+
self.cheap_operation = nn.Sequential(
86+
nn.Conv2d(init_channels, new_channels, dw_size, 1, dw_size//2, groups=init_channels, bias=False),
87+
nn.BatchNorm2d(new_channels),
88+
nn.ReLU(inplace=True) if relu else nn.Sequential(),
89+
)
90+
elif self.mode in ['attn']:
91+
self.oup = oup
92+
init_channels = math.ceil(oup / ratio)
93+
new_channels = init_channels*(ratio-1)
94+
self.primary_conv = nn.Sequential(
95+
nn.Conv2d(inp, init_channels, kernel_size, stride, kernel_size//2, bias=False),
96+
nn.BatchNorm2d(init_channels),
97+
nn.ReLU(inplace=True) if relu else nn.Sequential(),
98+
)
99+
self.cheap_operation = nn.Sequential(
100+
nn.Conv2d(init_channels, new_channels, dw_size, 1, dw_size//2, groups=init_channels, bias=False),
101+
nn.BatchNorm2d(new_channels),
102+
nn.ReLU(inplace=True) if relu else nn.Sequential(),
103+
)
104+
self.short_conv = nn.Sequential(
105+
nn.Conv2d(inp, oup, kernel_size, stride, kernel_size//2, bias=False),
106+
nn.BatchNorm2d(oup),
107+
nn.Conv2d(oup, oup, kernel_size=(1,5), stride=1, padding=(0,2), groups=oup,bias=False),
108+
nn.BatchNorm2d(oup),
109+
nn.Conv2d(oup, oup, kernel_size=(5,1), stride=1, padding=(2,0), groups=oup,bias=False),
110+
nn.BatchNorm2d(oup),
111+
)
112+
113+
def forward(self, x):
114+
if self.mode in ['original']:
115+
x1 = self.primary_conv(x)
116+
x2 = self.cheap_operation(x1)
117+
out = torch.cat([x1,x2], dim=1)
118+
return out[:,:self.oup,:,:]
119+
elif self.mode in ['attn']:
120+
res=self.short_conv(F.avg_pool2d(x,kernel_size=2,stride=2))
121+
x1 = self.primary_conv(x)
122+
x2 = self.cheap_operation(x1)
123+
out = torch.cat([x1,x2], dim=1)
124+
return out[:,:self.oup,:,:]*F.interpolate(self.gate_fn(res),size=out.shape[-1],mode='nearest')
125+
126+
127+
class GhostBottleneckV2(nn.Module):
128+
129+
def __init__(self, in_chs, mid_chs, out_chs, dw_kernel_size=3,
130+
stride=1, act_layer=nn.ReLU, se_ratio=0.,layer_id=None,args=None):
131+
super(GhostBottleneckV2, self).__init__()
132+
has_se = se_ratio is not None and se_ratio > 0.
133+
self.stride = stride
134+
135+
# Point-wise expansion
136+
if layer_id<=1:
137+
self.ghost1 = GhostModuleV2(in_chs, mid_chs, relu=True,mode='original',args=args)
138+
else:
139+
self.ghost1 = GhostModuleV2(in_chs, mid_chs, relu=True,mode='attn',args=args)
140+
141+
# Depth-wise convolution
142+
if self.stride > 1:
143+
self.conv_dw = nn.Conv2d(mid_chs, mid_chs, dw_kernel_size, stride=stride,
144+
padding=(dw_kernel_size-1)//2,groups=mid_chs, bias=False)
145+
self.bn_dw = nn.BatchNorm2d(mid_chs)
146+
147+
# Squeeze-and-excitation
148+
if has_se:
149+
self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio)
150+
else:
151+
self.se = None
152+
153+
self.ghost2 = GhostModuleV2(mid_chs, out_chs, relu=False,mode='original',args=args)
154+
155+
# shortcut
156+
if (in_chs == out_chs and self.stride == 1):
157+
self.shortcut = nn.Sequential()
158+
else:
159+
self.shortcut = nn.Sequential(
160+
nn.Conv2d(in_chs, in_chs, dw_kernel_size, stride=stride,
161+
padding=(dw_kernel_size-1)//2, groups=in_chs, bias=False),
162+
nn.BatchNorm2d(in_chs),
163+
nn.Conv2d(in_chs, out_chs, 1, stride=1, padding=0, bias=False),
164+
nn.BatchNorm2d(out_chs),
165+
)
166+
def forward(self, x):
167+
residual = x
168+
x = self.ghost1(x)
169+
if self.stride > 1:
170+
x = self.conv_dw(x)
171+
x = self.bn_dw(x)
172+
if self.se is not None:
173+
x = self.se(x)
174+
x = self.ghost2(x)
175+
x += self.shortcut(residual)
176+
return x
177+
178+
179+
class GhostNetV2(nn.Module):
180+
def __init__(self, cfgs, num_classes=1000, width=1.0, dropout=0.2,block=GhostBottleneckV2,args=None):
181+
super(GhostNetV2, self).__init__()
182+
self.cfgs = cfgs
183+
self.dropout = dropout
184+
185+
# building first layer
186+
output_channel = _make_divisible(16 * width, 4)
187+
self.conv_stem = nn.Conv2d(3, output_channel, 3, 2, 1, bias=False)
188+
self.bn1 = nn.BatchNorm2d(output_channel)
189+
self.act1 = nn.ReLU(inplace=True)
190+
input_channel = output_channel
191+
192+
# building inverted residual blocks
193+
stages = []
194+
#block = block
195+
layer_id=0
196+
for cfg in self.cfgs:
197+
layers = []
198+
for k, exp_size, c, se_ratio, s in cfg:
199+
output_channel = _make_divisible(c * width, 4)
200+
hidden_channel = _make_divisible(exp_size * width, 4)
201+
if block==GhostBottleneckV2:
202+
layers.append(block(input_channel, hidden_channel, output_channel, k, s,
203+
se_ratio=se_ratio,layer_id=layer_id,args=args))
204+
input_channel = output_channel
205+
layer_id+=1
206+
stages.append(nn.Sequential(*layers))
207+
208+
output_channel = _make_divisible(exp_size * width, 4)
209+
stages.append(nn.Sequential(ConvBnAct(input_channel, output_channel, 1)))
210+
input_channel = output_channel
211+
212+
self.blocks = nn.Sequential(*stages)
213+
214+
# building last several layers
215+
output_channel = 1280
216+
self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
217+
self.conv_head = nn.Conv2d(input_channel, output_channel, 1, 1, 0, bias=True)
218+
self.act2 = nn.ReLU(inplace=True)
219+
self.classifier = nn.Linear(output_channel, num_classes)
220+
221+
def forward(self, x):
222+
x = self.conv_stem(x)
223+
x = self.bn1(x)
224+
x = self.act1(x)
225+
x = self.blocks(x)
226+
x = self.global_pool(x)
227+
x = self.conv_head(x)
228+
x = self.act2(x)
229+
x = x.view(x.size(0), -1)
230+
if self.dropout > 0.:
231+
x = F.dropout(x, p=self.dropout, training=self.training)
232+
x = self.classifier(x)
233+
return x
234+
235+
@register_model
236+
def ghostnetv2(**kwargs):
237+
cfgs = [
238+
# k, t, c, SE, s
239+
[[3, 16, 16, 0, 1]],
240+
[[3, 48, 24, 0, 2]],
241+
[[3, 72, 24, 0, 1]],
242+
[[5, 72, 40, 0.25, 2]],
243+
[[5, 120, 40, 0.25, 1]],
244+
[[3, 240, 80, 0, 2]],
245+
[[3, 200, 80, 0, 1],
246+
[3, 184, 80, 0, 1],
247+
[3, 184, 80, 0, 1],
248+
[3, 480, 112, 0.25, 1],
249+
[3, 672, 112, 0.25, 1]
250+
],
251+
[[5, 672, 160, 0.25, 2]],
252+
[[5, 960, 160, 0, 1],
253+
[5, 960, 160, 0.25, 1],
254+
[5, 960, 160, 0, 1],
255+
[5, 960, 160, 0.25, 1]
256+
]
257+
]
258+
return GhostNetV2(cfgs, num_classes=kwargs['num_classes'],
259+
width=kwargs['width'],
260+
dropout=kwargs['dropout'],
261+
args=kwargs['args'])

0 commit comments

Comments
 (0)