Skip to content

Commit cc39b08

Browse files
andrewor14jcaip
andauthored
Fix quantization tutorials (imports, syntax, and style) (#1772)
* Fix quantization tutorials (imports, syntax, and style) Summary: This commit fixes the quantization tutorials such that they can be run smoothly by the user. Test Plan: Ran the updated tutorials without problem. Reviewers: jerryzh168 Subscribers: jerryzh168, supriyar ghstack-source-id: 196719d Pull Request resolved: #1763 * Fix quantization tutorials (imports, syntax, and style) Summary: This commit fixes the quantization tutorials such that they can be run smoothly by the user. Test Plan: Ran the updated tutorials without problem. Reviewers: jerryzh168 Subscribers: jerryzh168, supriyar ghstack-source-id: 196719d Pull Request resolved: #1763 * revert paths for wikitext version * Fix broken url Co-authored-by: Jesse Cai <[email protected]> Co-authored-by: Jesse Cai <[email protected]>
1 parent 9fb9f47 commit cc39b08

File tree

2 files changed

+77
-84
lines changed

2 files changed

+77
-84
lines changed

advanced_source/static_quantization_tutorial.rst

+42-46
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,20 @@ We'll start by doing the necessary imports:
2020

2121
.. code:: python
2222
23-
import numpy as np
24-
import torch
25-
import torch.nn as nn
26-
import torchvision
27-
from torch.utils.data import DataLoader
28-
from torchvision import datasets
29-
import torchvision.transforms as transforms
30-
import os
31-
import time
32-
import sys
33-
import torch.quantization
34-
35-
# # Setup warnings
23+
import os
24+
import sys
25+
import time
26+
import numpy as np
27+
28+
import torch
29+
import torch.nn as nn
30+
from torch.utils.data import DataLoader
31+
32+
import torchvision
33+
from torchvision import datasets
34+
import torchvision.transforms as transforms
35+
36+
# Set up warnings
3637
import warnings
3738
warnings.filterwarnings(
3839
action='ignore',
@@ -41,7 +42,7 @@ We'll start by doing the necessary imports:
4142
)
4243
warnings.filterwarnings(
4344
action='default',
44-
module=r'torch.quantization'
45+
module=r'torch.ao.quantization'
4546
)
4647
4748
# Specify random seed for repeatable results
@@ -62,7 +63,7 @@ Note: this code is taken from
6263

6364
.. code:: python
6465
65-
from torch.quantization import QuantStub, DeQuantStub
66+
from torch.ao.quantization import QuantStub, DeQuantStub
6667
6768
def _make_divisible(v, divisor, min_value=None):
6869
"""
@@ -196,9 +197,7 @@ Note: this code is taken from
196197
nn.init.zeros_(m.bias)
197198
198199
def forward(self, x):
199-
200200
x = self.quant(x)
201-
202201
x = self.features(x)
203202
x = x.mean([2, 3])
204203
x = self.classifier(x)
@@ -210,11 +209,11 @@ Note: this code is taken from
210209
def fuse_model(self):
211210
for m in self.modules():
212211
if type(m) == ConvBNReLU:
213-
torch.quantization.fuse_modules(m, ['0', '1', '2'], inplace=True)
212+
torch.ao.quantization.fuse_modules(m, ['0', '1', '2'], inplace=True)
214213
if type(m) == InvertedResidual:
215214
for idx in range(len(m.conv)):
216215
if type(m.conv[idx]) == nn.Conv2d:
217-
torch.quantization.fuse_modules(m.conv, [str(idx), str(idx + 1)], inplace=True)
216+
torch.ao.quantization.fuse_modules(m.conv, [str(idx), str(idx + 1)], inplace=True)
218217
219218
2. Helper functions
220219
-------------------
@@ -314,25 +313,22 @@ in this data. These functions mostly come from
314313
.. code:: python
315314
316315
def prepare_data_loaders(data_path):
317-
318316
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
319317
std=[0.229, 0.224, 0.225])
320318
dataset = torchvision.datasets.ImageNet(
321-
data_path, split="train",
322-
transforms.Compose([
323-
transforms.RandomResizedCrop(224),
324-
transforms.RandomHorizontalFlip(),
325-
transforms.ToTensor(),
326-
normalize,
327-
]))
319+
data_path, split="train", transform=transforms.Compose([
320+
transforms.RandomResizedCrop(224),
321+
transforms.RandomHorizontalFlip(),
322+
transforms.ToTensor(),
323+
normalize,
324+
]))
328325
dataset_test = torchvision.datasets.ImageNet(
329-
data_path, split="val",
330-
transforms.Compose([
331-
transforms.Resize(256),
332-
transforms.CenterCrop(224),
333-
transforms.ToTensor(),
334-
normalize,
335-
]))
326+
data_path, split="val", transform=transforms.Compose([
327+
transforms.Resize(256),
328+
transforms.CenterCrop(224),
329+
transforms.ToTensor(),
330+
normalize,
331+
]))
336332
337333
train_sampler = torch.utils.data.RandomSampler(dataset)
338334
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
@@ -348,8 +344,8 @@ in this data. These functions mostly come from
348344
return data_loader, data_loader_test
349345
350346
351-
Next, we'll load in the pre-trained MobileNetV2 model. We provide the URL to download the data from in ``torchvision``
352-
`here <https://github.com/pytorch/vision/blob/master/torchvision/models/mobilenet.py#L9>`_.
347+
Next, we'll load in the pre-trained MobileNetV2 model. We provide the URL to download the model
348+
`here <https://download.pytorch.org/models/mobilenet_v2-b0353104.pth>`_.
353349

354350
.. code:: python
355351
@@ -424,9 +420,9 @@ values to floats - and then back to ints - between every operation, resulting in
424420
425421
# Specify quantization configuration
426422
# Start with simple min/max range estimation and per-tensor quantization of weights
427-
myModel.qconfig = torch.quantization.default_qconfig
423+
myModel.qconfig = torch.ao.quantization.default_qconfig
428424
print(myModel.qconfig)
429-
torch.quantization.prepare(myModel, inplace=True)
425+
torch.ao.quantization.prepare(myModel, inplace=True)
430426
431427
# Calibrate first
432428
print('Post Training Quantization Prepare: Inserting Observers')
@@ -437,7 +433,7 @@ values to floats - and then back to ints - between every operation, resulting in
437433
print('Post Training Quantization: Calibration done')
438434
439435
# Convert to quantized model
440-
torch.quantization.convert(myModel, inplace=True)
436+
torch.ao.quantization.convert(myModel, inplace=True)
441437
print('Post Training Quantization: Convert done')
442438
print('\n Inverted Residual Block: After fusion and quantization, note fused modules: \n\n',myModel.features[1].conv)
443439
@@ -462,12 +458,12 @@ quantizing for x86 architectures. This configuration does the following:
462458
per_channel_quantized_model = load_model(saved_model_dir + float_model_file)
463459
per_channel_quantized_model.eval()
464460
per_channel_quantized_model.fuse_model()
465-
per_channel_quantized_model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
461+
per_channel_quantized_model.qconfig = torch.ao.quantization.get_default_qconfig('fbgemm')
466462
print(per_channel_quantized_model.qconfig)
467463
468-
torch.quantization.prepare(per_channel_quantized_model, inplace=True)
464+
torch.ao.quantization.prepare(per_channel_quantized_model, inplace=True)
469465
evaluate(per_channel_quantized_model,criterion, data_loader, num_calibration_batches)
470-
torch.quantization.convert(per_channel_quantized_model, inplace=True)
466+
torch.ao.quantization.convert(per_channel_quantized_model, inplace=True)
471467
top1, top5 = evaluate(per_channel_quantized_model, criterion, data_loader_test, neval_batches=num_eval_batches)
472468
print('Evaluation accuracy on %d images, %2.2f'%(num_eval_batches * eval_batch_size, top1.avg))
473469
torch.jit.save(torch.jit.script(per_channel_quantized_model), saved_model_dir + scripted_quantized_model_file)
@@ -539,13 +535,13 @@ We fuse modules as before
539535
qat_model.fuse_model()
540536
541537
optimizer = torch.optim.SGD(qat_model.parameters(), lr = 0.0001)
542-
qat_model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
538+
qat_model.qconfig = torch.ao.quantization.get_default_qat_qconfig('fbgemm')
543539
544540
Finally, ``prepare_qat`` performs the "fake quantization", preparing the model for quantization-aware training
545541

546542
.. code:: python
547543
548-
torch.quantization.prepare_qat(qat_model, inplace=True)
544+
torch.ao.quantization.prepare_qat(qat_model, inplace=True)
549545
print('Inverted Residual Block: After preparation for QAT, note fake-quantization modules \n',qat_model.features[1].conv)
550546
551547
Training a quantized model with high accuracy requires accurate modeling of numerics at
@@ -565,13 +561,13 @@ inference. For quantization aware training, therefore, we modify the training lo
565561
train_one_epoch(qat_model, criterion, optimizer, data_loader, torch.device('cpu'), num_train_batches)
566562
if nepoch > 3:
567563
# Freeze quantizer parameters
568-
qat_model.apply(torch.quantization.disable_observer)
564+
qat_model.apply(torch.ao.quantization.disable_observer)
569565
if nepoch > 2:
570566
# Freeze batch norm mean and variance estimates
571567
qat_model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
572568
573569
# Check the accuracy after each epoch
574-
quantized_model = torch.quantization.convert(qat_model.eval(), inplace=False)
570+
quantized_model = torch.ao.quantization.convert(qat_model.eval(), inplace=False)
575571
quantized_model.eval()
576572
top1, top5 = evaluate(quantized_model,criterion, data_loader_test, neval_batches=num_eval_batches)
577573
print('Epoch %d :Evaluation accuracy on %d images, %2.2f'%(nepoch, num_eval_batches * eval_batch_size, top1.avg))

prototype_source/fx_graph_mode_ptq_static.rst

+35-38
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,8 @@ tldr; The FX Graph Mode API looks like the following:
1313
.. code:: python
1414
1515
import torch
16-
from torch.quantization import get_default_qconfig
17-
# Note that this is temporary, we'll expose these functions to torch.quantization after official releasee
18-
from torch.quantization.quantize_fx import prepare_fx, convert_fx
16+
from torch.ao.quantization import get_default_qconfig
17+
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx
1918
float_model.eval()
2019
qconfig = get_default_qconfig("fbgemm")
2120
qconfig_dict = {"": qconfig}
@@ -58,24 +57,28 @@ These steps are identitcal to `Static Quantization with Eager Mode in PyTorch <h
5857

5958
To run the code in this tutorial using the entire ImageNet dataset, first download imagenet by following the instructions at here `ImageNet Data <http://www.image-net.org/download>`_. Unzip the downloaded file into the 'data_path' folder.
6059

61-
Download the `torchvision resnet18 model <https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py#L12>`_ and rename it to
60+
Download the `torchvision resnet18 model <https://download.pytorch.org/models/resnet18-f37072fd.pth>`_ and rename it to
6261
``data/resnet18_pretrained_float.pth``.
6362

6463
.. code:: python
6564
66-
import numpy as np
67-
import torch
68-
import torch.nn as nn
69-
import torchvision
70-
from torch.utils.data import DataLoader
71-
from torchvision import datasets
72-
import torchvision.transforms as transforms
73-
import os
74-
import time
75-
import sys
76-
import torch.quantization
77-
78-
# Setup warnings
65+
import os
66+
import sys
67+
import time
68+
import numpy as np
69+
70+
import torch
71+
from torch.ao.quantization import get_default_qconfig
72+
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx, fuse_fx
73+
import torch.nn as nn
74+
from torch.utils.data import DataLoader
75+
76+
import torchvision
77+
from torchvision import datasets
78+
from torchvision.models.resnet import resnet18
79+
import torchvision.transforms as transforms
80+
81+
# Set up warnings
7982
import warnings
8083
warnings.filterwarnings(
8184
action='ignore',
@@ -84,16 +87,13 @@ Download the `torchvision resnet18 model <https://github.com/pytorch/vision/blob
8487
)
8588
warnings.filterwarnings(
8689
action='default',
87-
module=r'torch.quantization'
90+
module=r'torch.ao.quantization'
8891
)
8992
9093
# Specify random seed for repeatable results
9194
_ = torch.manual_seed(191009)
9295
9396
94-
from torchvision.models.resnet import resnet18
95-
from torch.quantization import get_default_qconfig, quantize_jit
96-
9797
class AverageMeter(object):
9898
"""Computes and stores the average and current value"""
9999
def __init__(self, name, fmt=':f'):
@@ -168,25 +168,22 @@ Download the `torchvision resnet18 model <https://github.com/pytorch/vision/blob
168168
os.remove("temp.p")
169169
170170
def prepare_data_loaders(data_path):
171-
172171
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
173172
std=[0.229, 0.224, 0.225])
174173
dataset = torchvision.datasets.ImageNet(
175-
data_path, split="train",
176-
transforms.Compose([
177-
transforms.RandomResizedCrop(224),
178-
transforms.RandomHorizontalFlip(),
179-
transforms.ToTensor(),
180-
normalize,
181-
]))
174+
data_path, split="train", transform=transforms.Compose([
175+
transforms.RandomResizedCrop(224),
176+
transforms.RandomHorizontalFlip(),
177+
transforms.ToTensor(),
178+
normalize,
179+
]))
182180
dataset_test = torchvision.datasets.ImageNet(
183-
data_path, split="val",
184-
transforms.Compose([
185-
transforms.Resize(256),
186-
transforms.CenterCrop(224),
187-
transforms.ToTensor(),
188-
normalize,
189-
]))
181+
data_path, split="val", transform=transforms.Compose([
182+
transforms.Resize(256),
183+
transforms.CenterCrop(224),
184+
transforms.ToTensor(),
185+
normalize,
186+
]))
190187
191188
train_sampler = torch.utils.data.RandomSampler(dataset)
192189
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
@@ -239,7 +236,7 @@ of the observers for activation and weight. ``qconfig_dict`` is a dictionary wit
239236
.. code:: python
240237
241238
qconfig = {
242-
" : qconfig_global,
239+
"" : qconfig_global,
243240
"sub" : qconfig_sub,
244241
"sub.fc" : qconfig_fc,
245242
"sub.conv": None
@@ -282,7 +279,7 @@ of the observers for activation and weight. ``qconfig_dict`` is a dictionary wit
282279
]
283280
}
284281
285-
Utility functions related to ``qconfig`` can be found in the `qconfig <https://github.com/pytorch/pytorch/blob/master/torch/quantization/qconfig.py>`_ file.
282+
Utility functions related to ``qconfig`` can be found in the `qconfig <https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/qconfig.py>`_ file.
286283

287284
.. code:: python
288285

0 commit comments

Comments
 (0)