From 7120dd4909599b228692415732688b3d5e77206d Mon Sep 17 00:00:00 2001 From: Zixuan Cheng <110808245+violetch24@users.noreply.github.com> Date: Tue, 28 May 2024 16:00:31 +0800 Subject: [PATCH] bug fix for 3.x sq and static quant (#1823) Signed-off-by: Cheng, Zixuan --- .../torch/algorithms/smooth_quant/smooth_quant.py | 11 ++++++----- .../torch/algorithms/static_quant/static_quant.py | 11 ++++++----- neural_compressor/torch/quantization/quantize.py | 6 +++--- test/3x/torch/quantization/test_smooth_quant.py | 7 +++++++ test/3x/torch/quantization/test_static_quant.py | 7 ++++++- 5 files changed, 28 insertions(+), 14 deletions(-) diff --git a/neural_compressor/torch/algorithms/smooth_quant/smooth_quant.py b/neural_compressor/torch/algorithms/smooth_quant/smooth_quant.py index e9d6fde3524..f7bfc9369ce 100644 --- a/neural_compressor/torch/algorithms/smooth_quant/smooth_quant.py +++ b/neural_compressor/torch/algorithms/smooth_quant/smooth_quant.py @@ -389,8 +389,9 @@ def _ipex_post_quant_process(model, example_inputs, use_bf16, inplace=False): else: model = torch.jit.trace(model, example_inputs, strict=False) model = torch.jit.freeze(model.eval()) - # After freezing, run 1 time to warm up the profiling graph executor to insert prim::profile - # At the 2nd run, the llga pass will be triggered and the model is turned into - # an int8 model: prim::profile will be removed and will have LlgaFusionGroup in the graph - simple_inference(model, example_inputs, iterations=2) - return model + + # After freezing, run 1 time to warm up the profiling graph executor to insert prim::profile + # At the 2nd run, the llga pass will be triggered and the model is turned into + # an int8 model: prim::profile will be removed and will have LlgaFusionGroup in the graph + simple_inference(model, example_inputs, iterations=2) + return model diff --git a/neural_compressor/torch/algorithms/static_quant/static_quant.py b/neural_compressor/torch/algorithms/static_quant/static_quant.py index b9c476e9e80..7ebbf76f36c 100644 --- a/neural_compressor/torch/algorithms/static_quant/static_quant.py +++ b/neural_compressor/torch/algorithms/static_quant/static_quant.py @@ -176,8 +176,9 @@ def _ipex_post_quant_process(model, example_inputs, use_bf16, inplace=False): else: model = torch.jit.trace(model, example_inputs, strict=False) model = torch.jit.freeze(model.eval()) - # After freezing, run 1 time to warm up the profiling graph executor to insert prim::profile - # At the 2nd run, the llga pass will be triggered and the model is turned into - # an int8 model: prim::profile will be removed and will have LlgaFusionGroup in the graph - simple_inference(model, example_inputs, iterations=2) - return model + + # After freezing, run 1 time to warm up the profiling graph executor to insert prim::profile + # At the 2nd run, the llga pass will be triggered and the model is turned into + # an int8 model: prim::profile will be removed and will have LlgaFusionGroup in the graph + simple_inference(model, example_inputs, iterations=2) + return model diff --git a/neural_compressor/torch/quantization/quantize.py b/neural_compressor/torch/quantization/quantize.py index 8404befdc6f..d694123b359 100644 --- a/neural_compressor/torch/quantization/quantize.py +++ b/neural_compressor/torch/quantization/quantize.py @@ -71,10 +71,10 @@ def quantize( from neural_compressor.torch.algorithms.smooth_quant import TorchSmoothQuant sq = TorchSmoothQuant( - model, dataloader=None, example_inputs=example_inputs, q_func=run_fn, record_max_info=True + q_model, dataloader=None, example_inputs=example_inputs, q_func=run_fn, record_max_info=True ) - model.sq_info = sq - model = sq.transform( + q_model.sq_info = sq + q_model = sq.transform( alpha=quant_config.alpha, folding=quant_config.folding, auto_alpha_args=quant_config.auto_alpha_args, diff --git a/test/3x/torch/quantization/test_smooth_quant.py b/test/3x/torch/quantization/test_smooth_quant.py index 2967f386cb1..3e30825e865 100644 --- a/test/3x/torch/quantization/test_smooth_quant.py +++ b/test/3x/torch/quantization/test_smooth_quant.py @@ -193,6 +193,13 @@ def test_smooth_quant_mixed_precision(self): assert q_model is not None, "Quantization failed!" # quantize API + q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs) + assert q_model is not None, "Quantization failed!" + quant_config.excluded_precisions = ["bf16"] q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs) assert q_model is not None, "Quantization failed!" + + quant_config.folding = True + q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs) + assert q_model is not None, "Quantization failed!" diff --git a/test/3x/torch/quantization/test_static_quant.py b/test/3x/torch/quantization/test_static_quant.py index 072f4774e3e..fe13dff60ed 100644 --- a/test/3x/torch/quantization/test_static_quant.py +++ b/test/3x/torch/quantization/test_static_quant.py @@ -195,9 +195,14 @@ def test_static_quant_with_quantize_API(self): @pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX") def test_static_quant_mixed_precision(self): fp32_model = copy.deepcopy(self.fp32_model) + example_inputs = self.input quant_config = get_default_static_config() + prepared_model = prepare(fp32_model, quant_config=quant_config, example_inputs=example_inputs) + run_fn(prepared_model) + q_model = convert(prepared_model) + assert q_model is not None, "Quantization failed!" + quant_config.excluded_precisions = ["bf16"] - example_inputs = self.input prepared_model = prepare(fp32_model, quant_config=quant_config, example_inputs=example_inputs) run_fn(prepared_model) q_model = convert(prepared_model)