Skip to content

Commit

Permalink
Add default config set for tuning (#1562)
Browse files Browse the repository at this point in the history
Signed-off-by: yiliu30 <[email protected]>
Co-authored-by: chen, suyue <[email protected]>
  • Loading branch information
yiliu30 and chensuyue authored Jan 25, 2024
1 parent 8ea2fd3 commit fa8e66a
Show file tree
Hide file tree
Showing 10 changed files with 210 additions and 24 deletions.
38 changes: 33 additions & 5 deletions neural_compressor/common/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@
import re
from abc import ABC, abstractmethod
from collections import OrderedDict
from copy import deepcopy
from itertools import product
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union

from neural_compressor.common import Logger
from neural_compressor.common.utils import (
Expand All @@ -44,13 +43,12 @@
"register_config",
"BaseConfig",
"ComposableConfig",
"Options",
"get_all_config_set_from_config_registry",
"options",
]

# Dictionary to store registered configurations


# Config registry to store all registered configs.
class ConfigRegistry:
registered_configs = {}

Expand Down Expand Up @@ -104,6 +102,13 @@ def get_cls_configs(cls) -> Dict[str, Dict[str, object]]:
cls_configs[framework_name][algo_name] = config_data["cls"]
return cls_configs

@classmethod
def get_all_config_cls_by_fwk_name(cls, fwk_name: str) -> List[Type[BaseConfig]]:
configs_cls = []
for algo_name, config_pairs in cls.registered_configs.get(fwk_name, {}).items():
configs_cls.append(config_pairs["cls"])
return configs_cls


config_registry = ConfigRegistry()

Expand Down Expand Up @@ -373,6 +378,11 @@ def _is_op_type(name: str) -> bool:
# TODO (Yi), ort and tf need override it
return not isinstance(name, str)

@classmethod
@abstractmethod
def get_config_set_for_tuning(cls):
raise NotImplementedError


class ComposableConfig(BaseConfig):
name = COMPOSABLE_CONFIG
Expand Down Expand Up @@ -420,6 +430,24 @@ def register_supported_configs(cls):
"""Add all supported configs."""
raise NotImplementedError

@classmethod
def get_config_set_for_tuning(cls) -> None:
# TODO (Yi) handle the composable config in `tuning_config`
return None


def get_all_config_set_from_config_registry(fwk_name: str) -> Union[BaseConfig, List[BaseConfig]]:
all_registered_config_cls: List[BaseConfig] = config_registry.get_all_config_cls_by_fwk_name(fwk_name)
config_set = []
for config_cls in all_registered_config_cls:
config_set.append(config_cls.get_config_set_for_tuning())
return config_set


#######################################################
#### Options
#######################################################


def _check_value(name, src, supported_type, supported_value=[]):
"""Check if the given object is the given supported type and in the given supported value.
Expand Down
14 changes: 7 additions & 7 deletions neural_compressor/common/base_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ class Sampler:


class ConfigLoader:
def __init__(self, quant_configs, sampler: Sampler) -> None:
self.quant_configs = quant_configs
def __init__(self, config_set, sampler: Sampler) -> None:
self.config_set = config_set
self.sampler = sampler

@staticmethod
Expand All @@ -146,7 +146,7 @@ def parse_quant_config(quant_config: BaseConfig) -> List[BaseConfig]:
def parse_quant_configs(self) -> List[BaseConfig]:
# TODO (Yi) separate this functionality into `Sampler` in the next PR
quant_config_list = []
for quant_config in self.quant_configs:
for quant_config in self.config_set:
quant_config_list.extend(ConfigLoader.parse_quant_config(quant_config))
return quant_config_list

Expand Down Expand Up @@ -210,14 +210,14 @@ class TuningConfig:
"""Base Class for Tuning Criterion.
Args:
quant_configs: quantization configs. Default value is empty.
config_set: quantization configs. Default value is empty.
timeout: Tuning timeout (seconds). Default value is 0 which means early stop.
max_trials: Max tuning times. Default value is 100. Combine with timeout field to decide when to exit.
"""

def __init__(self, quant_configs=None, timeout=0, max_trials=100, sampler: Sampler = None) -> None:
def __init__(self, config_set=None, timeout=0, max_trials=100, sampler: Sampler = None) -> None:
"""Init a TuneCriterion object."""
self.quant_configs = quant_configs
self.config_set = config_set
self.timeout = timeout
self.max_trials = max_trials
self.sampler = sampler
Expand Down Expand Up @@ -265,7 +265,7 @@ def need_stop(self) -> bool:


def init_tuning(tuning_config: TuningConfig) -> Tuple[ConfigLoader, TuningLogger, TuningMonitor]:
config_loader = ConfigLoader(quant_configs=tuning_config.quant_configs, sampler=tuning_config.sampler)
config_loader = ConfigLoader(config_set=tuning_config.config_set, sampler=tuning_config.sampler)
tuning_logger = TuningLogger()
tuning_monitor = TuningMonitor(tuning_config)
return config_loader, tuning_logger, tuning_monitor
5 changes: 5 additions & 0 deletions neural_compressor/onnxrt/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,11 @@ def get_model_info(model: Union[onnx.ModelProto, Path, str]) -> List[Tuple[str,
logger.debug(f"Get model info: {filter_result}")
return filter_result

@classmethod
def get_config_set_for_tuning(cls) -> Union[None, "RTNConfig", List["RTNConfig"]]: # pragma: no cover
# TODO fwk owner needs to update it.
return RTNConfig(weight_bits=[4, 6])


# TODO(Yi) run `register_supported_configs` for all registered config.
RTNConfig.register_supported_configs()
Expand Down
7 changes: 7 additions & 0 deletions neural_compressor/tensorflow/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,13 @@ def register_supported_configs(cls) -> List[OperatorConfig]:
supported_configs.append(OperatorConfig(config=static_quant_config, operators=operators))
cls.supported_configs = supported_configs

@classmethod
def get_config_set_for_tuning(
cls,
) -> Union[None, "StaticQuantConfig", List["StaticQuantConfig"]]: # pragma: no cover
# TODO fwk owner needs to update it.
return StaticQuantConfig(weight_sym=[True, False])


# TODO(Yi) run `register_supported_configs` for all registered config.
StaticQuantConfig.register_supported_configs()
Expand Down
2 changes: 1 addition & 1 deletion neural_compressor/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@
)

from neural_compressor.common.base_tuning import TuningConfig
from neural_compressor.torch.quantization.autotune import autotune, get_default_tune_config
from neural_compressor.torch.quantization.autotune import autotune, get_all_config_set
19 changes: 11 additions & 8 deletions neural_compressor/torch/quantization/autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from copy import deepcopy
from typing import Dict, List, Optional, Union

import torch

from neural_compressor.common import Logger
from neural_compressor.common.base_config import BaseConfig
from neural_compressor.common.base_config import BaseConfig, get_all_config_set_from_config_registry
from neural_compressor.common.base_tuning import TuningConfig, evaluator, init_tuning
from neural_compressor.torch import quantize
from neural_compressor.torch.quantization.config import GPTQConfig, RTNConfig
from neural_compressor.torch.quantization.config import FRAMEWORK_NAME

logger = Logger().get_logger()


__all__ = [
"get_default_tune_config",
"autotune",
"get_all_config_set",
]


def get_default_tune_config() -> TuningConfig:
# TODO use the registered default tuning config in the next PR
return TuningConfig(quant_configs=[GPTQConfig(weight_bits=[4, 8]), RTNConfig(weight_bits=[4, 8])])
def get_all_config_set() -> Union[BaseConfig, List[BaseConfig]]:
return get_all_config_set_from_config_registry(fwk_name=FRAMEWORK_NAME)


def autotune(
Expand All @@ -52,15 +52,18 @@ def autotune(
for trial_index, quant_config in enumerate(config_loader):
tuning_logger.trial_start(trial_index=trial_index)
tuning_logger.quantization_start()
q_model = quantize(model, quant_config=quant_config, run_fn=run_fn, run_args=run_args)
logger.info(f"quant config: {quant_config}")
# !!! Make sure to use deepcopy only when inplace is set to `True`.
q_model = quantize(deepcopy(model), quant_config=quant_config, run_fn=run_fn, run_args=run_args, inplace=True)
tuning_logger.quantization_end()
tuning_logger.evaluation_start()
eval_result: float = evaluator.evaluate(q_model)
tuning_logger.evaluation_end()
tuning_monitor.add_trial_result(trial_index, eval_result, quant_config)
if tuning_monitor.need_stop():
best_quant_config: BaseConfig = tuning_monitor.get_best_quant_config()
quantize(model, quant_config=best_quant_config, run_fn=run_fn, run_args=run_args, inplace=True)
# !!! Make sure to use deepcopy only when inplace is set to `True`.
quantize(deepcopy(model), quant_config=best_quant_config, run_fn=run_fn, run_args=run_args, inplace=True)
best_quant_model = model # quantize model inplace
tuning_logger.trial_end(trial_index)
tuning_logger.tuning_end()
Expand Down
33 changes: 33 additions & 0 deletions neural_compressor/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@
from neural_compressor.torch.utils.constants import PRIORITY_GPTQ, PRIORITY_RTN
from neural_compressor.torch.utils.utility import is_hpex_avaliable, logger

__all__ = [
"RTNConfig",
"get_default_rtn_config",
"GPTQConfig",
"get_default_gptq_config",
]


FRAMEWORK_NAME = "torch"
DTYPE_RANGE = Union[torch.dtype, List[torch.dtype]]

Expand Down Expand Up @@ -153,6 +161,11 @@ def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]:
logger.debug(f"Get model info: {filter_result}")
return filter_result

@classmethod
def get_config_set_for_tuning(cls) -> Union[None, "RTNConfig", List["RTNConfig"]]:
# TODO fwk owner needs to update it.
return RTNConfig(weight_bits=[4, 6])


# TODO(Yi) run `register_supported_configs` for all registered config.
RTNConfig.register_supported_configs()
Expand Down Expand Up @@ -276,6 +289,11 @@ def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]:
logger.debug(f"Get model info: {filter_result}")
return filter_result

@classmethod
def get_config_set_for_tuning(cls) -> Union[None, "GPTQConfig", List["GPTQConfig"]]:
# TODO fwk owner needs to update it.
return GPTQConfig(weight_bits=[4, 6])


# TODO(Yi) run `register_supported_configs` for all registered config.
GPTQConfig.register_supported_configs()
Expand Down Expand Up @@ -352,6 +370,11 @@ def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]:
logger.debug(f"Get model info: {filter_result}")
return filter_result

@classmethod
def get_config_set_for_tuning(cls) -> Union[None, "StaticQuantConfig", List["StaticQuantConfig"]]:
# TODO fwk owner needs to update it.
return StaticQuantConfig(w_sym=[True, False])


# TODO(Yi) run `register_supported_configs` for all registered config.
StaticQuantConfig.register_supported_configs()
Expand Down Expand Up @@ -461,6 +484,11 @@ def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]:
logger.debug(f"Get model info: {filter_result}")
return filter_result

@classmethod
def get_config_set_for_tuning(cls) -> Union[None, "SmoothQuantConfig", List["SmoothQuantConfig"]]:
# TODO fwk owner needs to update it.
return SmoothQuantConfig(alpha=[0.1, 0.5])


# TODO(Yi) run `register_supported_configs` for all registered config.
SmoothQuantConfig.register_supported_configs()
Expand Down Expand Up @@ -541,6 +569,11 @@ def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]:
logger.debug(f"Get model info: {filter_result}")
return filter_result

@classmethod
def get_config_set_for_tuning(cls) -> Union[None, "FP8QConfig", List["FP8QConfig"]]:
# TODO fwk owner needs to update it.
return FP8QConfig(act_dtype=[torch.float8_e4m3fn])

# TODO(Yi) run `register_supported_configs` for all registered config.
FP8QConfig.register_supported_configs()

Expand Down
8 changes: 8 additions & 0 deletions test/3x/onnxrt/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,14 @@ def test_expand_config(self):
self.assertEqual(expand_config_list[0].weight_bits, 4)
self.assertEqual(expand_config_list[1].weight_bits, 8)

def test_config_set_api(self):
# *Note: this test is only for improving the code coverage and can be removed once the test_common is enabled.
from neural_compressor.common.base_config import config_registry, get_all_config_set_from_config_registry
from neural_compressor.onnxrt.quantization.config import FRAMEWORK_NAME

config_set = get_all_config_set_from_config_registry(fwk_name=FRAMEWORK_NAME)
self.assertEqual(len(config_set), len(config_registry.registered_configs[FRAMEWORK_NAME]))


if __name__ == "__main__":
unittest.main()
8 changes: 8 additions & 0 deletions test/3x/tensorflow/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,14 @@ def test_expand_config(self):
self.assertEqual(expand_config_list[0].weight_granularity, "per_channel")
self.assertEqual(expand_config_list[1].weight_granularity, "per_tensor")

def test_config_set_api(self):
# *Note: this test is only for improving the code coverage and can be removed once the test_common is enabled.
from neural_compressor.common.base_config import config_registry, get_all_config_set_from_config_registry
from neural_compressor.tensorflow.quantization.config import FRAMEWORK_NAME

config_set = get_all_config_set_from_config_registry(fwk_name=FRAMEWORK_NAME)
self.assertEqual(len(config_set), len(config_registry.registered_configs[FRAMEWORK_NAME]))


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit fa8e66a

Please sign in to comment.