From 72fbce4b34f29c2b6fe0d41a76c4d65edb08719a Mon Sep 17 00:00:00 2001 From: Zixuan Cheng <110808245+violetch24@users.noreply.github.com> Date: Mon, 20 May 2024 13:50:41 +0800 Subject: [PATCH] Smoothquant refactor for 3.x API (#1792) Signed-off-by: Cheng, Zixuan --- .../quantization/llm/run_clm_no_trainer.py | 18 +- .../torch/algorithms/base_algorithm.py | 14 +- .../torch/algorithms/smooth_quant/__init__.py | 2 +- .../algorithms/smooth_quant/smooth_quant.py | 264 +++++++++++++----- .../torch/algorithms/smooth_quant/utility.py | 133 ++++++++- .../torch/algorithms/static_quant/utility.py | 4 +- .../torch/quantization/algorithm_entry.py | 25 +- .../torch/quantization/config.py | 12 +- .../torch/quantization/quantize.py | 14 + .../torch/quantization/test_smooth_quant.py | 58 ++-- test/3x/torch/requirements.txt | 1 + 11 files changed, 408 insertions(+), 137 deletions(-) diff --git a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_clm_no_trainer.py b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_clm_no_trainer.py index 562b5215dd6..5ebae5fc020 100644 --- a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_clm_no_trainer.py +++ b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_clm_no_trainer.py @@ -361,20 +361,12 @@ def run_fn(model): from utils import get_example_inputs example_inputs = get_example_inputs(user_model, calib_dataloader) - if args.sq: - # currently, smooth quant only support quantize API - # TODO: support prepare/convert API for smooth quant - from neural_compressor.torch.quantization import quantize - user_model = quantize( - model=user_model, quant_config=quant_config, example_inputs=example_inputs, run_fn=run_fn - ) - else: - from neural_compressor.torch.quantization import prepare, convert - - user_model = prepare(model=user_model, quant_config=quant_config, example_inputs=example_inputs) - run_fn(user_model) - user_model = convert(user_model) + from neural_compressor.torch.quantization import prepare, convert + user_model = prepare(model=user_model, quant_config=quant_config, example_inputs=example_inputs) + run_fn(user_model) + user_model = convert(user_model) + user_model.save(args.output_dir) diff --git a/neural_compressor/torch/algorithms/base_algorithm.py b/neural_compressor/torch/algorithms/base_algorithm.py index 50a8d189233..48e1b390db4 100644 --- a/neural_compressor/torch/algorithms/base_algorithm.py +++ b/neural_compressor/torch/algorithms/base_algorithm.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy from abc import ABC, abstractmethod -from collections import OrderedDict from typing import Any, Optional import torch @@ -111,5 +111,15 @@ def execute(self, model: torch.nn.Module, mode, *args: Any, **kwargs: Any): elif mode == Mode.CONVERT: model = self.convert(model, *args, **kwargs) elif mode == Mode.QUANTIZE: - model = self.quantize(model, *args, **kwargs) + if not isinstance(self.quant_config, dict): + user_cfg = copy.deepcopy(self.quant_config).to_dict() + else: + user_cfg = copy.deepcopy(self.quant_config) + if "recipe_cfgs" in user_cfg: # keep quantize API for smoothquant + run_fn = kwargs.get("run_fn", None) + example_inputs = kwargs.get("example_inputs", None) + inplace = kwargs.get("inplace", True) + model = self.quantize(model, self.quant_config, run_fn, example_inputs, inplace) + else: + model = self.quantize(model, *args, **kwargs) return model diff --git a/neural_compressor/torch/algorithms/smooth_quant/__init__.py b/neural_compressor/torch/algorithms/smooth_quant/__init__.py index 13074a77fc0..bb420d9b673 100644 --- a/neural_compressor/torch/algorithms/smooth_quant/__init__.py +++ b/neural_compressor/torch/algorithms/smooth_quant/__init__.py @@ -14,5 +14,5 @@ # limitations under the License. from .utility import * -from .smooth_quant import smooth_quantize +from .smooth_quant import SmoothQuantQuantizer from .save_load import save, load, recover_model_from_json diff --git a/neural_compressor/torch/algorithms/smooth_quant/smooth_quant.py b/neural_compressor/torch/algorithms/smooth_quant/smooth_quant.py index e49d1bfbab8..f2b2cdf8542 100644 --- a/neural_compressor/torch/algorithms/smooth_quant/smooth_quant.py +++ b/neural_compressor/torch/algorithms/smooth_quant/smooth_quant.py @@ -21,11 +21,16 @@ try: import intel_extension_for_pytorch as ipex -except: +except: # pragma: no cover assert False, "Please install IPEX for smooth quantization." +from collections import OrderedDict +from types import MethodType + from packaging.version import Version +from neural_compressor.torch.algorithms import Quantizer + from .utility import ( TorchSmoothQuant, cfg_to_qconfig, @@ -41,88 +46,199 @@ ipex_ver = get_ipex_version() -def smooth_quantize(model, tune_cfg, run_fn, example_inputs, inplace=True): - """Execute the quantize process on the specified model. +class SmoothQuantQuantizer(Quantizer): + def __init__(self, quant_config: OrderedDict = {}): + """Init a SmoothQuantQuantizer object. - Args: - model: a float model to be quantized. - tune_cfg: quantization config for ops. - run_fn: a calibration function for calibrating the model. - example_inputs: used to trace torch model. - inplace: whether to carry out model transformations in-place. + Args: + quant_config (OrderedDict, optional): quantization config for ops. Defaults to {}. + """ + super().__init__(quant_config) - Returns: - A quantized model. - """ - assert not ipex_ver.release < Version("2.1").release, "IPEX version >= 2.1 is required for SmoothQuant." + def prepare(self, model, example_inputs, inplace=True, *args, **kwargs): + """Prepares a given model for quantization. + + Args: + model: A float model to be quantized. + example_inputs: Used to trace torch model. + inplace: Whether to carry out model transformations in-place. Defaults to True. + + Returns: + A prepared model. + """ + assert example_inputs is not None, "Please provide example_inputs for smooth quantization." + assert not ipex_ver.release < Version("2.1").release, "IPEX version >= 2.1 is required for SmoothQuant." + + # Note: we should make sure smoothquant is only executed once with inplacing fp32 model. + if hasattr(model, "_smoothquant_optimized") and model._smoothquant_optimized: # pragma: no cover + logger.info("The model is already optimized by SmoothQuant algorithm, skip it.") + return model + + cfgs, op_infos_from_cfgs, output_tensor_id_op_name = ( + model.cfgs, + model.op_infos_from_cfgs, + model.output_tensor_id_op_name, + ) + + # Update json file in ipex_config_path + cfg_to_qconfig(self.quant_config, cfgs, op_infos_from_cfgs, output_tensor_id_op_name) + model.eval() + + # check smoothquant alpha and act_algo value + recipe_cfgs = self.quant_config.get("recipe_cfgs", None) + alpha = recipe_cfgs["smooth_quant_args"]["alpha"] + for op, _ in self.quant_config["op"].items(): + act_algo = self.quant_config["op"][op]["activation"]["algorithm"] - _, cfgs, op_infos_from_cfgs, output_tensor_id_op_name, _ = get_quantizable_ops_recursively(model, example_inputs) + # Check save_qconf_summary part is a workaround for IPEX bug. + # Sometimes the prepared model from get_op_capablitiy loss this attribute. + if not hasattr(model, "save_qconf_summary") or not hasattr(model, "load_qconf_summary"): + from torch.ao.quantization.observer import MinMaxObserver - # check smoothquant folding value - recipe_cfgs = tune_cfg.get("recipe_cfgs", None) - if "smooth_quant_args" in recipe_cfgs and "folding" in recipe_cfgs["smooth_quant_args"]: - if recipe_cfgs["smooth_quant_args"]["folding"] is None: - if ipex_ver.release < Version("2.1").release: - folding = True + if ipex_ver.release >= Version("2.1.1").release: + static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping( + alpha=alpha, act_observer=MinMaxObserver + ) + else: # pragma: no cover + if act_algo == "minmax": + static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping( + alpha=alpha, act_observer=MinMaxObserver() + ) + logger.warning( + "The int8 model accuracy will be close to 0 with MinMaxobserver, " + + "the suggested IPEX version is higher or equal than 2.1.100+cpu." + ) + else: + static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=alpha) + + if isinstance(example_inputs, dict): + model = ipex.quantization.prepare( + model, static_qconfig, example_kwarg_inputs=example_inputs, inplace=inplace + ) else: - folding = False - else: - folding = recipe_cfgs["smooth_quant_args"]["folding"] + model = ipex.quantization.prepare(model, static_qconfig, example_inputs=example_inputs, inplace=inplace) - # Note: we should make sure smoothquant is only executed once with inplacing fp32 model. - if hasattr(model, "_smoothquant_optimized") and model._smoothquant_optimized: - logger.info("The model is already optimized by SmoothQuant algorithm, skip it.") + cfg_to_qconfig(self.quant_config, cfgs, op_infos_from_cfgs, output_tensor_id_op_name, smooth_quant=True) + model.load_qconf_summary(qconf_summary=ipex_config_path) return model - sq = TorchSmoothQuant(model, dataloader=None, example_inputs=example_inputs, q_func=run_fn, record_max_info=True) - model = sq.transform( - alpha=recipe_cfgs["smooth_quant_args"]["alpha"], - folding=folding, - auto_alpha_args=recipe_cfgs["smooth_quant_args"]["auto_alpha_args"], - scale_sharing=recipe_cfgs["smooth_quant_args"]["scale_sharing"], - ) - - # Update model parameter when smoothquant folding = False - if recipe_cfgs and recipe_cfgs.get("smooth_quant", False) and not folding: - return qdq_quantize( - model, tune_cfg, run_fn, example_inputs, inplace, cfgs, op_infos_from_cfgs, output_tensor_id_op_name, sq - ) + def convert(self, model, example_inputs, inplace=True, *args, **kwargs): + """Converts a prepared model to a quantized model. - # Update model parameter when smoothquant folding = True - if recipe_cfgs and recipe_cfgs.get("smooth_quant", False) and folding: - _apply_pre_optimization(model, tune_cfg, sq) - model.eval() + Args: + model: The prepared model to be converted. + example_inputs: Used to trace torch model. + inplace: Whether to carry out model transformations in-place. Defaults to True. - # Check save_qconf_summary part is a workaround for IPEX bug. - # Sometimes the prepared model from get_op_capablitiy loss this attribute - if not hasattr(model, "save_qconf_summary") or not hasattr(model, "load_qconf_summary"): - static_qconfig = ipex.quantization.default_static_qconfig_mapping - if isinstance(example_inputs, dict): - model = ipex.quantization.prepare( - model, static_qconfig, example_kwarg_inputs=example_inputs, inplace=inplace + Returns: + A quantized model. + """ + model.save_qconf_summary(qconf_summary=ipex_config_path) + model = _ipex_post_quant_process(model, example_inputs, inplace=inplace) + + with open(ipex_config_path, "r") as f: + model.tune_cfg = json.load(f) + model.ipex_config_path = ipex_config_path + dump_model_op_stats(self.quant_config["op"]) + + from neural_compressor.torch.algorithms.smooth_quant import save + + logger.info("Smooth quantization done.") + model.ori_save = model.save + model.save = MethodType(save, model) + return model + + def quantize(self, model, tune_cfg, run_fn, example_inputs, inplace=True, *args, **kwargs): + """Execute the quantize process on the specified model. + + Args: + model: a float model to be quantized. + tune_cfg: quantization config for ops. + run_fn: a calibration function for calibrating the model. + example_inputs: used to trace torch model. + inplace: whether to carry out model transformations in-place. + + Returns: + A quantized model. + """ + assert not ipex_ver.release < Version("2.1").release, "IPEX version >= 2.1 is required for SmoothQuant." + + cfgs, op_infos_from_cfgs, output_tensor_id_op_name = ( + model.cfgs, + model.op_infos_from_cfgs, + model.output_tensor_id_op_name, + ) + + # check smoothquant folding value + recipe_cfgs = tune_cfg.get("recipe_cfgs", None) + if "smooth_quant_args" in recipe_cfgs and "folding" in recipe_cfgs["smooth_quant_args"]: + if recipe_cfgs["smooth_quant_args"]["folding"] is None: # pragma: no cover + if ipex_ver.release < Version("2.1").release: + folding = True + else: + folding = False + else: + folding = recipe_cfgs["smooth_quant_args"]["folding"] + + # Note: we should make sure smoothquant is only executed once with inplacing fp32 model. + if hasattr(model, "_smoothquant_optimized") and model._smoothquant_optimized: # pragma: no cover + logger.info("The model is already optimized by SmoothQuant algorithm, skip it.") + return model + + sq_info = model.sq_info + + # Update model parameter when smoothquant folding = False + if recipe_cfgs and recipe_cfgs.get("smooth_quant", False) and not folding: + return qdq_quantize( + model, + tune_cfg, + run_fn, + example_inputs, + inplace, + cfgs, + op_infos_from_cfgs, + output_tensor_id_op_name, + sq_info, ) - else: - model = ipex.quantization.prepare(model, static_qconfig, example_inputs=example_inputs, inplace=inplace) - model.load_qconf_summary(qconf_summary=ipex_config_path) - run_fn(model) - model.save_qconf_summary(qconf_summary=ipex_config_path) - model = _ipex_post_quant_process(model, example_inputs, inplace=inplace) + # Update model parameter when smoothquant folding = True + if recipe_cfgs and recipe_cfgs.get("smooth_quant", False) and folding: + _apply_pre_optimization(model, tune_cfg, sq_info) - # Recover model parameter when smoothquant folding = True - if ( - recipe_cfgs - and recipe_cfgs.get("smooth_quant", False) - and recipe_cfgs["smooth_quant_args"]["folding"] - and not inplace - ): # pragma: no cover - _apply_pre_optimization(model, tune_cfg, sq, recover=True) + # Update json file in ipex_config_path + cfg_to_qconfig(self.quant_config, cfgs, op_infos_from_cfgs, output_tensor_id_op_name) + model.eval() - with open(ipex_config_path, "r") as f: - model.tune_cfg = json.load(f) - model.ipex_config_path = ipex_config_path - dump_model_op_stats(tune_cfg["op"]) - return model + # Check save_qconf_summary part is a workaround for IPEX bug. + # Sometimes the prepared model from get_op_capablitiy loss this attribute + if not hasattr(model, "save_qconf_summary") or not hasattr(model, "load_qconf_summary"): # pragma: no cover + static_qconfig = ipex.quantization.default_static_qconfig_mapping + if isinstance(example_inputs, dict): + model = ipex.quantization.prepare( + model, static_qconfig, example_kwarg_inputs=example_inputs, inplace=inplace + ) + else: + model = ipex.quantization.prepare(model, static_qconfig, example_inputs=example_inputs, inplace=inplace) + + model.load_qconf_summary(qconf_summary=ipex_config_path) + run_fn(model) + model.save_qconf_summary(qconf_summary=ipex_config_path) + model = _ipex_post_quant_process(model, example_inputs, inplace=inplace) + + # Recover model parameter when smoothquant folding = True + if ( + recipe_cfgs + and recipe_cfgs.get("smooth_quant", False) + and recipe_cfgs["smooth_quant_args"]["folding"] + and not inplace + ): # pragma: no cover + _apply_pre_optimization(model, tune_cfg, sq_info, recover=True) + + with open(ipex_config_path, "r") as f: + model.tune_cfg = json.load(f) + model.ipex_config_path = ipex_config_path + dump_model_op_stats(tune_cfg["op"]) + return model def qdq_quantize( @@ -133,12 +249,12 @@ def qdq_quantize( # Check save_qconf_summary part is a workaround for IPEX bug. # Sometimes the prepared model from get_op_capablitiy loss this attribute - if not hasattr(model, "save_qconf_summary") or not hasattr(model, "load_qconf_summary"): + if not hasattr(model, "save_qconf_summary") or not hasattr(model, "load_qconf_summary"): # pragma: no cover from torch.ao.quantization.observer import MinMaxObserver if ipex_ver.release >= Version("2.1.1").release: static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=0.5, act_observer=MinMaxObserver) - else: + else: # pragma: no cover if sq_minmax_init: static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping( alpha=0.5, act_observer=MinMaxObserver() @@ -169,7 +285,7 @@ def qdq_quantize( # IPEX may raise an error on the second iteration. # OverflowError: cannot convert float infinity to integer run_fn(model) - except: + except: # pragma: no cover logger.warning( "The calibration failed when calibrating with ipex, " + "using scale info from SmoothQuant for Linear and " @@ -197,7 +313,7 @@ def _apply_pre_optimization(model, tune_cfg, sq, recover=False): tsq = TorchSmoothQuant(model, None) alpha = tune_cfg["recipe_cfgs"]["smooth_quant_args"]["alpha"] for op_name, info in sq_max_info.items(): - if alpha == "auto": + if alpha == "auto": # pragma: no cover alpha = info["alpha"] absorb_layer = op_name absorbed_layer = info["absorbed_layer"] @@ -237,7 +353,7 @@ def _ipex_post_quant_process(model, example_inputs, inplace=False): else: model = torch.jit.trace(model, example_inputs) model = torch.jit.freeze(model.eval()) - except: + except: # pragma: no cover if isinstance(example_inputs, dict): model = torch.jit.trace(model, example_kwarg_inputs=example_inputs, strict=False, check_trace=False) else: diff --git a/neural_compressor/torch/algorithms/smooth_quant/utility.py b/neural_compressor/torch/algorithms/smooth_quant/utility.py index 3448d705ea7..51af8ba43cf 100644 --- a/neural_compressor/torch/algorithms/smooth_quant/utility.py +++ b/neural_compressor/torch/algorithms/smooth_quant/utility.py @@ -28,9 +28,11 @@ TransformerBasedModelBlockPatternDetector, dump_model_op_stats, generate_activation_observer, - get_quantizable_ops_recursively, + get_quantizable_ops_from_cfgs, ipex_config_path, + parse_cfgs, simple_inference, + unify_op_type_mapping_ipex, ) from neural_compressor.torch.utils import get_ipex_version, get_torch_version, logger @@ -38,6 +40,125 @@ ipex_ver = get_ipex_version() +def get_quantizable_ops_recursively(model, example_inputs, alpha, act_algo, inplace=True): # pragma: no cover + """Get all quantizable ops from model. + + Args: + model (object): input model + example_inputs (dict|list|tuple|torch.Tensor): used to trace torch model. + alpha (float|str): smoothquant alpha. + act_algo (str): activation algorithm, minmax or kl. + inplace (bool): whether to carry out model transformations in-place. Defaults to True. + + Returns: + quantizable_ops (list): list of tuples of op_name and op_type. + cfgs (dict): dict of configuration + """ + quantizable_ops = [] + # group ops by position for transform-based model + detector = TransformerBasedModelBlockPatternDetector(model) + detect_result = detector.detect_block() + attention_block = detect_result.get("attention_blocks", None) + ffn_blocks = detect_result.get("ffn_blocks", None) + logger.info(f"Attention Blocks: {len(attention_block)}") + logger.info(f"FFN Blocks: {len(ffn_blocks)}") + if not os.path.exists(ipex_config_path): + assert isinstance(model, torch.nn.Module), "The model passed in is not the instance of torch.nn.Module" + + if hasattr(model, "save_qconf_summary"): + os.makedirs(os.path.dirname(ipex_config_path), exist_ok=True) + model.save_qconf_summary(qconf_summary=ipex_config_path) + else: # pragma: no cover + model.eval() + + # create a quantization config file for intel pytorch extension model + os.makedirs(os.path.dirname(ipex_config_path), exist_ok=True) + assert example_inputs is not None, "IPEX need q_dataloader or example_inputs to prepare the model" + + from torch.ao.quantization import MinMaxObserver + + if ipex_ver.release >= Version("2.1.1").release: + static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping( + alpha=alpha, act_observer=MinMaxObserver + ) + else: # pragma: no cover + if act_algo == "minmax": + static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping( + alpha=alpha, act_observer=MinMaxObserver() + ) + logger.warning( + "The int8 model accuracy will be close to 0 with MinMaxobserver, " + + "the suggested IPEX version is higher or equal than 2.1.100+cpu." + ) + else: + static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=alpha) + + if isinstance(example_inputs, dict): + model = ipex.quantization.prepare( + model, static_qconfig, example_kwarg_inputs=example_inputs, inplace=inplace + ) + else: + model = ipex.quantization.prepare(model, static_qconfig, example_inputs=example_inputs, inplace=inplace) + + simple_inference(model, example_inputs, iterations=1) + model.save_qconf_summary(qconf_summary=ipex_config_path) + + map_op_name_to_fqn = {} + with open(ipex_config_path, "r") as f: + cfgs = json.load(f) + ( + ops_name, + op_infos_from_cfgs, + input_tensor_id_op_name, + output_tensor_id_op_name, + ) = parse_cfgs(cfgs) + quantizable_op_names = get_quantizable_ops_from_cfgs(ops_name, op_infos_from_cfgs, input_tensor_id_op_name) + for name in quantizable_op_names: + # name : list + if len(name) == 1: + module_key = name[0][0] + op_cfg_id = name[0][2] + ipex_op_type = cfgs[module_key]["q_op_infos"][op_cfg_id]["op_type"] + module_fqn = cfgs[module_key]["q_op_infos"][op_cfg_id].get("fqn", None) + + if ipex_op_type in unify_op_type_mapping_ipex: + quantizable_ops.append((tuple(name), unify_op_type_mapping_ipex[ipex_op_type])) + map_op_name_to_fqn[(tuple(name), ipex_op_type)] = module_fqn + else: + re_flag = False + for pattern, unify_op_type in unify_op_type_mapping_ipex["re"].items(): + if re.match(pattern, ipex_op_type): + re_flag = True + quantizable_ops.append((tuple(name), unify_op_type)) + map_op_name_to_fqn[(tuple(name), unify_op_type)] = module_fqn + break + if not re_flag: + quantizable_ops.append((tuple(name), ipex_op_type)) + map_op_name_to_fqn[(tuple(name), ipex_op_type)] = module_fqn + else: # pragma: no cover + op_type = "" + for op_name in name: + module_key = op_name[0] + op_cfg_id = op_name[2] + single_op_type = cfgs[module_key]["q_op_infos"][op_cfg_id]["op_type"] + if single_op_type in unify_op_type_mapping_ipex: + single_op_type = unify_op_type_mapping_ipex[single_op_type] + op_type += "&" + single_op_type if op_type else single_op_type + quantizable_ops.append((tuple(name), op_type)) + _module_key = name[0][0] + _op_cfg_id = name[0][2] + module_fqn = cfgs[_module_key]["q_op_infos"][_op_cfg_id]["fqn"] + map_op_name_to_fqn[(tuple(name), op_type)] = module_fqn + + logger.debug("Map op name to fqn: ") + logger.debug(map_op_name_to_fqn) + logger.info("Attention Blocks : ") + logger.info(attention_block) + logger.info("FFN Blocks : ") + logger.info(ffn_blocks) + return quantizable_ops, cfgs, op_infos_from_cfgs, output_tensor_id_op_name + + def check_cfg_and_qconfig( tune_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_op_name, smooth_quant=False ): # pragma: no cover @@ -539,8 +660,6 @@ def calibrate(self, calib_iter, op_types=[torch.nn.Conv2d, torch.nn.Linear]): # class GraphTrace: # pragma: no cover - """""" - def __init__(self): self.supported_torch_module_to_aten = { "Linear": "aten::linear", @@ -729,7 +848,7 @@ def remove_unsupported_layers(self, model, absorb_to_layer, no_absorb_layers): @register_autotune("version1") -class AutoAlpha: +class AutoAlpha: # pragma: no cover def __init__( self, model, @@ -1354,7 +1473,7 @@ def _auto_tune_alpha_blockwise(self): return best_alphas -class TorchSmoothQuant: +class TorchSmoothQuant: # pragma: no cover """Fake input channel quantization, for more details please refer to [1] SmoothQuant: Accurate and Efficient Post-Training Quantization for Large Language Models @@ -1929,7 +2048,7 @@ def _trace(self, op_types, skip_unsupported_layers=True): return absorb_to_layer, no_absorb_layers -class SQLinearWrapper(torch.nn.Module): +class SQLinearWrapper(torch.nn.Module): # pragma: no cover def __init__(self, module, input_scale, input_minmax, alpha=0.5, dtype=torch.quint8): super().__init__() self.register_buffer("input_scale", input_scale) @@ -1990,7 +2109,7 @@ def _recover_sq_linear(self): self.sq_linear.weight *= scale -class WrapperLayer(torch.nn.Module): +class WrapperLayer(torch.nn.Module): # pragma: no cover def __init__(self, layer, input_min, input_max, save_q_input=False): super(WrapperLayer, self).__init__() self.add_module("orig_layer", layer) # set orig_layer in get/set_module diff --git a/neural_compressor/torch/algorithms/static_quant/utility.py b/neural_compressor/torch/algorithms/static_quant/utility.py index 842ac6d889f..7cc3ecbb41e 100644 --- a/neural_compressor/torch/algorithms/static_quant/utility.py +++ b/neural_compressor/torch/algorithms/static_quant/utility.py @@ -277,7 +277,7 @@ def get_quantizable_ops_recursively(model, example_inputs): # pragma: no cover op_infos_from_cfgs, input_tensor_id_op_name, output_tensor_id_op_name, - ) = paser_cfgs(cfgs) + ) = parse_cfgs(cfgs) quantizable_op_names = get_quantizable_ops_from_cfgs(ops_name, op_infos_from_cfgs, input_tensor_id_op_name) for name in quantizable_op_names: # name : list @@ -426,7 +426,7 @@ def get_element_under_depth(d, ops_lst): ops_lst.append(d) -def paser_cfgs(cfgs): # pragma: no cover +def parse_cfgs(cfgs): # pragma: no cover """Parse configs. Args: diff --git a/neural_compressor/torch/quantization/algorithm_entry.py b/neural_compressor/torch/quantization/algorithm_entry.py index 688db134c0e..93aa351cf08 100644 --- a/neural_compressor/torch/quantization/algorithm_entry.py +++ b/neural_compressor/torch/quantization/algorithm_entry.py @@ -235,10 +235,14 @@ def pt2e_static_quant_entry(model: torch.nn.Module, configs_mapping, mode: Mode, @register_algo(name=SMOOTH_QUANT) @torch.no_grad() def smooth_quant_entry( - model: torch.nn.Module, configs_mapping: Dict[Tuple[str, callable], SmoothQuantConfig], *args, **kwargs + model: torch.nn.Module, + configs_mapping: Dict[Tuple[str, callable], SmoothQuantConfig], + mode: Mode = Mode.QUANTIZE, + *args, + **kwargs, ) -> torch.nn.Module: logger.info("Quantize model with the smooth quant algorithm.") - from neural_compressor.torch.algorithms.smooth_quant import save, smooth_quantize + from neural_compressor.torch.algorithms.smooth_quant import SmoothQuantQuantizer, TorchSmoothQuant # convert the user config into internal format quant_config_mapping = {} @@ -277,17 +281,12 @@ def smooth_quant_entry( example_inputs = kwargs.get("example_inputs", None) inplace = kwargs.get("inplace", True) assert example_inputs is not None, "Please provide example_inputs for smooth quantization." - q_model = smooth_quantize( - model=model, - tune_cfg=quant_config_mapping, - run_fn=run_fn, - example_inputs=example_inputs, - inplace=inplace, - ) - logger.info("Smooth quantization done.") - q_model.ori_save = q_model.save - q_model.save = MethodType(save, q_model) - return q_model + + quantizer = get_quantizer(model, quantizer_cls=SmoothQuantQuantizer, quant_config=quant_config_mapping) + model = quantizer.execute(model, mode=mode, run_fn=run_fn, example_inputs=example_inputs, inplace=inplace) + postprocess_model(model, mode, quantizer) + + return model ###################### AWQ Algo Entry ################################## diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index dc850b04d2c..61fac568d30 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -1012,11 +1012,17 @@ def register_supported_configs(cls) -> List[OperatorConfig]: supported_configs.append(OperatorConfig(config=linear_sq_config, operators=operators)) cls.supported_configs = supported_configs - @staticmethod - def get_model_info(model: torch.nn.Module, example_inputs) -> List[Tuple[str, Callable]]: + def get_model_info(self, model: torch.nn.Module, example_inputs) -> List[Tuple[str, Callable]]: from neural_compressor.torch.algorithms.smooth_quant import get_quantizable_ops_recursively - model_info, _, _, _, _ = get_quantizable_ops_recursively(model, example_inputs=example_inputs) + model_info, cfgs, op_infos_from_cfgs, output_tensor_id_op_name = get_quantizable_ops_recursively( + model, example_inputs, alpha=self.alpha, act_algo=self.act_algo, inplace=True + ) + model.cfgs, model.op_infos_from_cfgs, model.output_tensor_id_op_name = ( + cfgs, + op_infos_from_cfgs, + output_tensor_id_op_name, + ) return model_info @classmethod diff --git a/neural_compressor/torch/quantization/quantize.py b/neural_compressor/torch/quantization/quantize.py index 47f1e89667b..8404befdc6f 100644 --- a/neural_compressor/torch/quantization/quantize.py +++ b/neural_compressor/torch/quantization/quantize.py @@ -67,6 +67,20 @@ def quantize( if is_ipex_available and ( isinstance(quant_config, StaticQuantConfig) or isinstance(quant_config, SmoothQuantConfig) ): + if isinstance(quant_config, SmoothQuantConfig): + from neural_compressor.torch.algorithms.smooth_quant import TorchSmoothQuant + + sq = TorchSmoothQuant( + model, dataloader=None, example_inputs=example_inputs, q_func=run_fn, record_max_info=True + ) + model.sq_info = sq + model = sq.transform( + alpha=quant_config.alpha, + folding=quant_config.folding, + auto_alpha_args=quant_config.auto_alpha_args, + scale_sharing=quant_config.scale_sharing, + ) + model_info = quant_config.get_model_info(q_model, example_inputs) else: model_info = quant_config.get_model_info(model=q_model) diff --git a/test/3x/torch/quantization/test_smooth_quant.py b/test/3x/torch/quantization/test_smooth_quant.py index 2fbe59221c5..7d8b1730ff1 100644 --- a/test/3x/torch/quantization/test_smooth_quant.py +++ b/test/3x/torch/quantization/test_smooth_quant.py @@ -4,7 +4,7 @@ import pytest import torch -from neural_compressor.torch.quantization import SmoothQuantConfig, get_default_sq_config, quantize +from neural_compressor.torch.quantization import SmoothQuantConfig, convert, get_default_sq_config, prepare, quantize from neural_compressor.torch.utils import is_ipex_available if is_ipex_available(): @@ -41,22 +41,16 @@ def test_smooth_quant_default(self): fp32_model = copy.deepcopy(model) quant_config = get_default_sq_config() example_inputs = torch.randn([1, 3]) - q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs) + prepared_model = prepare(fp32_model, quant_config=quant_config, example_inputs=example_inputs) + run_fn(prepared_model) + q_model = convert(prepared_model) assert q_model is not None, "Quantization failed!" - @pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX") - def test_smooth_quant_auto(self): fp32_model = copy.deepcopy(model) - auto_alpha_args = { - "alpha_min": 0.45, - "alpha_max": 0.55, - "alpha_step": 0.01, - "shared_criterion": "mean", - "do_blockwise": True, - } - quant_config = SmoothQuantConfig(alpha="auto", auto_alpha_args=auto_alpha_args, folding=False) - example_inputs = torch.randn([1, 3]) - q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs) + example_dict = {"x": example_inputs} + prepared_model = prepare(fp32_model, quant_config=quant_config, example_inputs=example_dict) + run_fn(prepared_model) + q_model = convert(prepared_model) assert q_model is not None, "Quantization failed!" @pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX") @@ -66,7 +60,9 @@ def test_smooth_quant_fallback(self): example_inputs = torch.randn([1, 3]) # fallback by op_type quant_config.set_local(torch.nn.Linear, SmoothQuantConfig(w_dtype="fp32", act_dtype="fp32")) - q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs) + prepared_model = prepare(fp32_model, quant_config=quant_config, example_inputs=example_inputs) + run_fn(prepared_model) + q_model = convert(prepared_model) assert q_model is not None, "Quantization failed!" for op, op_info in q_model.tune_cfg[" "]["q_op_infos"].items(): @@ -96,7 +92,9 @@ def test_sq_linear_params(self, act_sym, act_algo, alpha, folding, scale_sharing def run_fn(model): model(example_inputs) - q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs) + prepared_model = prepare(fp32_model, quant_config=quant_config, example_inputs=example_inputs) + run_fn(prepared_model) + q_model = convert(prepared_model) assert q_model is not None, "Quantization failed!" output1 = fp32_model(example_inputs) output2 = q_model(example_inputs) @@ -104,12 +102,10 @@ def run_fn(model): @pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX") def test_sq_ipex_accuracy(self): - from intel_extension_for_pytorch.quantization import convert, prepare - example_inputs = torch.zeros([1, 3]) qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=0.5) user_model = copy.deepcopy(model) - user_model = prepare(user_model.eval(), qconfig, example_inputs=example_inputs, inplace=True) + user_model = ipex.quantization.prepare(user_model.eval(), qconfig, example_inputs=example_inputs, inplace=True) def run_fn(model): model(example_inputs) @@ -117,7 +113,7 @@ def run_fn(model): run_fn(user_model) user_model.save_qconf_summary(qconf_summary="ipex.json") with torch.no_grad(): - user_model = convert(user_model.eval(), inplace=True).eval() + user_model = ipex.quantization.convert(user_model.eval(), inplace=True).eval() user_model(example_inputs) user_model = torch.jit.trace(user_model.eval(), example_inputs, strict=False) user_model = torch.jit.freeze(user_model.eval()) @@ -127,7 +123,9 @@ def run_fn(model): fp32_model = copy.deepcopy(model) quant_config = get_default_sq_config() - q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs) + prepared_model = prepare(fp32_model, quant_config=quant_config, example_inputs=example_inputs) + run_fn(prepared_model) + q_model = convert(prepared_model) assert q_model is not None, "Quantization failed!" q_model.save("saved_results") @@ -147,7 +145,9 @@ def test_sq_save_load(self): fp32_model = copy.deepcopy(model) quant_config = get_default_sq_config() example_inputs = torch.zeros([1, 3]) - q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs) + prepared_model = prepare(fp32_model, quant_config=quant_config, example_inputs=example_inputs) + run_fn(prepared_model) + q_model = convert(prepared_model) assert q_model is not None, "Quantization failed!" q_model.save("saved_results") inc_out = q_model(example_inputs) @@ -162,6 +162,20 @@ def test_sq_save_load(self): assert torch.allclose(inc_out, loaded_out, atol=2e-02), "Unexpected result. Please double check." # compare saved json file + fp32_model = copy.deepcopy(model) loaded_model = recover_model_from_json(fp32_model, "saved_results/qconfig.json", example_inputs=example_inputs) loaded_out = loaded_model(example_inputs) assert torch.allclose(inc_out, loaded_out, atol=1e-05), "Unexpected result. Please double check." + + @pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX") + def test_smooth_quant_with_quantize_API(self): + fp32_model = copy.deepcopy(model) + quant_config = get_default_sq_config() + example_inputs = torch.randn([1, 3]) + q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs) + assert q_model is not None, "Quantization failed!" + + fp32_model = copy.deepcopy(model) + example_dict = {"x": example_inputs} + q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_dict) + assert q_model is not None, "Quantization failed!" diff --git a/test/3x/torch/requirements.txt b/test/3x/torch/requirements.txt index f26aa3cec64..454ce56efd2 100644 --- a/test/3x/torch/requirements.txt +++ b/test/3x/torch/requirements.txt @@ -1,4 +1,5 @@ numpy +peft==0.10.0 prettytable psutil pytest