From 05272c48591567d0a1d36fe6cfe5c697d836887b Mon Sep 17 00:00:00 2001 From: Yi Liu Date: Tue, 3 Sep 2024 10:21:51 +0800 Subject: [PATCH] add per_channel_minmax (#1990) Signed-off-by: yiliu30 --- .../torch/algorithms/pt2e_quant/utility.py | 15 ++++++++++++--- test/3x/torch/quantization/test_pt2e_quant.py | 7 +++++-- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/neural_compressor/torch/algorithms/pt2e_quant/utility.py b/neural_compressor/torch/algorithms/pt2e_quant/utility.py index 966baf0e53a..e31efabf0a6 100644 --- a/neural_compressor/torch/algorithms/pt2e_quant/utility.py +++ b/neural_compressor/torch/algorithms/pt2e_quant/utility.py @@ -17,7 +17,12 @@ import torch import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq -from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver, PlaceholderObserver +from torch.ao.quantization.observer import ( + HistogramObserver, + MinMaxObserver, + PerChannelMinMaxObserver, + PlaceholderObserver, +) from torch.ao.quantization.quantizer import QuantizationSpec from torch.ao.quantization.quantizer.x86_inductor_quantizer import QuantizationConfig, X86InductorQuantizer @@ -48,12 +53,15 @@ def create_quant_spec_from_config(dtype, sym, granularity, algo, is_dynamic=Fals "placeholder": PlaceholderObserver, "minmax": MinMaxObserver, "kl": HistogramObserver, + "per_channel_minmax": PerChannelMinMaxObserver, } # Force to use placeholder observer for dynamic quantization if is_dynamic: algo = "placeholder" - # algo - observer_or_fake_quant_ctr = observer_mapping[algo] + if f"{granularity}_{algo}" in observer_mapping: + observer_or_fake_quant_ctr = observer_mapping[f"{granularity}_{algo}"] + else: + observer_or_fake_quant_ctr = observer_mapping[algo] # qscheme qscheme = qscheme_mapping[granularity][sym] quantization_spec = QuantizationSpec( @@ -61,6 +69,7 @@ def create_quant_spec_from_config(dtype, sym, granularity, algo, is_dynamic=Fals quant_min=min_max_mapping[select_dtype][0], quant_max=min_max_mapping[select_dtype][1], observer_or_fake_quant_ctr=observer_or_fake_quant_ctr, + ch_axis=0, qscheme=qscheme, is_dynamic=is_dynamic, ) diff --git a/test/3x/torch/quantization/test_pt2e_quant.py b/test/3x/torch/quantization/test_pt2e_quant.py index ab80e991203..2d3b6cabd94 100644 --- a/test/3x/torch/quantization/test_pt2e_quant.py +++ b/test/3x/torch/quantization/test_pt2e_quant.py @@ -98,7 +98,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return exported_model, example_inputs @pytest.mark.skipif(get_torch_version() <= TORCH_VERSION_2_2_2, reason="Requires torch>=2.3.0") - def test_quantize_simple_model(self, force_not_import_ipex): + @pytest.mark.parametrize("granularity", ["per_tensor", "per_channel"]) + def test_quantize_simple_model(self, granularity, force_not_import_ipex): + from neural_compressor.torch.quantization import StaticQuantConfig + model, example_inputs = self.build_simple_torch_model_and_example_inputs() float_model_output = model(*example_inputs) quant_config = None @@ -107,7 +110,7 @@ def calib_fn(model): for i in range(4): model(*example_inputs) - quant_config = get_default_static_config() + quant_config = StaticQuantConfig(w_granularity=granularity) q_model = quantize(model=model, quant_config=quant_config, run_fn=calib_fn) from torch._inductor import config