|
36 | 36 | from neural_compressor.torch.utils.constants import PRIORITY_GPTQ, PRIORITY_RTN |
37 | 37 | from neural_compressor.torch.utils.utility import is_hpex_avaliable, logger |
38 | 38 |
|
| 39 | +__all__ = [ |
| 40 | + "RTNConfig", |
| 41 | + "get_default_rtn_config", |
| 42 | + "GPTQConfig", |
| 43 | + "get_default_gptq_config", |
| 44 | +] |
| 45 | + |
| 46 | + |
39 | 47 | FRAMEWORK_NAME = "torch" |
40 | 48 | DTYPE_RANGE = Union[torch.dtype, List[torch.dtype]] |
41 | 49 |
|
@@ -153,6 +161,11 @@ def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]: |
153 | 161 | logger.debug(f"Get model info: {filter_result}") |
154 | 162 | return filter_result |
155 | 163 |
|
| 164 | + @classmethod |
| 165 | + def get_config_set_for_tuning(cls) -> Union[None, "RTNConfig", List["RTNConfig"]]: |
| 166 | + # TODO fwk owner needs to update it. |
| 167 | + return RTNConfig(weight_bits=[4, 6]) |
| 168 | + |
156 | 169 |
|
157 | 170 | # TODO(Yi) run `register_supported_configs` for all registered config. |
158 | 171 | RTNConfig.register_supported_configs() |
@@ -276,6 +289,11 @@ def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]: |
276 | 289 | logger.debug(f"Get model info: {filter_result}") |
277 | 290 | return filter_result |
278 | 291 |
|
| 292 | + @classmethod |
| 293 | + def get_config_set_for_tuning(cls) -> Union[None, "GPTQConfig", List["GPTQConfig"]]: |
| 294 | + # TODO fwk owner needs to update it. |
| 295 | + return GPTQConfig(weight_bits=[4, 6]) |
| 296 | + |
279 | 297 |
|
280 | 298 | # TODO(Yi) run `register_supported_configs` for all registered config. |
281 | 299 | GPTQConfig.register_supported_configs() |
@@ -352,6 +370,11 @@ def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]: |
352 | 370 | logger.debug(f"Get model info: {filter_result}") |
353 | 371 | return filter_result |
354 | 372 |
|
| 373 | + @classmethod |
| 374 | + def get_config_set_for_tuning(cls) -> Union[None, "StaticQuantConfig", List["StaticQuantConfig"]]: |
| 375 | + # TODO fwk owner needs to update it. |
| 376 | + return StaticQuantConfig(w_sym=[True, False]) |
| 377 | + |
355 | 378 |
|
356 | 379 | # TODO(Yi) run `register_supported_configs` for all registered config. |
357 | 380 | StaticQuantConfig.register_supported_configs() |
@@ -461,6 +484,11 @@ def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]: |
461 | 484 | logger.debug(f"Get model info: {filter_result}") |
462 | 485 | return filter_result |
463 | 486 |
|
| 487 | + @classmethod |
| 488 | + def get_config_set_for_tuning(cls) -> Union[None, "SmoothQuantConfig", List["SmoothQuantConfig"]]: |
| 489 | + # TODO fwk owner needs to update it. |
| 490 | + return SmoothQuantConfig(alpha=[0.1, 0.5]) |
| 491 | + |
464 | 492 |
|
465 | 493 | # TODO(Yi) run `register_supported_configs` for all registered config. |
466 | 494 | SmoothQuantConfig.register_supported_configs() |
@@ -541,6 +569,11 @@ def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]: |
541 | 569 | logger.debug(f"Get model info: {filter_result}") |
542 | 570 | return filter_result |
543 | 571 |
|
| 572 | + @classmethod |
| 573 | + def get_config_set_for_tuning(cls) -> Union[None, "FP8QConfig", List["FP8QConfig"]]: |
| 574 | + # TODO fwk owner needs to update it. |
| 575 | + return FP8QConfig(act_dtype=[torch.float8_e4m3fn]) |
| 576 | + |
544 | 577 | # TODO(Yi) run `register_supported_configs` for all registered config. |
545 | 578 | FP8QConfig.register_supported_configs() |
546 | 579 |
|
|
0 commit comments