Skip to content

Commit eaffc5a

Browse files
committed
Post-training quantization.
1 parent bf7217e commit eaffc5a

File tree

9 files changed

+445
-0
lines changed

9 files changed

+445
-0
lines changed

post_quant/README.md

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Usage
2+
3+
```python
4+
import torch
5+
import torchvision
6+
import torchvision.datasets as datasets
7+
import torchvision.transforms.transforms as transforms
8+
from post_quant.fake_quantization import fake_quant, load_fake_quant_model
9+
10+
model = torchvision.models.resnet50(True)
11+
model.eval()
12+
13+
db = datasets.ImageFolder(
14+
"ILSVRC2012_img_val",
15+
transforms.Compose([
16+
transforms.Resize(256),
17+
transforms.CenterCrop(224),
18+
transforms.ToTensor(),
19+
transforms.Normalize(
20+
mean=[0.485, 0.456, 0.406],
21+
std=[0.229, 0.224, 0.225]
22+
),
23+
]))
24+
dataset = torch.utils.data.DataLoader(
25+
db,
26+
batch_size=128,
27+
num_workers=8,
28+
shuffle=False,
29+
pin_memory=True)
30+
31+
# Quantize model
32+
q_model = fake_quant(model, dataset)
33+
34+
# Save model with scale & zero point:
35+
torch.save(model.state_dict(), 'model.quant')
36+
37+
# Reload model:
38+
m = load_fake_quant_model(torchvision.models.resnet50(), 'model.quant')
39+
```
40+
41+
42+
# TODO
43+
- [ ] Symmetric quantization
44+
- [ ] Channel-wise weight quantization
45+
- [ ] More sophisticated activation range calibration

post_quant/__init__.py

Whitespace-only changes.

post_quant/accuracy_test.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Borrowed from examples
2+
import torch
3+
import time
4+
5+
6+
class AverageMeter(object):
7+
"""Computes and stores the average and current value"""
8+
def __init__(self):
9+
self.reset()
10+
11+
def reset(self):
12+
self.val = 0
13+
self.avg = 0
14+
self.sum = 0
15+
self.count = 0
16+
17+
def update(self, val, n=1):
18+
self.val = val
19+
self.sum += val * n
20+
self.count += n
21+
self.avg = self.sum / self.count
22+
23+
24+
def accuracy(output, target, topk=(1,)):
25+
"""Computes the accuracy over the k top predictions for the specified values of k"""
26+
with torch.no_grad():
27+
maxk = max(topk)
28+
batch_size = target.size(0)
29+
30+
_, pred = output.topk(maxk, 1, True, True)
31+
pred = pred.t()
32+
correct = pred.eq(target.view(1, -1).expand_as(pred))
33+
34+
res = []
35+
for k in topk:
36+
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
37+
res.append(correct_k.mul_(100.0 / batch_size))
38+
39+
return res
40+
41+
42+
def validate(val_loader, model,
43+
shut_up=False,
44+
criterion=None, half=False):
45+
batch_time = AverageMeter()
46+
losses = AverageMeter()
47+
top1 = AverageMeter()
48+
top5 = AverageMeter()
49+
50+
# switch to evaluate mode
51+
model.eval()
52+
53+
with torch.no_grad():
54+
end = time.time()
55+
for i, (input, target) in enumerate(val_loader):
56+
if torch.cuda.is_available():
57+
input = input.cuda(non_blocking=True)
58+
target = target.cuda(non_blocking=True)
59+
if half:
60+
input = input.half()
61+
# compute output
62+
output = model(input)
63+
loss = criterion(output, target) if criterion else None
64+
65+
# measure accuracy and record layer
66+
acc1, acc5 = accuracy(output, target, topk=(1, 5))
67+
losses.update(loss.item() if loss else 0, input.size(0))
68+
top1.update(acc1[0], input.size(0))
69+
top5.update(acc5[0], input.size(0))
70+
71+
# measure elapsed time
72+
batch_time.update(time.time() - end)
73+
end = time.time()
74+
75+
if i % 10 == 0 and not shut_up:
76+
print('Test: [{0}/{1}]\t'
77+
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
78+
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
79+
'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
80+
'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
81+
i, len(val_loader), batch_time=batch_time, loss=losses,
82+
top1=top1, top5=top5))
83+
84+
if not shut_up:
85+
print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
86+
.format(top1=top1, top5=top5))
87+
88+
return top1.avg
89+

