Skip to content

Commit ed57f80

Browse files
jerryzh168facebook-github-bot
authored andcommitted
[quant][refactor] Move some util functions from torch/quantization/fx/utils.py to torch/quantization/utils.py (pytorch#48107)
Summary: Pull Request resolved: pytorch#48107 Test Plan: Imported from OSS Reviewed By: supriyar Differential Revision: D25026495 fbshipit-source-id: 3634b6b95a18670232600874b1e593180ea9f44c
1 parent 4316bf9 commit ed57f80

File tree

4 files changed

+93
-87
lines changed

4 files changed

+93
-87
lines changed

torch/quantization/fx/quantization_patterns.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,24 @@
1313
get_static_quant_module_class,
1414
get_quantized_operator,
1515
)
16+
from ..utils import (
17+
get_swapped_custom_module_class,
18+
activation_is_statically_quantized,
19+
weight_is_statically_quantized,
20+
weight_dtype,
21+
get_qconfig_dtypes,
22+
)
23+
1624
from .pattern_utils import (
1725
register_quant_pattern,
1826
mark_input_output_not_observed,
1927
)
28+
2029
from .utils import (
2130
_parent_name,
2231
quantize_node,
2332
get_per_tensor_qparams,
24-
get_swapped_custom_module_class,
25-
activation_is_statically_quantized,
26-
weight_is_quantized,
27-
weight_dtype,
2833
get_linear_prepack_op_for_dtype,
29-
get_qconfig_dtypes,
3034
)
3135

3236
from abc import ABC, abstractmethod
@@ -339,7 +343,7 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_
339343
quantized_input_idxs = []
340344
if activation_statically_quantized:
341345
quantized_input_idxs.append(0)
342-
if weight_is_quantized(qconfig):
346+
if weight_is_statically_quantized(qconfig):
343347
quantized_input_idxs.append(1)
344348
args = load_arg(quantized=quantized_input_idxs)(self.linear_node.args)
345349
args = load_arg(quantized=False)(self.linear_node.args)
@@ -360,7 +364,7 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_
360364
else: # non-debug option
361365
# linear args
362366
# (x, weight, bias, ...)
363-
weight_quantized = weight_is_quantized(qconfig)
367+
weight_quantized = weight_is_statically_quantized(qconfig)
364368
linear_weight = load_arg(quantized=weight_quantized)(self.linear_node.args[1])
365369

366370
# get other arguments

torch/quantization/fx/quantize.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@
2525
)
2626

2727
from ..utils import (
28-
get_combined_dict
28+
get_combined_dict,
29+
get_swapped_custom_module_class,
30+
activation_is_statically_quantized,
2931
)
3032

3133
from .pattern_utils import (
@@ -48,8 +50,6 @@
4850
_parent_name,
4951
quantize_node,
5052
get_custom_module_class_keys,
51-
get_swapped_custom_module_class,
52-
activation_is_statically_quantized,
5353
)
5454

5555
from collections import OrderedDict

torch/quantization/fx/utils.py

Lines changed: 1 addition & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import re
22
import torch
3-
from ..quant_type import QuantType, quant_type_to_str
3+
from ..utils import is_per_tensor, is_per_channel
44

55
# turn foo.bar -> ['foo', 'bar']
66
def _parent_name(target):
@@ -76,15 +76,6 @@ def graph_pretty_str(g, shorten=True) -> str:
7676
res_str += "*obs_{n} = activation_post_process_{n}\n"
7777
return res_str
7878

79-
def is_per_tensor(qscheme):
80-
return qscheme == torch.per_tensor_affine or \
81-
qscheme == torch.per_tensor_symmetric
82-
83-
def is_per_channel(qscheme):
84-
return qscheme in [torch.per_channel_affine,
85-
torch.per_channel_affine_float_qparams,
86-
torch.per_channel_symmetric]
87-
8879
def get_per_tensor_qparams(activation_post_process):
8980
assert is_per_tensor(activation_post_process.qscheme), 'Only per tensor quantization is supported'
9081
scale, zero_point = activation_post_process.calculate_qparams()
@@ -171,73 +162,6 @@ def get_custom_module_class_keys(custom_config_dict, custom_config_dict_key):
171162
float_custom_module_classes |= quant_mode_custom_module_classes
172163
return list(float_custom_module_classes)
173164

