From 30b36b83a195c6ea350692c7ac0bfec1b52ee419 Mon Sep 17 00:00:00 2001 From: Yi Liu <106061964+yiliu30@users.noreply.github.com> Date: Thu, 16 May 2024 20:25:07 +0800 Subject: [PATCH] Add pt2e dynamic quantization (#1795) Signed-off-by: yiliu30 --- .../torch/algorithms/pt2e_quant/__init__.py | 2 +- .../torch/algorithms/pt2e_quant/core.py | 14 ++-- .../torch/quantization/__init__.py | 2 + .../torch/quantization/algorithm_entry.py | 28 ++++++- .../torch/quantization/config.py | 77 ++++++++++++++++++- neural_compressor/torch/utils/constants.py | 1 + neural_compressor/torch/utils/environ.py | 18 ++--- neural_compressor/torch/utils/utility.py | 29 ++++--- .../algorithms/pt2e_quant/test_pt2e_w8a8.py | 8 +- test/3x/torch/quantization/test_pt2e_quant.py | 29 +++++-- 10 files changed, 167 insertions(+), 41 deletions(-) diff --git a/neural_compressor/torch/algorithms/pt2e_quant/__init__.py b/neural_compressor/torch/algorithms/pt2e_quant/__init__.py index b3ed5d11fe3..b6187ba214a 100644 --- a/neural_compressor/torch/algorithms/pt2e_quant/__init__.py +++ b/neural_compressor/torch/algorithms/pt2e_quant/__init__.py @@ -13,4 +13,4 @@ # limitations under the License. -from neural_compressor.torch.algorithms.pt2e_quant.core import W8A8StaticQuantizer +from neural_compressor.torch.algorithms.pt2e_quant.core import W8A8PT2EQuantizer diff --git a/neural_compressor/torch/algorithms/pt2e_quant/core.py b/neural_compressor/torch/algorithms/pt2e_quant/core.py index c0983ee7aad..5608a29f150 100644 --- a/neural_compressor/torch/algorithms/pt2e_quant/core.py +++ b/neural_compressor/torch/algorithms/pt2e_quant/core.py @@ -18,9 +18,7 @@ from typing import Any -import torch import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq -from torch._export import capture_pre_autograd_graph from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer from torch.fx.graph_module import GraphModule @@ -30,15 +28,21 @@ from neural_compressor.torch.utils import create_xiq_quantizer_from_pt2e_config -class W8A8StaticQuantizer(Quantizer): +class W8A8PT2EQuantizer(Quantizer): + is_dynamic = False + + def __init__(self, quant_config=None): + super().__init__(quant_config) @staticmethod def update_quantizer_based_on_quant_config(quant_config=None) -> X86InductorQuantizer: if not quant_config: quantizer = X86InductorQuantizer() - quantizer.set_global(xiq.get_default_x86_inductor_quantization_config()) + quantizer.set_global( + xiq.get_default_x86_inductor_quantization_config(is_dynamic=W8A8PT2EQuantizer.is_dynamic) + ) else: - quantizer = create_xiq_quantizer_from_pt2e_config(quant_config) + quantizer = create_xiq_quantizer_from_pt2e_config(quant_config, is_dynamic=W8A8PT2EQuantizer.is_dynamic) return quantizer def prepare(self, model: GraphModule, example_inputs=None, inplace=True, *args, **kwargs) -> GraphModule: diff --git a/neural_compressor/torch/quantization/__init__.py b/neural_compressor/torch/quantization/__init__.py index e9b51a1a99a..b89bec51350 100644 --- a/neural_compressor/torch/quantization/__init__.py +++ b/neural_compressor/torch/quantization/__init__.py @@ -35,6 +35,8 @@ get_default_fp8_config, get_default_fp8_config_set, get_woq_tuning_config, + DynamicQuantConfig, + get_default_dynamic_config, ) from neural_compressor.torch.quantization.autotune import ( diff --git a/neural_compressor/torch/quantization/algorithm_entry.py b/neural_compressor/torch/quantization/algorithm_entry.py index 8ca5e222877..0334121b5ae 100644 --- a/neural_compressor/torch/quantization/algorithm_entry.py +++ b/neural_compressor/torch/quantization/algorithm_entry.py @@ -14,7 +14,7 @@ from copy import deepcopy from types import MethodType -from typing import Any, Callable, Dict, Tuple +from typing import Callable, Dict, Tuple import torch @@ -42,7 +42,7 @@ TEQConfig, ) from neural_compressor.torch.utils import get_quantizer, is_ipex_imported, logger, postprocess_model, register_algo -from neural_compressor.torch.utils.constants import PT2E_STATIC_QUANT +from neural_compressor.torch.utils.constants import PT2E_DYNAMIC_QUANT, PT2E_STATIC_QUANT ###################### RTN Algo Entry ################################## @@ -186,19 +186,39 @@ def static_quant_entry( return model +###################### PT2E Dynamic Quant Algo Entry ################################## +@register_algo(name=PT2E_DYNAMIC_QUANT) +@torch.no_grad() +def pt2e_dynamic_quant_entry(model: torch.nn.Module, configs_mapping, mode: Mode, *args, **kwargs) -> torch.nn.Module: + logger.info("Quantize model with the PT2E static quant algorithm.") + from neural_compressor.torch.algorithms.pt2e_quant.core import W8A8PT2EQuantizer + + run_fn = kwargs.get("run_fn", None) + example_inputs = kwargs.get("example_inputs", None) + inplace = kwargs.get("inplace", True) + W8A8PT2EQuantizer.is_dynamic = True + for _, quant_config in configs_mapping.items(): + if quant_config.name == PT2E_DYNAMIC_QUANT: + w8a8_quantizer = W8A8PT2EQuantizer(quant_config=quant_config) + model = w8a8_quantizer.execute( + model, mode=mode, run_fn=run_fn, example_inputs=example_inputs, inplace=inplace + ) + return model + + ###################### PT2E Static Quant Algo Entry ################################## @register_algo(name=PT2E_STATIC_QUANT) @torch.no_grad() def pt2e_static_quant_entry(model: torch.nn.Module, configs_mapping, mode: Mode, *args, **kwargs) -> torch.nn.Module: logger.info("Quantize model with the PT2E static quant algorithm.") - from neural_compressor.torch.algorithms.pt2e_quant.core import W8A8StaticQuantizer + from neural_compressor.torch.algorithms.pt2e_quant.core import W8A8PT2EQuantizer run_fn = kwargs.get("run_fn", None) example_inputs = kwargs.get("example_inputs", None) inplace = kwargs.get("inplace", True) for _, quant_config in configs_mapping.items(): if quant_config.name == STATIC_QUANT: - w8a8_quantizer = W8A8StaticQuantizer(quant_config=quant_config) + w8a8_quantizer = W8A8PT2EQuantizer(quant_config=quant_config) model = w8a8_quantizer.execute( model, mode=mode, run_fn=run_fn, example_inputs=example_inputs, inplace=inplace ) diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index 05f2d629a88..62b98b83a34 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -17,7 +17,7 @@ # pylint:disable=import-error from collections import OrderedDict -from typing import Any, Callable, Dict, List, NamedTuple, Optional +from typing import Callable, Dict, List, NamedTuple, Optional from typing import OrderedDict as OrderedDictType from typing import Tuple, Union @@ -50,6 +50,7 @@ PRIORITY_HQQ, PRIORITY_RTN, PRIORITY_TEQ, + PT2E_DYNAMIC_QUANT, ) __all__ = [ @@ -778,6 +779,80 @@ def get_default_AutoRound_config() -> AutoRoundConfig: return AutoRoundConfig() +######################## Dynamic Quant Config ############################### +@register_config(framework_name=FRAMEWORK_NAME, algo_name=PT2E_DYNAMIC_QUANT) +class DynamicQuantConfig(BaseConfig): + """Config class for dynamic quantization.""" + + name = PT2E_DYNAMIC_QUANT + params_list = [ + "w_dtype", + "w_sym", + "w_granularity", + "w_algo", + "act_dtype", + "act_sym", + "act_granularity", + "act_algo", + ] + supported_configs: List[OperatorConfig] = [] + + def __init__( + self, + w_dtype: str = "int8", + w_sym: bool = True, + w_granularity: str = "per_tensor", + w_algo: str = "minmax", + act_dtype: str = "uint8", + act_sym: bool = False, + act_granularity: str = "per_tensor", + act_algo: str = "kl", + white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST, + ): + """Init Dynamic Quant Configs.""" + super().__init__(white_list=white_list) + self.w_dtype = w_dtype + self.w_sym = w_sym + self.w_granularity = w_granularity + self.w_algo = w_algo + self.act_dtype = act_dtype + self.act_sym = act_sym + self.act_granularity = act_granularity + self.act_algo = act_algo + self._post_init() + + @classmethod + def register_supported_configs(cls) -> List[OperatorConfig]: + supported_configs = [] + linear_static_config = cls() + operators = [torch.nn.Linear] + supported_configs.append(OperatorConfig(config=linear_static_config, operators=operators)) + cls.supported_configs = supported_configs + + @staticmethod + def get_model_info(model: torch.nn.Module, example_inputs=None): + return None + + def to_config_mapping( + self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None + ) -> OrderedDictType[Union[str, str], OrderedDictType[str, BaseConfig]]: + config_mapping = OrderedDict({self.name: self}) + return config_mapping + + @classmethod + def get_config_set_for_tuning(cls) -> Union[None, "DynamicQuantConfig", List["DynamicQuantConfig"]]: + return cls(act_sym=[True, False], act_algo=["kl", "minmax"]) + + +def get_default_dynamic_config() -> DynamicQuantConfig: + """Generate the default dynamic quant config. + + Returns: + the default dynamic quant config. + """ + return DynamicQuantConfig() + + ######################## Static Quant Config ############################### @register_config(framework_name=FRAMEWORK_NAME, algo_name=STATIC_QUANT) class StaticQuantConfig(BaseConfig): diff --git a/neural_compressor/torch/utils/constants.py b/neural_compressor/torch/utils/constants.py index 54a68163ded..429851e311b 100644 --- a/neural_compressor/torch/utils/constants.py +++ b/neural_compressor/torch/utils/constants.py @@ -52,3 +52,4 @@ PT2E_STATIC_QUANT = "pt2e_static_quant" +PT2E_DYNAMIC_QUANT = "pt2e_dynamic_quant" diff --git a/neural_compressor/torch/utils/environ.py b/neural_compressor/torch/utils/environ.py index 3a98d963e09..f2906881040 100644 --- a/neural_compressor/torch/utils/environ.py +++ b/neural_compressor/torch/utils/environ.py @@ -31,20 +31,20 @@ def is_hpex_available(): return _hpex_available -try: - import intel_extension_for_pytorch as ipex - - _ipex_available = True -except: - _ipex_available = False - - def is_ipex_available(): + try: + import intel_extension_for_pytorch as ipex + + _ipex_available = True + except: + _ipex_available = False return _ipex_available def get_ipex_version(): - if _ipex_available: + if is_ipex_available(): + import intel_extension_for_pytorch as ipex + try: ipex_version = ipex.__version__.split("+")[0] except ValueError as e: # pragma: no cover diff --git a/neural_compressor/torch/utils/utility.py b/neural_compressor/torch/utils/utility.py index af0b4f8b79d..f01345fab5c 100644 --- a/neural_compressor/torch/utils/utility.py +++ b/neural_compressor/torch/utils/utility.py @@ -17,7 +17,7 @@ import torch import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq -from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver +from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver, PlaceholderObserver from torch.ao.quantization.quantizer import QuantizationSpec from torch.ao.quantization.quantizer.x86_inductor_quantizer import QuantizationConfig, X86InductorQuantizer from typing_extensions import TypeAlias @@ -172,30 +172,41 @@ def postprocess_model(model, mode, quantizer): del model.quantizer -def create_quant_spec_from_config(dtype, sym, granularity, algo) -> QuantizationSpec: +def create_quant_spec_from_config(dtype, sym, granularity, algo, is_dynamic=False) -> QuantizationSpec: dtype_mapping: Dict[str, torch.dtype] = {"int8": torch.int8, "uint8": torch.uint8} + select_dtype = dtype_mapping[dtype] + min_max_mapping = {torch.int8: (-128, 127), torch.uint8: (0, 255)} qscheme_mapping = { "per_channel": {True: torch.per_channel_symmetric, False: torch.per_tensor_affine}, "per_tensor": {True: torch.per_tensor_symmetric, False: torch.per_tensor_affine}, } observer_mapping = { + "placeholder": PlaceholderObserver, "minmax": MinMaxObserver, "kl": HistogramObserver, } + # Force to use placeholder observer for dynamic quantization + if is_dynamic: + algo = "placeholder" # algo observer_or_fake_quant_ctr = observer_mapping[algo] # qscheme qscheme = qscheme_mapping[granularity][sym] quantization_spec = QuantizationSpec( - dtype=dtype_mapping[dtype], observer_or_fake_quant_ctr=observer_or_fake_quant_ctr, qscheme=qscheme + dtype=select_dtype, + quant_min=min_max_mapping[select_dtype][0], + quant_max=min_max_mapping[select_dtype][1], + observer_or_fake_quant_ctr=observer_or_fake_quant_ctr, + qscheme=qscheme, + is_dynamic=is_dynamic, ) return quantization_spec -def _map_inc_config_to_torch_quant_config(inc_config) -> QuantizationConfig: - default_quant_config = xiq.get_default_x86_inductor_quantization_config() +def _map_inc_config_to_torch_quant_config(inc_config, is_dynamic=False) -> QuantizationConfig: + default_quant_config = xiq.get_default_x86_inductor_quantization_config(is_dynamic=is_dynamic) input_act_quant_spec = create_quant_spec_from_config( - inc_config.act_dtype, inc_config.act_sym, inc_config.act_granularity, inc_config.act_algo + inc_config.act_dtype, inc_config.act_sym, inc_config.act_granularity, inc_config.act_algo, is_dynamic=is_dynamic ) weight_quant_spec = create_quant_spec_from_config( inc_config.w_dtype, inc_config.w_sym, inc_config.w_granularity, inc_config.w_algo @@ -210,14 +221,14 @@ def _map_inc_config_to_torch_quant_config(inc_config) -> QuantizationConfig: return quant_config -def create_xiq_quantizer_from_pt2e_config(config) -> X86InductorQuantizer: +def create_xiq_quantizer_from_pt2e_config(config, is_dynamic=False) -> X86InductorQuantizer: quantizer = xiq.X86InductorQuantizer() # set global - global_config = _map_inc_config_to_torch_quant_config(config) + global_config = _map_inc_config_to_torch_quant_config(config, is_dynamic) quantizer.set_global(global_config) # set local for module_or_func_name, local_config in config.local_config.items(): - local_quant_config = _map_inc_config_to_torch_quant_config(local_config) + local_quant_config = _map_inc_config_to_torch_quant_config(local_config, is_dynamic) if isinstance(module_or_func_name, torch.nn.Module): quantizer.set_module_type_qconfig(module_or_func_name, local_quant_config) else: diff --git a/test/3x/torch/algorithms/pt2e_quant/test_pt2e_w8a8.py b/test/3x/torch/algorithms/pt2e_quant/test_pt2e_w8a8.py index 103d0e3350e..1f8bfb7a882 100644 --- a/test/3x/torch/algorithms/pt2e_quant/test_pt2e_w8a8.py +++ b/test/3x/torch/algorithms/pt2e_quant/test_pt2e_w8a8.py @@ -5,12 +5,12 @@ import torch from neural_compressor.common.utils import logger -from neural_compressor.torch.algorithms.pt2e_quant.core import W8A8StaticQuantizer +from neural_compressor.torch.algorithms.pt2e_quant.core import W8A8PT2EQuantizer from neural_compressor.torch.export import export_model_for_pt2e_quant from neural_compressor.torch.utils import TORCH_VERSION_2_2_2, get_torch_version -class TestW8A8StaticQuantizer: +class TestW8A8PT2EQuantizer: @staticmethod def get_toy_model(): @@ -52,7 +52,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: @pytest.mark.skipif(get_torch_version() <= TORCH_VERSION_2_2_2, reason="Requires torch>=2.3.0") def test_quantizer_on_simple_model(self): model, example_inputs = self.build_simple_torch_model_and_example_inputs() - w8a8_static_quantizer = W8A8StaticQuantizer() + w8a8_static_quantizer = W8A8PT2EQuantizer() # prepare prepare_model = w8a8_static_quantizer.prepare(model, example_inputs=example_inputs) # calibrate @@ -81,7 +81,7 @@ def test_quantizer_on_llm(self): model = export_model_for_pt2e_quant(model, example_inputs=example_inputs) quant_config = None - w8a8_static_quantizer = W8A8StaticQuantizer() + w8a8_static_quantizer = W8A8PT2EQuantizer() # prepare prepare_model = w8a8_static_quantizer.prepare(model) # calibrate diff --git a/test/3x/torch/quantization/test_pt2e_quant.py b/test/3x/torch/quantization/test_pt2e_quant.py index 23e56d7220b..cd2867f72ac 100644 --- a/test/3x/torch/quantization/test_pt2e_quant.py +++ b/test/3x/torch/quantization/test_pt2e_quant.py @@ -8,13 +8,25 @@ from neural_compressor.common.utils import logger from neural_compressor.torch.export import export from neural_compressor.torch.quantization import ( + DynamicQuantConfig, StaticQuantConfig, convert, + get_default_dynamic_config, get_default_static_config, prepare, quantize, ) -from neural_compressor.torch.utils import TORCH_VERSION_2_2_2, get_torch_version, is_ipex_imported +from neural_compressor.torch.utils import TORCH_VERSION_2_2_2, get_torch_version + + +@pytest.fixture +def force_not_import_ipex(monkeypatch): + def _is_ipex_imported(): + return False + + monkeypatch.setattr("neural_compressor.torch.quantization.config.is_ipex_imported", _is_ipex_imported) + monkeypatch.setattr("neural_compressor.torch.quantization.algorithm_entry.is_ipex_imported", _is_ipex_imported) + monkeypatch.setattr("neural_compressor.torch.export._export.is_ipex_imported", _is_ipex_imported) class TestPT2EQuantization: @@ -56,9 +68,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: exported_model = export(model, example_inputs=example_inputs) return exported_model, example_inputs - @pytest.mark.skipif(is_ipex_imported(), reason="IPEX is imported") @pytest.mark.skipif(get_torch_version() <= TORCH_VERSION_2_2_2, reason="Requires torch>=2.3.0") - def test_quantize_simple_model(self): + def test_quantize_simple_model(self, force_not_import_ipex): model, example_inputs = self.build_simple_torch_model_and_example_inputs() quant_config = None @@ -76,9 +87,9 @@ def calib_fn(model): logger.warning("out shape is %s", out.shape) assert out is not None - @pytest.mark.skipif(is_ipex_imported(), reason="IPEX is imported") @pytest.mark.skipif(get_torch_version() <= TORCH_VERSION_2_2_2, reason="Requires torch>=2.3.0") - def test_prepare_and_convert_on_simple_model(self): + @pytest.mark.parametrize("is_dynamic", [False, True]) + def test_prepare_and_convert_on_simple_model(self, is_dynamic, force_not_import_ipex): model, example_inputs = self.build_simple_torch_model_and_example_inputs() quant_config = None @@ -86,7 +97,10 @@ def calib_fn(model): for i in range(2): model(*example_inputs) - quant_config = get_default_static_config() + if is_dynamic: + quant_config = get_default_dynamic_config() + else: + quant_config = get_default_static_config() prepared_model = prepare(model, quant_config=quant_config) calib_fn(prepared_model) @@ -101,9 +115,8 @@ def calib_fn(model): logger.warning("out shape is %s", out.shape) assert out is not None - @pytest.mark.skipif(is_ipex_imported(), reason="IPEX is imported") @pytest.mark.skipif(get_torch_version() <= TORCH_VERSION_2_2_2, reason="Requires torch>=2.3.0") - def test_prepare_and_convert_on_llm(self): + def test_prepare_and_convert_on_llm(self, force_not_import_ipex): from transformers import AutoModelForCausalLM, AutoTokenizer # set TOKENIZERS_PARALLELISM to false