post_quant/activation.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import torch
2+
3+
from post_quant.accuracy_test import validate
4+
from post_quant.common import _weight_quantize_range, dequantize, quantize
5+
6+
7+
class ActivationMonitor(object):
8+
def __init__(self, bits=8, smooth=True):
9+
self.bits = bits
10+
self.smooth = smooth
11+
12+
def __call__(self, m, _, output_):
13+
o_max = output_.max().item()
14+
o_min = output_.min().item()
15+
if m.output_max is None:
16+
m.output_max = torch.tensor(o_max)
17+
m.output_min = torch.tensor(o_min)
18+
else:
19+
if not self.smooth:
20+
if m.output_max < o_max:
21+
m.output_max = o_max
22+
if m.output_min > o_min:
23+
m.output_min = o_min
24+
else:
25+
m.output_max = m.output_max * 0.9 + o_max * 0.1
26+
m.output_min = m.output_min * 0.9 + o_min * 0.1
27+
min = m.output_min.item()
28+
max = m.output_max.item()
29+
s, z = _weight_quantize_range(min, max, bits=self.bits)
30+
m.output_scale = torch.tensor(s)
31+
m.output_zero_point = torch.tensor(z)
32+
33+
34+
def register_activation_monitor(
35+
net,
36+
func):
37+
handles = []
38+
for n, module in net.named_modules():
39+
if need_monitor(module):
40+
h = hook_monitor(module, func)
41+
handles.append(h)
42+
return handles
43+
44+
45+
def fake_quant_activation_module(net):
46+
for n, m in net.named_modules():
47+
if need_monitor(m):
48+
replace_forward_op(m)
49+
50+
51+
def need_monitor(module):
52+
if isinstance(module, torch.nn.Conv2d) or \
53+
isinstance(module, torch.nn.BatchNorm2d) or \
54+
isinstance(module, torch.nn.Linear):
55+
return True
56+
return False
57+
58+
59+
def hook_monitor(m, func):
60+
m.register_buffer('output_scale', None)
61+
m.register_buffer('output_zero_point', None)
62+
m.output_max = None
63+
m.output_min = None
64+
return m.register_forward_hook(func)
65+
66+
67+
# Replace the forward function to record the output
68+
def replace_forward_op(module):
69+
old_forward = module.forward
70+
s = module.output_scale.item()
71+
z = module.output_zero_point.item()
72+
73+
def quant_forward(*input):
74+
output_ = old_forward(*input)
75+
return dequantize(quantize(output_, s, z), s, z)
76+
77+
module.forward = quant_forward
78+
79+
80+
def calibrate_activation_range(m, db, bits):
81+
hooks = register_activation_monitor(m, ActivationMonitor(bits=bits))
82+
validate(db, m)
83+
for h in hooks:
84+
h.remove()

post_quant/common.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import torch
2+
3+
4+
def _weight_quantize_range(min_w, max_w, bits):
5+
level = 2 ** bits - 1
6+
scale = (max_w - min_w) / level
7+
zero_point = round((0.0 - min_w) / scale)
8+
if max_w < 0:
9+
zero_point = level
10+
if min_w > 0:
11+
zero_point = 0
12+
return scale, zero_point
13+
14+
15+
def dequantize(weight, S, Z):
16+
return S * (weight - Z)
17+
18+
19+
def quantize(weight, S, Z, bits=8):
20+
return torch.clamp((weight / S).round() + Z, 0, 2 ** bits - 1)
21+
22+
23+
def quantize_tensor(tensor, bits):
24+
s, z = _weight_quantize_parameter(tensor, bits)
25+
return dequantize(quantize(tensor, s, z, bits), s, z), s, z
26+
27+
28+
def _weight_quantize_parameter(weight, bits=8):
29+
return _weight_quantize_range(weight.min().item(), weight.max().item(), bits)
30+
31+
32+
def register_quant_params(m):
33+
with torch.no_grad():
34+
for n, module in m.named_modules():
35+
if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
36+
module.register_buffer('weight_scale', torch.tensor(0.0))
37+
module.register_buffer('weight_zero_point', torch.tensor(0))
38+
module.register_buffer('bias_scale', torch.tensor(0.0))
39+
module.register_buffer('bias_zero_point', torch.tensor(0))
40+
module.register_buffer('output_scale', torch.tensor(0.0))
41+
module.register_buffer('output_zero_point', torch.tensor(0))

