Skip to content

Commit b4cb576

Browse files
raghuramank100fmassa
authored andcommitted
Quantizable resnet and mobilenet models (pytorch#1471)
* add quantized models * Modify mobilenet.py documentation and clean up comments Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Move fuse_model method to QuantizableInvertedResidual and clean up args documentation Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Restore relu settings to default in resnet.py Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Fix missing return in forward Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Fix missing return in forwards Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Change pretrained -> pretrained_float_models Replace InvertedResidual with block Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Update tests to follow similar structure to test_models.py, allowing for modular testing Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Replace forward method with simple function assignment Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Fix error in arguments for resnet18 Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * pretrained_float_model argument missing for mobilenet Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * reference script for quantization aware training and post training quantization * reference script for quantization aware training and post training quantization * set pretrained_float_model as False and explicitly provide float model Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Address review comments: 1. Replace forward with _forward 2. Use pretrained models in reference train/eval script 3. Modify test to skip if fbgemm is not supported Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Fix lint errors. Use _forward for common code between float and quantized models Clean up linting for reference train scripts Test over all quantizable models Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Update default values for args in quantization/train.py Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Update models to conform to new API with quantize argument Remove apex in training script, add post training quant as an option Add support for separate calibration data set. Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Fix minor errors in train_quantization.py Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Remove duplicate file * Bugfix * Minor improvements on the models * Expose print_freq to evaluate * Minor improvements on train_quantization.py * Ensure that quantized models are created and run on the specified backends Fix errors in test only mode Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Add model urls * Fix errors in quantized model tests. Speedup creation of random quantized model by removing histogram observers Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Move setting qengine prior to convert. Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Fix lint error Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Add readme.md Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Readme.md Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Fix lint
1 parent e79cadd commit b4cb576

File tree

11 files changed

+738
-31
lines changed

11 files changed

+738
-31
lines changed

references/classification/README.md

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,32 @@ python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
2828
--model mobilenet_v2 --epochs 300 --lr 0.045 --wd 0.00004\
2929
--lr-step-size 1 --lr-gamma 0.98
3030
```
31+
32+
## Quantized
33+
34+
### Parameters used for generating quantized models:
35+
36+
For all post training quantized models (All quantized models except mobilenet-v2), the settings are:
37+
38+
1. num_calibration_batches: 32
39+
2. num_workers: 16
40+
3. batch_size: 32
41+
4. eval_batch_size: 128
42+
5. backend: 'fbgemm'
43+
44+
For Mobilenet-v2, the model was trained with quantization aware training, the settings used are:
45+
1. num_workers: 16
46+
2. batch_size: 32
47+
3. eval_batch_size: 128
48+
4. backend: 'qnnpack'
49+
5. learning-rate: 0.0001
50+
6. num_epochs: 90
51+
7. num_observer_update_epochs:4
52+
8. num_batch_norm_update_epochs:3
53+
9. momentum: 0.9
54+
10. lr_step_size:30
55+
11. lr_gamma: 0.1
56+
57+
Training converges at about 10 epochs.
58+
59+
For post training quant, device is set to CPU. For training, the device is set to CUDA

references/classification/train.py

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,12 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, pri
4747
metric_logger.meters['img/s'].update(batch_size / (time.time() - start_time))
4848

4949

50-
def evaluate(model, criterion, data_loader, device):
50+
def evaluate(model, criterion, data_loader, device, print_freq=100):
5151
model.eval()
5252
metric_logger = utils.MetricLogger(delimiter=" ")
5353
header = 'Test:'
5454
with torch.no_grad():
55-
for image, target in metric_logger.log_every(data_loader, 100, header):
55+
for image, target in metric_logger.log_every(data_loader, print_freq, header):
5656
image = image.to(device, non_blocking=True)
5757
target = target.to(device, non_blocking=True)
5858
output = model(image)
@@ -81,35 +81,16 @@ def _get_cache_path(filepath):
8181
return cache_path
8282

8383

84-
def main(args):
85-
if args.apex:
86-
if sys.version_info < (3, 0):
87-
raise RuntimeError("Apex currently only supports Python 3. Aborting.")
88-
if amp is None:
89-
raise RuntimeError("Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
90-
"to enable mixed-precision training.")
91-
92-
if args.output_dir:
93-
utils.mkdir(args.output_dir)
94-
95-
utils.init_distributed_mode(args)
96-
print(args)
97-
98-
device = torch.device(args.device)
99-
100-
torch.backends.cudnn.benchmark = True
101-
84+
def load_data(traindir, valdir, cache_dataset, distributed):
10285
# Data loading code
10386
print("Loading data")
104-
traindir = os.path.join(args.data_path, 'train')
105-
valdir = os.path.join(args.data_path, 'val')
10687
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
10788
std=[0.229, 0.224, 0.225])
10889

10990
print("Loading training data")
11091
st = time.time()
11192
cache_path = _get_cache_path(traindir)
112-
if args.cache_dataset and os.path.exists(cache_path):
93+
if cache_dataset and os.path.exists(cache_path):
11394
# Attention, as the transforms are also cached!
11495
print("Loading dataset_train from {}".format(cache_path))
11596
dataset, _ = torch.load(cache_path)
@@ -122,15 +103,15 @@ def main(args):
122103
transforms.ToTensor(),
123104
normalize,
124105
]))
125-
if args.cache_dataset:
106+
if cache_dataset:
126107
print("Saving dataset_train to {}".format(cache_path))
127108
utils.mkdir(os.path.dirname(cache_path))
128109
utils.save_on_master((dataset, traindir), cache_path)
129110
print("Took", time.time() - st)
130111

131112
print("Loading validation data")
132113
cache_path = _get_cache_path(valdir)
133-
if args.cache_dataset and os.path.exists(cache_path):
114+
if cache_dataset and os.path.exists(cache_path):
134115
# Attention, as the transforms are also cached!
135116
print("Loading dataset_test from {}".format(cache_path))
136117
dataset_test, _ = torch.load(cache_path)
@@ -143,19 +124,44 @@ def main(args):
143124
transforms.ToTensor(),
144125
normalize,
145126
]))
146-
if args.cache_dataset:
127+
if cache_dataset:
147128
print("Saving dataset_test to {}".format(cache_path))
148129
utils.mkdir(os.path.dirname(cache_path))
149130
utils.save_on_master((dataset_test, valdir), cache_path)
150131

151132
print("Creating data loaders")
152-
if args.distributed:
133+
if distributed:
153134
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
154135
test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test)
155136
else:
156137
train_sampler = torch.utils.data.RandomSampler(dataset)
157138
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
158139

140+
return dataset, dataset_test, train_sampler, test_sampler
141+
142+
143+
def main(args):
144+
if args.apex:
145+
if sys.version_info < (3, 0):
146+
raise RuntimeError("Apex currently only supports Python 3. Aborting.")
147+
if amp is None:
148+
raise RuntimeError("Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
149+
"to enable mixed-precision training.")
150+
151+
if args.output_dir:
152+
utils.mkdir(args.output_dir)
153+
154+
utils.init_distributed_mode(args)
155+
print(args)
156+
157+
device = torch.device(args.device)
158+
159+
torch.backends.cudnn.benchmark = True
160+
161+
train_dir = os.path.join(args.data_path, 'train')
162+
val_dir = os.path.join(args.data_path, 'val')
163+
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir,
164+
args.cache_dataset, args.distributed)
159165
data_loader = torch.utils.data.DataLoader(
160166
dataset, batch_size=args.batch_size,
161167
sampler=train_sampler, num_workers=args.workers, pin_memory=True)

0 commit comments

Comments
 (0)