From d2031fdcabb8ca0267a541b8665d0887b9cdec77 Mon Sep 17 00:00:00 2001
From: yuwenzho <yuwen.zhou@intel.com>
Date: Wed, 8 May 2024 19:09:55 -0700
Subject: [PATCH 1/6] enhance torch 3.x algorithm entry

Signed-off-by: yuwenzho <yuwen.zhou@intel.com>
---
 .../torch/algorithms/weight_only/autoround.py |  11 +-
 .../torch/quantization/algorithm_entry.py     | 132 ++++++++----------
 neural_compressor/torch/utils/utility.py      |  36 +++++
 .../weight_only/test_autoround.py             |   2 +-
 4 files changed, 98 insertions(+), 83 deletions(-)

diff --git a/neural_compressor/torch/algorithms/weight_only/autoround.py b/neural_compressor/torch/algorithms/weight_only/autoround.py
index 59a06240d9d..65fdb397fe9 100644
--- a/neural_compressor/torch/algorithms/weight_only/autoround.py
+++ b/neural_compressor/torch/algorithms/weight_only/autoround.py
@@ -26,7 +26,7 @@
 class AutoRoundQuantizer(Quantizer):
     def __init__(
         self,
-        weight_config: dict = {},
+        quant_config: dict = {},
         enable_full_range: bool = False,
         batch_size: int = 8,
         amp: bool = True,
@@ -51,8 +51,8 @@ def __init__(
         """Init a AutQRoundQuantizer object.
 
         Args:
-        weight_config (dict): Configuration for weight quantization (default is an empty dictionary).
-        weight_config={
+        quant_config (dict): Configuration for weight quantization (default is an empty dictionary).
+        quant_config={
                     'layer1':##layer_name
                     {
                         'data_type': 'int',
@@ -89,9 +89,8 @@ def __init__(
         scale_dtype (str): The data type of quantization scale to be used (default is "float32"), different kernels
                             have different choices.
         """
-        super().__init__(weight_config)
+        super().__init__(quant_config)
         self.tokenizer = None
-        self.weight_config = weight_config
         self.enable_full_range = enable_full_range
         self.batch_size = batch_size
         self.amp = amp
@@ -125,7 +124,7 @@ def prepare(self, model: torch.nn.Module, *args, **kwargs):
         self.rounder = AutoRoundProcessor(
             model=model,
             tokenizer=None,
-            weight_config=self.weight_config,
+            weight_config=self.quant_config,
             enable_full_range=self.enable_full_range,
             batch_size=self.batch_size,
             amp=self.amp,
diff --git a/neural_compressor/torch/quantization/algorithm_entry.py b/neural_compressor/torch/quantization/algorithm_entry.py
index 41d2593c224..50679a1f45a 100644
--- a/neural_compressor/torch/quantization/algorithm_entry.py
+++ b/neural_compressor/torch/quantization/algorithm_entry.py
@@ -30,7 +30,13 @@
     StaticQuantConfig,
     TEQConfig,
 )
-from neural_compressor.torch.utils import Mode, logger, register_algo
+from neural_compressor.torch.utils import (
+    Mode,
+    logger,
+    register_algo,
+    preprocess_quantizer,
+    postprocess_model,
+)
 
 
 ###################### RTN Algo Entry ##################################
@@ -68,17 +74,9 @@ def rtn_entry(
             "double_quant_group_size": quant_config.double_quant_group_size,
         }
 
-    if getattr(model, "quantizer", False):
-        quantizer = model.quantizer
-    else:
-        quantizer = RTNQuantizer(quant_config=weight_config)
-
+    quantizer = preprocess_quantizer(model, quantizer_cls=RTNQuantizer, quant_config=weight_config)
     model = quantizer.execute(model, mode=mode)
-
-    if getattr(model, "quantizer", False):
-        del model.quantizer
-    else:
-        model.quantizer = quantizer
+    postprocess_model(model, mode, quantizer)
     return model
 
 
@@ -125,15 +123,12 @@ def gptq_entry(
     )
     kwargs.pop("example_inputs")
     logger.warning("lm_head in transformer model is skipped by GPTQ")
-    if getattr(model, "quantizer", False):
-        quantizer = model.quantizer
-    else:
-        quantizer = GPTQuantizer(quant_config=weight_config)
+
+    quantizer = preprocess_quantizer(
+        model, quantizer_cls=GPTQuantizer, quant_config=weight_config)
     model = quantizer.execute(model, mode=mode, *args, **kwargs)
-    if getattr(model, "quantizer", False):
-        del model.quantizer
-    else:
-        model.quantizer = quantizer
+    postprocess_model(model, mode, quantizer)
+
     return model
 
 
@@ -177,17 +172,11 @@ def static_quant_entry(
     inplace = kwargs.get("inplace", True)
     assert example_inputs is not None, "Please provide example_inputs for static quantization."
 
-    if getattr(model, "quantizer", False):
-        quantizer = model.quantizer
-    else:
-        quantizer = StaticQuantQuantizer(quant_config=quant_config_mapping)
-
+    quantizer = preprocess_quantizer(
+        model, quantizer_cls=StaticQuantQuantizer, quant_config=quant_config_mapping)
     model = quantizer.execute(model, mode=mode, run_fn=run_fn, example_inputs=example_inputs, inplace=inplace)
+    postprocess_model(model, mode, quantizer)
 
-    if getattr(model, "quantizer", False):
-        del model.quantizer
-    else:
-        model.quantizer = quantizer
     return model
 
 
@@ -301,11 +290,7 @@ def awq_quantize_entry(
     example_inputs = kwargs.get("example_inputs", None)
     assert example_inputs is not None, "Please provide example_inputs for AWQ quantization."
 
-    if getattr(model, "quantizer", False):
-        quantizer = model.quantizer
-    else:
-        quantizer = AWQQuantizer(quant_config=weight_config)
-
+    quantizer = preprocess_quantizer(model, quantizer_cls=AWQQuantizer, quant_config=weight_config)
     model = quantizer.execute(
         model,
         mode=mode,
@@ -318,11 +303,8 @@ def awq_quantize_entry(
         return_int=return_int,
         use_full_range=use_full_range,
     )
+    postprocess_model(model, mode, quantizer)
 
-    if getattr(model, "quantizer", False):
-        del model.quantizer
-    else:
-        model.quantizer = quantizer
     return model
 
 
@@ -364,10 +346,17 @@ def teq_quantize_entry(
             absorb_to_layer = quant_config.absorb_to_layer
             folding = quant_config.folding
     assert isinstance(model, torch.nn.Module), "only support torch module"
-    quantizer = TEQuantizer(
-        quant_config=weight_config, folding=folding, absorb_to_layer=absorb_to_layer, example_inputs=example_inputs
-    )
+
+    quantizer = preprocess_quantizer(
+        model,
+        quantizer_cls=TEQuantizer,
+        quant_config=weight_config,
+        folding=folding,
+        absorb_to_layer=absorb_to_layer,
+        example_inputs=example_inputs)
     model = quantizer.execute(model, mode=mode, run_fn=run_fn, example_inputs=example_inputs, inplace=inplace)
+    postprocess_model(model, mode, quantizer)
+
     return model
 
 
@@ -414,35 +403,32 @@ def autoround_quantize_entry(
             scale_dtype = quant_config.scale_dtype
 
     kwargs.pop("example_inputs")
-    if getattr(model, "quantizer", False):
-        quantizer = model.quantizer
-    else:
-        quantizer = AutoRoundQuantizer(
-            weight_config=weight_config,
-            enable_full_range=enable_full_range,
-            batch_size=batch_size,
-            lr_scheduler=lr_scheduler,
-            use_quant_input=use_quant_input,
-            enable_minmax_tuning=enable_minmax_tuning,
-            lr=lr,
-            minmax_lr=minmax_lr,
-            low_gpu_mem_usage=low_gpu_mem_usage,
-            iters=iters,
-            seqlen=seqlen,
-            n_samples=n_samples,
-            sampler=sampler,
-            seed=seed,
-            n_blocks=n_blocks,
-            gradient_accumulate_steps=gradient_accumulate_steps,
-            not_use_best_mse=not_use_best_mse,
-            dynamic_max_gap=dynamic_max_gap,
-            scale_dtype=scale_dtype,
-        )
+
+    quantizer = preprocess_quantizer(
+        model,
+        quantizer_cls=AutoRoundQuantizer,
+        quant_config=weight_config,
+        enable_full_range=enable_full_range,
+        batch_size=batch_size,
+        lr_scheduler=lr_scheduler,
+        use_quant_input=use_quant_input,
+        enable_minmax_tuning=enable_minmax_tuning,
+        lr=lr,
+        minmax_lr=minmax_lr,
+        low_gpu_mem_usage=low_gpu_mem_usage,
+        iters=iters,
+        seqlen=seqlen,
+        n_samples=n_samples,
+        sampler=sampler,
+        seed=seed,
+        n_blocks=n_blocks,
+        gradient_accumulate_steps=gradient_accumulate_steps,
+        not_use_best_mse=not_use_best_mse,
+        dynamic_max_gap=dynamic_max_gap,
+        scale_dtype=scale_dtype,)
     model = quantizer.execute(model=model, mode=mode, *args, **kwargs)
-    if getattr(model, "quantizer", False):
-        del model.quantizer
-    else:
-        model.quantizer = quantizer
+    postprocess_model(model, mode, quantizer)
+
     logger.info("AutoRound quantization done.")
     return model
 
@@ -460,17 +446,11 @@ def hqq_entry(
     from neural_compressor.torch.algorithms.weight_only.hqq import HQQuantizer
 
     logger.info("Quantize model with the HQQ algorithm.")
-    if getattr(model, "quantizer", False):
-        quantizer = model.quantizer
-    else:
-        quantizer = HQQuantizer(quant_config=configs_mapping)
 
+    quantizer = preprocess_quantizer(model, quantizer_cls=HQQuantizer, quant_config=configs_mapping)
     model = quantizer.execute(model, mode=mode)
+    postprocess_model(model, mode, quantizer)
 
-    if getattr(model, "quantizer", False):
-        del model.quantizer
-    else:
-        model.quantizer = quantizer
     return model
 
 
diff --git a/neural_compressor/torch/utils/utility.py b/neural_compressor/torch/utils/utility.py
index 135c4025c10..0678416e05e 100644
--- a/neural_compressor/torch/utils/utility.py
+++ b/neural_compressor/torch/utils/utility.py
@@ -131,3 +131,39 @@ class Mode(Enum):
     PREPARE = "prepare"
     CONVERT = "convert"
     QUANTIZE = "quantize"
+
+
+def preprocess_quantizer(model, quantizer_cls, quant_config=None, *args, **kwargs):
+    """Process quantizer.
+
+    Initialize a quantizer or get `quantizer` attribute from model.
+
+    Args:
+        model (torch.nn.Module): pytorch model.
+        quantizer_cls (Quantizer): quantizer class of a specific algorithm.
+        quant_config (dict, optional): Specifies how to apply the algorithm on the given model.
+            Defaults to None.
+
+    Returns:
+        quantizer object.
+    """
+    if not hasattr(model, "quantizer"):
+        quantizer = quantizer_cls(quant_config=quant_config, *args, **kwargs)
+        return quantizer
+    else:
+        return model.quantizer
+
+
+def postprocess_model(model, mode, quantizer):
+    """Process `quantizer` attribute of model according to current mode.
+
+    Args:
+        model (torch.nn.Module): pytorch model.
+        mode (Mode): The mode of current phase, including 'prepare', 'convert' and 'quantize'.
+        quantizer (Quantizer): quantizer object.
+    """
+    if mode == Mode.PREPARE:
+        model.quantizer = quantizer
+    elif mode == Mode.CONVERT or mode == Mode.QUANTIZE:
+        if getattr(model, "quantizer", False):
+            del model.quantizer
diff --git a/test/3x/torch/quantization/weight_only/test_autoround.py b/test/3x/torch/quantization/weight_only/test_autoround.py
index 615e1ce6419..26270f560fd 100644
--- a/test/3x/torch/quantization/weight_only/test_autoround.py
+++ b/test/3x/torch/quantization/weight_only/test_autoround.py
@@ -102,7 +102,7 @@ def test_quantizer(self):
                 "sym": False,
             }
         }
-        quantizer = AutoRoundQuantizer(weight_config=weight_config)
+        quantizer = AutoRoundQuantizer(quant_config=weight_config)
         fp32_model = gpt_j_model
 
         # quantizer execute

From 59e86b14beb1fecb7714531edc4fd1eb485d2731 Mon Sep 17 00:00:00 2001
From: "pre-commit-ci[bot]"
 <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Date: Thu, 9 May 2024 02:28:04 +0000
Subject: [PATCH 2/6] [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
---
 .../torch/quantization/algorithm_entry.py     | 20 +++++++------------
 1 file changed, 7 insertions(+), 13 deletions(-)

diff --git a/neural_compressor/torch/quantization/algorithm_entry.py b/neural_compressor/torch/quantization/algorithm_entry.py
index 50679a1f45a..9072335bf83 100644
--- a/neural_compressor/torch/quantization/algorithm_entry.py
+++ b/neural_compressor/torch/quantization/algorithm_entry.py
@@ -30,13 +30,7 @@
     StaticQuantConfig,
     TEQConfig,
 )
-from neural_compressor.torch.utils import (
-    Mode,
-    logger,
-    register_algo,
-    preprocess_quantizer,
-    postprocess_model,
-)
+from neural_compressor.torch.utils import Mode, logger, postprocess_model, preprocess_quantizer, register_algo
 
 
 ###################### RTN Algo Entry ##################################
@@ -124,8 +118,7 @@ def gptq_entry(
     kwargs.pop("example_inputs")
     logger.warning("lm_head in transformer model is skipped by GPTQ")
 
-    quantizer = preprocess_quantizer(
-        model, quantizer_cls=GPTQuantizer, quant_config=weight_config)
+    quantizer = preprocess_quantizer(model, quantizer_cls=GPTQuantizer, quant_config=weight_config)
     model = quantizer.execute(model, mode=mode, *args, **kwargs)
     postprocess_model(model, mode, quantizer)
 
@@ -172,8 +165,7 @@ def static_quant_entry(
     inplace = kwargs.get("inplace", True)
     assert example_inputs is not None, "Please provide example_inputs for static quantization."
 
-    quantizer = preprocess_quantizer(
-        model, quantizer_cls=StaticQuantQuantizer, quant_config=quant_config_mapping)
+    quantizer = preprocess_quantizer(model, quantizer_cls=StaticQuantQuantizer, quant_config=quant_config_mapping)
     model = quantizer.execute(model, mode=mode, run_fn=run_fn, example_inputs=example_inputs, inplace=inplace)
     postprocess_model(model, mode, quantizer)
 
@@ -353,7 +345,8 @@ def teq_quantize_entry(
         quant_config=weight_config,
         folding=folding,
         absorb_to_layer=absorb_to_layer,
-        example_inputs=example_inputs)
+        example_inputs=example_inputs,
+    )
     model = quantizer.execute(model, mode=mode, run_fn=run_fn, example_inputs=example_inputs, inplace=inplace)
     postprocess_model(model, mode, quantizer)
 
@@ -425,7 +418,8 @@ def autoround_quantize_entry(
         gradient_accumulate_steps=gradient_accumulate_steps,
         not_use_best_mse=not_use_best_mse,
         dynamic_max_gap=dynamic_max_gap,
-        scale_dtype=scale_dtype,)
+        scale_dtype=scale_dtype,
+    )
     model = quantizer.execute(model=model, mode=mode, *args, **kwargs)
     postprocess_model(model, mode, quantizer)
 

From 59ab67af310e19fbc407c44352b39954d3d260ba Mon Sep 17 00:00:00 2001
From: yuwenzho <yuwen.zhou@intel.com>
Date: Wed, 8 May 2024 20:25:16 -0700
Subject: [PATCH 3/6] fix ut

Signed-off-by: yuwenzho <yuwen.zhou@intel.com>
---
 test/3x/torch/quantization/weight_only/test_mixed_algos.py | 7 +++++--
 1 file changed, 5 insertions(+), 2 deletions(-)

diff --git a/test/3x/torch/quantization/weight_only/test_mixed_algos.py b/test/3x/torch/quantization/weight_only/test_mixed_algos.py
index b4789f6c5d9..d465f8cd9c3 100644
--- a/test/3x/torch/quantization/weight_only/test_mixed_algos.py
+++ b/test/3x/torch/quantization/weight_only/test_mixed_algos.py
@@ -10,8 +10,11 @@
 
 
 def run_fn(model):
-    model(torch.tensor([[10, 20, 30]], dtype=torch.long))
-    model(torch.tensor([[40, 50, 60]], dtype=torch.long))
+    # GPTQ uses ValueError to reduce computation when collecting input data of the first block
+    # It's special for UTs, no need to add this wrapper in examples.
+    with pytest.raises(ValueError):
+        model(torch.tensor([[10, 20, 30]], dtype=torch.long))
+        model(torch.tensor([[40, 50, 60]], dtype=torch.long))
 
 
 class TestMixedTwoAlgo:

From cc79eeffa6c03454c3afba66a11303f1b597f7be Mon Sep 17 00:00:00 2001
From: yuwenzho <yuwen.zhou@intel.com>
Date: Wed, 8 May 2024 23:20:36 -0700
Subject: [PATCH 4/6] enhance code

Signed-off-by: yuwenzho <yuwen.zhou@intel.com>
---
 .../torch/algorithms/base_algorithm.py           |  3 +--
 .../torch/algorithms/weight_only/autoround.py    |  6 +++---
 .../torch/quantization/algorithm_entry.py        | 16 ++++++++--------
 neural_compressor/torch/quantization/quantize.py |  6 ++----
 neural_compressor/torch/utils/utility.py         | 11 ++++++++---
 5 files changed, 22 insertions(+), 20 deletions(-)

diff --git a/neural_compressor/torch/algorithms/base_algorithm.py b/neural_compressor/torch/algorithms/base_algorithm.py
index dd6216079c4..c458c210e33 100644
--- a/neural_compressor/torch/algorithms/base_algorithm.py
+++ b/neural_compressor/torch/algorithms/base_algorithm.py
@@ -99,14 +99,13 @@ def quantize(self, model: torch.nn.Module, *args: Any, **kwargs: Any):
 
         return model
 
-    def execute(self, model: torch.nn.Module, mode, *args: Any, **kwargs: Any):  # pragma: no cover
+    def execute(self, model: torch.nn.Module, mode, *args: Any, **kwargs: Any):
         """Execute according to mode.
 
         Args:
             model (torch.nn.Module): The model to be executed.
             mode (Mode): The mode of current phase, including 'prepare', 'convert' and 'quantize'.
         """
-        # TODO: remove '# pragma: no cover' once CI test can cover this function
         if mode == Mode.PREPARE:
             model = self.prepare(model, *args, **kwargs)
         elif mode == Mode.CONVERT:
diff --git a/neural_compressor/torch/algorithms/weight_only/autoround.py b/neural_compressor/torch/algorithms/weight_only/autoround.py
index 65fdb397fe9..5ff78a3413d 100644
--- a/neural_compressor/torch/algorithms/weight_only/autoround.py
+++ b/neural_compressor/torch/algorithms/weight_only/autoround.py
@@ -26,7 +26,7 @@
 class AutoRoundQuantizer(Quantizer):
     def __init__(
         self,
-        quant_config: dict = {},
+        quant_config: dict = None,
         enable_full_range: bool = False,
         batch_size: int = 8,
         amp: bool = True,
@@ -51,7 +51,7 @@ def __init__(
         """Init a AutQRoundQuantizer object.
 
         Args:
-        quant_config (dict): Configuration for weight quantization (default is an empty dictionary).
+        quant_config (dict): Configuration for weight quantization (default is None).
         quant_config={
                     'layer1':##layer_name
                     {
@@ -124,7 +124,7 @@ def prepare(self, model: torch.nn.Module, *args, **kwargs):
         self.rounder = AutoRoundProcessor(
             model=model,
             tokenizer=None,
-            weight_config=self.quant_config,
+            weight_config=self.quant_config or {},
             enable_full_range=self.enable_full_range,
             batch_size=self.batch_size,
             amp=self.amp,
diff --git a/neural_compressor/torch/quantization/algorithm_entry.py b/neural_compressor/torch/quantization/algorithm_entry.py
index 9072335bf83..b9c53a6cce0 100644
--- a/neural_compressor/torch/quantization/algorithm_entry.py
+++ b/neural_compressor/torch/quantization/algorithm_entry.py
@@ -30,7 +30,7 @@
     StaticQuantConfig,
     TEQConfig,
 )
-from neural_compressor.torch.utils import Mode, logger, postprocess_model, preprocess_quantizer, register_algo
+from neural_compressor.torch.utils import Mode, logger, postprocess_model, get_quantizer, register_algo
 
 
 ###################### RTN Algo Entry ##################################
@@ -68,7 +68,7 @@ def rtn_entry(
             "double_quant_group_size": quant_config.double_quant_group_size,
         }
 
-    quantizer = preprocess_quantizer(model, quantizer_cls=RTNQuantizer, quant_config=weight_config)
+    quantizer = get_quantizer(model, quantizer_cls=RTNQuantizer, quant_config=weight_config)
     model = quantizer.execute(model, mode=mode)
     postprocess_model(model, mode, quantizer)
     return model
@@ -118,7 +118,7 @@ def gptq_entry(
     kwargs.pop("example_inputs")
     logger.warning("lm_head in transformer model is skipped by GPTQ")
 
-    quantizer = preprocess_quantizer(model, quantizer_cls=GPTQuantizer, quant_config=weight_config)
+    quantizer = get_quantizer(model, quantizer_cls=GPTQuantizer, quant_config=weight_config)
     model = quantizer.execute(model, mode=mode, *args, **kwargs)
     postprocess_model(model, mode, quantizer)
 
@@ -165,7 +165,7 @@ def static_quant_entry(
     inplace = kwargs.get("inplace", True)
     assert example_inputs is not None, "Please provide example_inputs for static quantization."
 
-    quantizer = preprocess_quantizer(model, quantizer_cls=StaticQuantQuantizer, quant_config=quant_config_mapping)
+    quantizer = get_quantizer(model, quantizer_cls=StaticQuantQuantizer, quant_config=quant_config_mapping)
     model = quantizer.execute(model, mode=mode, run_fn=run_fn, example_inputs=example_inputs, inplace=inplace)
     postprocess_model(model, mode, quantizer)
 
@@ -282,7 +282,7 @@ def awq_quantize_entry(
     example_inputs = kwargs.get("example_inputs", None)
     assert example_inputs is not None, "Please provide example_inputs for AWQ quantization."
 
-    quantizer = preprocess_quantizer(model, quantizer_cls=AWQQuantizer, quant_config=weight_config)
+    quantizer = get_quantizer(model, quantizer_cls=AWQQuantizer, quant_config=weight_config)
     model = quantizer.execute(
         model,
         mode=mode,
@@ -339,7 +339,7 @@ def teq_quantize_entry(
             folding = quant_config.folding
     assert isinstance(model, torch.nn.Module), "only support torch module"
 
-    quantizer = preprocess_quantizer(
+    quantizer = get_quantizer(
         model,
         quantizer_cls=TEQuantizer,
         quant_config=weight_config,
@@ -397,7 +397,7 @@ def autoround_quantize_entry(
 
     kwargs.pop("example_inputs")
 
-    quantizer = preprocess_quantizer(
+    quantizer = get_quantizer(
         model,
         quantizer_cls=AutoRoundQuantizer,
         quant_config=weight_config,
@@ -441,7 +441,7 @@ def hqq_entry(
 
     logger.info("Quantize model with the HQQ algorithm.")
 
-    quantizer = preprocess_quantizer(model, quantizer_cls=HQQuantizer, quant_config=configs_mapping)
+    quantizer = get_quantizer(model, quantizer_cls=HQQuantizer, quant_config=configs_mapping)
     model = quantizer.execute(model, mode=mode)
     postprocess_model(model, mode, quantizer)
 
diff --git a/neural_compressor/torch/quantization/quantize.py b/neural_compressor/torch/quantization/quantize.py
index 7b3db8eca72..4d27ac263d6 100644
--- a/neural_compressor/torch/quantization/quantize.py
+++ b/neural_compressor/torch/quantization/quantize.py
@@ -91,7 +91,7 @@ def prepare(
     quant_config: BaseConfig,
     inplace: bool = True,
     example_inputs: Any = None,
-):  # pragma: no cover
+):
     """Prepare the model for calibration.
 
     Insert observers into the model so that it can monitor the input and output tensors during calibration.
@@ -105,7 +105,6 @@ def prepare(
     Returns:
         prepared and calibrated module.
     """
-    # TODO: remove '# pragma: no cover' once CI test can cover this function
     prepared_model = model if inplace else copy.deepcopy(model)
     registered_configs = config_registry.get_cls_configs()
     if isinstance(quant_config, dict):
@@ -148,7 +147,7 @@ def convert(
     model: torch.nn.Module,
     quant_config: BaseConfig = None,
     inplace: bool = True,
-):  # pragma: no cover
+):
     """Convert the prepared model to a quantized model.
 
     Args:
@@ -159,7 +158,6 @@ def convert(
     Returns:
         The quantized model.
     """
-    # TODO: remove '# pragma: no cover' once CI test can cover this function
     q_model = model if inplace else copy.deepcopy(model)
 
     # TODO: Optimize the check for prepared flag after adding HQT FP8 Quant
diff --git a/neural_compressor/torch/utils/utility.py b/neural_compressor/torch/utils/utility.py
index 0678416e05e..04a27aa0684 100644
--- a/neural_compressor/torch/utils/utility.py
+++ b/neural_compressor/torch/utils/utility.py
@@ -133,8 +133,8 @@ class Mode(Enum):
     QUANTIZE = "quantize"
 
 
-def preprocess_quantizer(model, quantizer_cls, quant_config=None, *args, **kwargs):
-    """Process quantizer.
+def get_quantizer(model, quantizer_cls, quant_config=None, *args, **kwargs):
+    """Get the quantizer.
 
     Initialize a quantizer or get `quantizer` attribute from model.
 
@@ -155,7 +155,12 @@ def preprocess_quantizer(model, quantizer_cls, quant_config=None, *args, **kwarg
 
 
 def postprocess_model(model, mode, quantizer):
-    """Process `quantizer` attribute of model according to current mode.
+    """Process `quantizer` attribute of model according to current phase.
+
+    In `prepare` phase, the `quantizer` is set as an attribute of the model
+    to avoid redundant initialization during `convert` phase.
+
+    In 'convert' or 'quantize' phase, the unused `quantizer` attribute is removed.
 
     Args:
         model (torch.nn.Module): pytorch model.

From 7f3fdd8b075a86178420b156757288f0ea038e1e Mon Sep 17 00:00:00 2001
From: "pre-commit-ci[bot]"
 <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Date: Thu, 9 May 2024 06:20:48 +0000
Subject: [PATCH 5/6] [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
---
 neural_compressor/torch/quantization/algorithm_entry.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/neural_compressor/torch/quantization/algorithm_entry.py b/neural_compressor/torch/quantization/algorithm_entry.py
index b9c53a6cce0..9f6c37232df 100644
--- a/neural_compressor/torch/quantization/algorithm_entry.py
+++ b/neural_compressor/torch/quantization/algorithm_entry.py
@@ -30,7 +30,7 @@
     StaticQuantConfig,
     TEQConfig,
 )
-from neural_compressor.torch.utils import Mode, logger, postprocess_model, get_quantizer, register_algo
+from neural_compressor.torch.utils import Mode, get_quantizer, logger, postprocess_model, register_algo
 
 
 ###################### RTN Algo Entry ##################################

From e86c1641d92a15981b00378ab281f350068b9c2c Mon Sep 17 00:00:00 2001
From: "pre-commit-ci[bot]"
 <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Date: Thu, 9 May 2024 07:20:45 +0000
Subject: [PATCH 6/6] [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
---
 neural_compressor/torch/quantization/algorithm_entry.py | 4 ++--
 neural_compressor/torch/utils/utility.py                | 2 ++
 2 files changed, 4 insertions(+), 2 deletions(-)

diff --git a/neural_compressor/torch/quantization/algorithm_entry.py b/neural_compressor/torch/quantization/algorithm_entry.py
index 3d2e089fa91..44c5423a34d 100644
--- a/neural_compressor/torch/quantization/algorithm_entry.py
+++ b/neural_compressor/torch/quantization/algorithm_entry.py
@@ -32,11 +32,11 @@
 )
 from neural_compressor.torch.utils import (
     Mode,
+    get_quantizer,
     is_ipex_imported,
     logger,
+    postprocess_model,
     register_algo,
-    get_quantizer,
-    postprocess_model
 )
 from neural_compressor.torch.utils.constants import PT2E_STATIC_QUANT
 
diff --git a/neural_compressor/torch/utils/utility.py b/neural_compressor/torch/utils/utility.py
index f6542793ee8..3ea196ccdfb 100644
--- a/neural_compressor/torch/utils/utility.py
+++ b/neural_compressor/torch/utils/utility.py
@@ -176,6 +176,8 @@ def postprocess_model(model, mode, quantizer):
     elif mode == Mode.CONVERT or mode == Mode.QUANTIZE:
         if getattr(model, "quantizer", False):
             del model.quantizer
+
+
 def create_quant_spec_from_config(dtype, sym, granularity, algo) -> QuantizationSpec:
     dtype_mapping: Dict[str, torch.dtype] = {"int8": torch.int8, "uint8": torch.uint8}
     qscheme_mapping = {