174-
def get_swapped_custom_module_class(custom_module, custom_module_class_mapping, qconfig):
175-
""" Get the observed/quantized custom module class that we need
176-
to swap `custom_module` to
177-
Input:
178-
custom_module: input, can be an instance of either a float or observed custom module
179-
custom_module_class_mapping: the float to observed or observed to quantized custom module class mapping
180-
qconfig: qconfig configured for the custom module
181-
182-
Output:
183-
corresponding observed/quantized custom module class for input custom module instance
184-
"""
185-
quant_type = get_quant_type(qconfig)
186-
quant_type_str = quant_type_to_str(quant_type)
187-
class_mapping = custom_module_class_mapping.get(quant_type_str, {})
188-
assert type(custom_module) in class_mapping, "did not found corresponding observed " \
189-
"module class for {} in mapping: {}".format(type(custom_module), class_mapping)
190-
return class_mapping[type(custom_module)]
191-
192-
def activation_is_statically_quantized(qconfig):
193-
""" Given a qconfig, decide if the activation needs to be
194-
statically quantized or not
195-
"""
196-
assert qconfig is not None
197-
activation = qconfig.activation()
198-
return activation.dtype in [torch.quint8, torch.qint8]
199-
200-
def weight_dtype(qconfig):
201-
assert qconfig is not None
202-
weight = qconfig.weight()
203-
return weight.dtype
204-
205-
def weight_is_quantized(qconfig):
206-
""" Given a qconfig, decide if the activation needs to be
207-
quantized or not
208-
"""
209-
return weight_dtype(qconfig) in [torch.quint8, torch.qint8]
210-
211-
def get_qconfig_dtypes(qconfig):
212-
r""" returns the qconfig tuple for qconfig:
213-
(activation_dtype, weight_dtype, activation_compute_dtype)
214-
"""
215-
assert qconfig is not None
216-
activation = qconfig.activation()
217-
weight = qconfig.weight()
218-
compute_dtype = activation.compute_dtype if hasattr(activation, 'compute_dtype') else None
219-
return (activation.dtype, weight.dtype, compute_dtype)
220-
221-
def get_quant_type(qconfig):
222-
assert qconfig is not None
223-
activation = qconfig.activation()
224-
weight = qconfig.weight()
225-
static_dtypes = [torch.quint8, torch.qint8]
226-
if weight.dtype in static_dtypes:
227-
if activation.dtype in static_dtypes:
228-
return QuantType.STATIC
229-
elif hasattr(activation, 'compute_dtype') and activation.compute_dtype in static_dtypes:
230-
return QuantType.DYNAMIC
231-
else:
232-
return QuantType.WEIGHT_ONLY
233-
234-
if weight.dtype == torch.float16:
235-
if activation.dtype == torch.float:
236-
return QuantType.DYNAMIC
237-
238-
raise Exception("Unrecognized dtype combination in get_quant_type: activation({}),"
239-
"weight({})".format(activation.dtype, weight.dtype))
240-
241165
def get_linear_prepack_op_for_dtype(dtype):
242166
if dtype == torch.float16:
243167
return torch.ops.quantized.linear_prepack_fp16

torch/quantization/utils.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,87 @@
11
"""
22
Utils shared by different modes of quantization (eager/graph)
33
"""
4+
import torch
5+
from .quant_type import QuantType, quant_type_to_str
46

57
def get_combined_dict(default_dict, additional_dict):
68
d = default_dict.copy()
79
for k, v in additional_dict.items():
810
d[k] = v
911
return d
12+
13+
def is_per_tensor(qscheme):
14+
return qscheme == torch.per_tensor_affine or \
15+
qscheme == torch.per_tensor_symmetric
16+
17+
def is_per_channel(qscheme):
18+
return qscheme in [torch.per_channel_affine,
19+
torch.per_channel_affine_float_qparams,
20+
torch.per_channel_symmetric]
21+
22+
def get_swapped_custom_module_class(custom_module, custom_module_class_mapping, qconfig):
23+
""" Get the observed/quantized custom module class that we need
24+
to swap `custom_module` to
25+
Input:
26+
custom_module: input, can be an instance of either a float or observed custom module
27+
custom_module_class_mapping: the float to observed or observed to quantized custom module class mapping
28+
qconfig: qconfig configured for the custom module
29+
30+
Output:
31+
corresponding observed/quantized custom module class for input custom module instance
32+
"""
33+
quant_type = get_quant_type(qconfig)
34+
quant_type_str = quant_type_to_str(quant_type)
35+
class_mapping = custom_module_class_mapping.get(quant_type_str, {})
36+
assert type(custom_module) in class_mapping, "did not found corresponding observed " \
37+
"module class for {} in mapping: {}".format(type(custom_module), class_mapping)
38+
return class_mapping[type(custom_module)]
39+
40+
def activation_is_statically_quantized(qconfig):
41+
""" Given a qconfig, decide if the activation needs to be
42+
statically quantized or not
43+
"""
44+
assert qconfig is not None
45+
activation = qconfig.activation()
46+
return activation.dtype in [torch.quint8, torch.qint8]
47+
48+
def weight_dtype(qconfig):
49+
assert qconfig is not None
50+
weight = qconfig.weight()
51+
return weight.dtype
52+
53+
def weight_is_statically_quantized(qconfig):
54+
""" Given a qconfig, decide if the weight needs to be
55+
quantized or not
56+
"""
57+
return weight_dtype(qconfig) in [torch.quint8, torch.qint8]
58+
59+
def get_qconfig_dtypes(qconfig):
60+
r""" returns the qconfig tuple for qconfig:
61+
(activation_dtype, weight_dtype, activation_compute_dtype)
62+
"""
63+
assert qconfig is not None
64+
activation = qconfig.activation()
65+
weight = qconfig.weight()
66+
compute_dtype = activation.compute_dtype if hasattr(activation, 'compute_dtype') else None
67+
return (activation.dtype, weight.dtype, compute_dtype)
68+
69+
def get_quant_type(qconfig):
70+
assert qconfig is not None
71+
activation = qconfig.activation()
72+
weight = qconfig.weight()
73+
static_dtypes = [torch.quint8, torch.qint8]
74+
if weight.dtype in static_dtypes:
75+
if activation.dtype in static_dtypes:
76+
return QuantType.STATIC
77+
elif hasattr(activation, 'compute_dtype') and activation.compute_dtype in static_dtypes:
78+
return QuantType.DYNAMIC
79+
else:
80+
return QuantType.WEIGHT_ONLY
81+
82+
if weight.dtype == torch.float16:
83+
if activation.dtype == torch.float:
84+
return QuantType.DYNAMIC
85+
86+
raise Exception("Unrecognized dtype combination in get_quant_type: activation({}),"
87+
"weight({})".format(activation.dtype, weight.dtype))

0 commit comments

Comments
 (0)