Skip to content

Commit

Permalink
add per_channel_minmax (#1990)
Browse files Browse the repository at this point in the history
Signed-off-by: yiliu30 <[email protected]>
  • Loading branch information
yiliu30 authored Sep 3, 2024
1 parent 82d8c06 commit 05272c4
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
15 changes: 12 additions & 3 deletions neural_compressor/torch/algorithms/pt2e_quant/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@

import torch
import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver, PlaceholderObserver
from torch.ao.quantization.observer import (
HistogramObserver,
MinMaxObserver,
PerChannelMinMaxObserver,
PlaceholderObserver,
)
from torch.ao.quantization.quantizer import QuantizationSpec
from torch.ao.quantization.quantizer.x86_inductor_quantizer import QuantizationConfig, X86InductorQuantizer

Expand Down Expand Up @@ -48,19 +53,23 @@ def create_quant_spec_from_config(dtype, sym, granularity, algo, is_dynamic=Fals
"placeholder": PlaceholderObserver,
"minmax": MinMaxObserver,
"kl": HistogramObserver,
"per_channel_minmax": PerChannelMinMaxObserver,
}
# Force to use placeholder observer for dynamic quantization
if is_dynamic:
algo = "placeholder"
# algo
observer_or_fake_quant_ctr = observer_mapping[algo]
if f"{granularity}_{algo}" in observer_mapping:
observer_or_fake_quant_ctr = observer_mapping[f"{granularity}_{algo}"]
else:
observer_or_fake_quant_ctr = observer_mapping[algo]
# qscheme
qscheme = qscheme_mapping[granularity][sym]
quantization_spec = QuantizationSpec(
dtype=select_dtype,
quant_min=min_max_mapping[select_dtype][0],
quant_max=min_max_mapping[select_dtype][1],
observer_or_fake_quant_ctr=observer_or_fake_quant_ctr,
ch_axis=0,
qscheme=qscheme,
is_dynamic=is_dynamic,
)
Expand Down
7 changes: 5 additions & 2 deletions test/3x/torch/quantization/test_pt2e_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return exported_model, example_inputs

@pytest.mark.skipif(get_torch_version() <= TORCH_VERSION_2_2_2, reason="Requires torch>=2.3.0")
def test_quantize_simple_model(self, force_not_import_ipex):
@pytest.mark.parametrize("granularity", ["per_tensor", "per_channel"])
def test_quantize_simple_model(self, granularity, force_not_import_ipex):
from neural_compressor.torch.quantization import StaticQuantConfig

model, example_inputs = self.build_simple_torch_model_and_example_inputs()
float_model_output = model(*example_inputs)
quant_config = None
Expand All @@ -107,7 +110,7 @@ def calib_fn(model):
for i in range(4):
model(*example_inputs)

quant_config = get_default_static_config()
quant_config = StaticQuantConfig(w_granularity=granularity)
q_model = quantize(model=model, quant_config=quant_config, run_fn=calib_fn)
from torch._inductor import config

Expand Down

0 comments on commit 05272c4

Please sign in to comment.