Skip to content

Commit fa8e66a

Browse files
yiliu30chensuyue
andauthored
Add default config set for tuning (#1562)
Signed-off-by: yiliu30 <[email protected]> Co-authored-by: chen, suyue <[email protected]>
1 parent 8ea2fd3 commit fa8e66a

File tree

10 files changed

+210
-24
lines changed

10 files changed

+210
-24
lines changed

neural_compressor/common/base_config.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,8 @@
2121
import re
2222
from abc import ABC, abstractmethod
2323
from collections import OrderedDict
24-
from copy import deepcopy
2524
from itertools import product
26-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
25+
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
2726

2827
from neural_compressor.common import Logger
2928
from neural_compressor.common.utils import (
@@ -44,13 +43,12 @@
4443
"register_config",
4544
"BaseConfig",
4645
"ComposableConfig",
47-
"Options",
46+
"get_all_config_set_from_config_registry",
4847
"options",
4948
]
5049

51-
# Dictionary to store registered configurations
52-
5350

51+
# Config registry to store all registered configs.
5452
class ConfigRegistry:
5553
registered_configs = {}
5654

@@ -104,6 +102,13 @@ def get_cls_configs(cls) -> Dict[str, Dict[str, object]]:
104102
cls_configs[framework_name][algo_name] = config_data["cls"]
105103
return cls_configs
106104

105+
@classmethod
106+
def get_all_config_cls_by_fwk_name(cls, fwk_name: str) -> List[Type[BaseConfig]]:
107+
configs_cls = []
108+
for algo_name, config_pairs in cls.registered_configs.get(fwk_name, {}).items():
109+
configs_cls.append(config_pairs["cls"])
110+
return configs_cls
111+
107112

108113
config_registry = ConfigRegistry()
109114

@@ -373,6 +378,11 @@ def _is_op_type(name: str) -> bool:
373378
# TODO (Yi), ort and tf need override it
374379
return not isinstance(name, str)
375380

381+
@classmethod
382+
@abstractmethod
383+
def get_config_set_for_tuning(cls):
384+
raise NotImplementedError
385+
376386

377387
class ComposableConfig(BaseConfig):
378388
name = COMPOSABLE_CONFIG
@@ -420,6 +430,24 @@ def register_supported_configs(cls):
420430
"""Add all supported configs."""
421431
raise NotImplementedError
422432

433+
@classmethod
434+
def get_config_set_for_tuning(cls) -> None:
435+
# TODO (Yi) handle the composable config in `tuning_config`
436+
return None
437+
438+
439+
def get_all_config_set_from_config_registry(fwk_name: str) -> Union[BaseConfig, List[BaseConfig]]:
440+
all_registered_config_cls: List[BaseConfig] = config_registry.get_all_config_cls_by_fwk_name(fwk_name)
441+
config_set = []
442+
for config_cls in all_registered_config_cls:
443+
config_set.append(config_cls.get_config_set_for_tuning())
444+
return config_set
445+
446+
447+
#######################################################
448+
#### Options
449+
#######################################################
450+
423451

424452
def _check_value(name, src, supported_type, supported_value=[]):
425453
"""Check if the given object is the given supported type and in the given supported value.

neural_compressor/common/base_tuning.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,8 @@ class Sampler:
129129

130130

131131
class ConfigLoader:
132-
def __init__(self, quant_configs, sampler: Sampler) -> None:
133-
self.quant_configs = quant_configs
132+
def __init__(self, config_set, sampler: Sampler) -> None:
133+
self.config_set = config_set
134134
self.sampler = sampler
135135

136136
@staticmethod
@@ -146,7 +146,7 @@ def parse_quant_config(quant_config: BaseConfig) -> List[BaseConfig]:
146146
def parse_quant_configs(self) -> List[BaseConfig]:
147147
# TODO (Yi) separate this functionality into `Sampler` in the next PR
148148
quant_config_list = []
149-
for quant_config in self.quant_configs:
149+
for quant_config in self.config_set:
150150
quant_config_list.extend(ConfigLoader.parse_quant_config(quant_config))
151151
return quant_config_list
152152

@@ -210,14 +210,14 @@ class TuningConfig:
210210
"""Base Class for Tuning Criterion.
211211
212212
Args:
213-
quant_configs: quantization configs. Default value is empty.
213+
config_set: quantization configs. Default value is empty.
214214
timeout: Tuning timeout (seconds). Default value is 0 which means early stop.
215215
max_trials: Max tuning times. Default value is 100. Combine with timeout field to decide when to exit.
216216
"""
217217

218-
def __init__(self, quant_configs=None, timeout=0, max_trials=100, sampler: Sampler = None) -> None:
218+
def __init__(self, config_set=None, timeout=0, max_trials=100, sampler: Sampler = None) -> None:
219219
"""Init a TuneCriterion object."""
220-
self.quant_configs = quant_configs
220+
self.config_set = config_set
221221
self.timeout = timeout
222222
self.max_trials = max_trials
223223
self.sampler = sampler
@@ -265,7 +265,7 @@ def need_stop(self) -> bool:
265265

266266

267267
def init_tuning(tuning_config: TuningConfig) -> Tuple[ConfigLoader, TuningLogger, TuningMonitor]:
268-
config_loader = ConfigLoader(quant_configs=tuning_config.quant_configs, sampler=tuning_config.sampler)
268+
config_loader = ConfigLoader(config_set=tuning_config.config_set, sampler=tuning_config.sampler)
269269
tuning_logger = TuningLogger()
270270
tuning_monitor = TuningMonitor(tuning_config)
271271
return config_loader, tuning_logger, tuning_monitor

neural_compressor/onnxrt/quantization/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,11 @@ def get_model_info(model: Union[onnx.ModelProto, Path, str]) -> List[Tuple[str,
144144
logger.debug(f"Get model info: {filter_result}")
145145
return filter_result
146146

147+
@classmethod
148+
def get_config_set_for_tuning(cls) -> Union[None, "RTNConfig", List["RTNConfig"]]: # pragma: no cover
149+
# TODO fwk owner needs to update it.
150+
return RTNConfig(weight_bits=[4, 6])
151+
147152

148153
# TODO(Yi) run `register_supported_configs` for all registered config.
149154
RTNConfig.register_supported_configs()

neural_compressor/tensorflow/quantization/config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,13 @@ def register_supported_configs(cls) -> List[OperatorConfig]:
103103
supported_configs.append(OperatorConfig(config=static_quant_config, operators=operators))
104104
cls.supported_configs = supported_configs
105105

106+
@classmethod
107+
def get_config_set_for_tuning(
108+
cls,
109+
) -> Union[None, "StaticQuantConfig", List["StaticQuantConfig"]]: # pragma: no cover
110+
# TODO fwk owner needs to update it.
111+
return StaticQuantConfig(weight_sym=[True, False])
112+
106113

107114
# TODO(Yi) run `register_supported_configs` for all registered config.
108115
StaticQuantConfig.register_supported_configs()

neural_compressor/torch/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,4 @@
2828
)
2929

3030
from neural_compressor.common.base_tuning import TuningConfig
31-
from neural_compressor.torch.quantization.autotune import autotune, get_default_tune_config
31+
from neural_compressor.torch.quantization.autotune import autotune, get_all_config_set

neural_compressor/torch/quantization/autotune.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,28 +12,28 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from copy import deepcopy
1516
from typing import Dict, List, Optional, Union
1617

1718
import torch
1819

1920
from neural_compressor.common import Logger
20-
from neural_compressor.common.base_config import BaseConfig
21+
from neural_compressor.common.base_config import BaseConfig, get_all_config_set_from_config_registry
2122
from neural_compressor.common.base_tuning import TuningConfig, evaluator, init_tuning
2223
from neural_compressor.torch import quantize
23-
from neural_compressor.torch.quantization.config import GPTQConfig, RTNConfig
24+
from neural_compressor.torch.quantization.config import FRAMEWORK_NAME
2425

2526
logger = Logger().get_logger()
2627

2728

2829
__all__ = [
29-
"get_default_tune_config",
3030
"autotune",
31+
"get_all_config_set",
3132
]
3233

3334

34-
def get_default_tune_config() -> TuningConfig:
35-
# TODO use the registered default tuning config in the next PR
36-
return TuningConfig(quant_configs=[GPTQConfig(weight_bits=[4, 8]), RTNConfig(weight_bits=[4, 8])])
35+
def get_all_config_set() -> Union[BaseConfig, List[BaseConfig]]:
36+
return get_all_config_set_from_config_registry(fwk_name=FRAMEWORK_NAME)
3737

3838

3939
def autotune(
@@ -52,15 +52,18 @@ def autotune(
5252
for trial_index, quant_config in enumerate(config_loader):
5353
tuning_logger.trial_start(trial_index=trial_index)
5454
tuning_logger.quantization_start()
55-
q_model = quantize(model, quant_config=quant_config, run_fn=run_fn, run_args=run_args)
55+
logger.info(f"quant config: {quant_config}")
56+
# !!! Make sure to use deepcopy only when inplace is set to `True`.
57+
q_model = quantize(deepcopy(model), quant_config=quant_config, run_fn=run_fn, run_args=run_args, inplace=True)
5658
tuning_logger.quantization_end()
5759
tuning_logger.evaluation_start()
5860
eval_result: float = evaluator.evaluate(q_model)
5961
tuning_logger.evaluation_end()
6062
tuning_monitor.add_trial_result(trial_index, eval_result, quant_config)
6163
if tuning_monitor.need_stop():
6264
best_quant_config: BaseConfig = tuning_monitor.get_best_quant_config()
63-
quantize(model, quant_config=best_quant_config, run_fn=run_fn, run_args=run_args, inplace=True)
65+
# !!! Make sure to use deepcopy only when inplace is set to `True`.
66+
quantize(deepcopy(model), quant_config=best_quant_config, run_fn=run_fn, run_args=run_args, inplace=True)
6467
best_quant_model = model # quantize model inplace
6568
tuning_logger.trial_end(trial_index)
6669
tuning_logger.tuning_end()

neural_compressor/torch/quantization/config.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,14 @@
3636
from neural_compressor.torch.utils.constants import PRIORITY_GPTQ, PRIORITY_RTN
3737
from neural_compressor.torch.utils.utility import is_hpex_avaliable, logger
3838

39+
__all__ = [
40+
"RTNConfig",
41+
"get_default_rtn_config",
42+
"GPTQConfig",
43+
"get_default_gptq_config",
44+
]
45+
46+
3947
FRAMEWORK_NAME = "torch"
4048
DTYPE_RANGE = Union[torch.dtype, List[torch.dtype]]
4149

@@ -153,6 +161,11 @@ def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]:
153161
logger.debug(f"Get model info: {filter_result}")
154162
return filter_result
155163

164+
@classmethod
165+
def get_config_set_for_tuning(cls) -> Union[None, "RTNConfig", List["RTNConfig"]]:
166+
# TODO fwk owner needs to update it.
167+
return RTNConfig(weight_bits=[4, 6])
168+
156169

157170
# TODO(Yi) run `register_supported_configs` for all registered config.
158171
RTNConfig.register_supported_configs()
@@ -276,6 +289,11 @@ def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]:
276289
logger.debug(f"Get model info: {filter_result}")
277290
return filter_result
278291

292+
@classmethod
293+
def get_config_set_for_tuning(cls) -> Union[None, "GPTQConfig", List["GPTQConfig"]]:
294+
# TODO fwk owner needs to update it.
295+
return GPTQConfig(weight_bits=[4, 6])
296+
279297

280298
# TODO(Yi) run `register_supported_configs` for all registered config.
281299
GPTQConfig.register_supported_configs()
@@ -352,6 +370,11 @@ def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]:
352370
logger.debug(f"Get model info: {filter_result}")
353371
return filter_result
354372

373+
@classmethod
374+
def get_config_set_for_tuning(cls) -> Union[None, "StaticQuantConfig", List["StaticQuantConfig"]]:
375+
# TODO fwk owner needs to update it.
376+
return StaticQuantConfig(w_sym=[True, False])
377+
355378

356379
# TODO(Yi) run `register_supported_configs` for all registered config.
357380
StaticQuantConfig.register_supported_configs()
@@ -461,6 +484,11 @@ def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]:
461484
logger.debug(f"Get model info: {filter_result}")
462485
return filter_result
463486

487+
@classmethod
488+
def get_config_set_for_tuning(cls) -> Union[None, "SmoothQuantConfig", List["SmoothQuantConfig"]]:
489+
# TODO fwk owner needs to update it.
490+
return SmoothQuantConfig(alpha=[0.1, 0.5])
491+
464492

465493
# TODO(Yi) run `register_supported_configs` for all registered config.
466494
SmoothQuantConfig.register_supported_configs()
@@ -541,6 +569,11 @@ def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]:
541569
logger.debug(f"Get model info: {filter_result}")
542570
return filter_result
543571

572+
@classmethod
573+
def get_config_set_for_tuning(cls) -> Union[None, "FP8QConfig", List["FP8QConfig"]]:
574+
# TODO fwk owner needs to update it.
575+
return FP8QConfig(act_dtype=[torch.float8_e4m3fn])
576+
544577
# TODO(Yi) run `register_supported_configs` for all registered config.
545578
FP8QConfig.register_supported_configs()
546579

test/3x/onnxrt/test_config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,14 @@ def test_expand_config(self):
328328
self.assertEqual(expand_config_list[0].weight_bits, 4)
329329
self.assertEqual(expand_config_list[1].weight_bits, 8)
330330

331+
def test_config_set_api(self):
332+
# *Note: this test is only for improving the code coverage and can be removed once the test_common is enabled.
333+
from neural_compressor.common.base_config import config_registry, get_all_config_set_from_config_registry
334+
from neural_compressor.onnxrt.quantization.config import FRAMEWORK_NAME
335+
336+
config_set = get_all_config_set_from_config_registry(fwk_name=FRAMEWORK_NAME)
337+
self.assertEqual(len(config_set), len(config_registry.registered_configs[FRAMEWORK_NAME]))
338+
331339

332340
if __name__ == "__main__":
333341
unittest.main()

test/3x/tensorflow/test_config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,14 @@ def test_expand_config(self):
315315
self.assertEqual(expand_config_list[0].weight_granularity, "per_channel")
316316
self.assertEqual(expand_config_list[1].weight_granularity, "per_tensor")
317317

318+
def test_config_set_api(self):
319+
# *Note: this test is only for improving the code coverage and can be removed once the test_common is enabled.
320+
from neural_compressor.common.base_config import config_registry, get_all_config_set_from_config_registry
321+
from neural_compressor.tensorflow.quantization.config import FRAMEWORK_NAME
322+
323+
config_set = get_all_config_set_from_config_registry(fwk_name=FRAMEWORK_NAME)
324+
self.assertEqual(len(config_set), len(config_registry.registered_configs[FRAMEWORK_NAME]))
325+
318326

319327
if __name__ == "__main__":
320328
unittest.main()

0 commit comments

Comments
 (0)