post_quant/fake_quantization.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import torch
2+
from .common import register_quant_params
3+
from .fusion import fuse_module
4+
from .weights import quantize_module
5+
from .activation import fake_quant_activation_module, calibrate_activation_range
6+
7+
8+
def load_fake_quant_model(m, f):
9+
state_dict = torch.load(f)
10+
m.eval()
11+
fuse_module(m)
12+
register_quant_params(m)
13+
m.load_state_dict(state_dict)
14+
fake_quant_activation_module(m)
15+
return m
16+
17+
18+
def fake_quant(m, db, bits=8):
19+
m.eval()
20+
fuse_module(m)
21+
calibrate_activation_range(m, db, bits)
22+
quantize_module(m, bits)
23+
fake_quant_activation_module(m)
24+
return m

post_quant/fusion.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import torch
2+
import torch.nn as nn
3+
from utils.modules import DummyModule
4+
5+
6+
def fuse(conv, bn):
7+
w = conv.weight
8+
mean = bn.running_mean
9+
var_sqrt = torch.sqrt(bn.running_var + bn.eps)
10+
11+
beta = bn.weight
12+
gamma = bn.bias
13+
14+
if conv.bias is not None:
15+
b = conv.bias
16+
else:
17+
b = mean.new_zeros(mean.shape)
18+
19+
w = w * (beta / var_sqrt).reshape([conv.out_channels, 1, 1, 1])
20+
b = (b - mean)/var_sqrt * beta + gamma
21+
22+
fused_conv = nn.Conv2d(
23+
conv.in_channels,
24+
conv.out_channels,
25+
conv.kernel_size,
26+
conv.stride,
27+
conv.padding,
28+
conv.dilation,
29+
conv.groups,
30+
bias=True,
31+
padding_mode=conv.padding_mode
32+
)
33+
fused_conv.weight = nn.Parameter(w)
34+
fused_conv.bias = nn.Parameter(b)
35+
return fused_conv
36+
37+
38+
def fuse_module(m):
39+
children = list(m.named_children())
40+
conv = None
41+
conv_name = None
42+
43+
for name, child in children:
44+
if isinstance(child, nn.BatchNorm2d) and conv:
45+
bc = fuse(conv, child)
46+
m._modules[conv_name] = bc
47+
m._modules[name] = DummyModule()
48+
conv = None
49+
elif isinstance(child, nn.Conv2d):
50+
conv = child
51+
conv_name = name
52+
else:
53+
fuse_module(child)
54+
55+
56+
def validate(net, input_, cuda=True):
57+
net.eval()
58+
if cuda:
59+
input_ = input_.cuda()
60+
net.cuda()
61+
# import time
62+
# s = time.time()
63+
a = net(input_)
64+
if cuda:
65+
torch.cuda.synchronize()
66+
# print(time.time() - s)
67+
fuse_module(net)
68+
# print(mbnet)
69+
# s = time.time()
70+
b = net(input_)
71+
if cuda:
72+
torch.cuda.synchronize()
73+
# print(time.time() - s)
74+
return (a - b).abs().max().item()
75+
76+
77+
if __name__ == '__main__':
78+
import torchvision
79+
mbnet = torchvision.models.mobilenet_v2(True)
80+
mbnet.eval()
81+
print(validate(mbnet, torch.randn(32, 3, 224, 224), True))

0 commit comments

Comments
 (0)