From a5e5f5f64855b85e2a374c8b808b317448318113 Mon Sep 17 00:00:00 2001 From: Zixuan Cheng <110808245+violetch24@users.noreply.github.com> Date: Thu, 29 Feb 2024 17:03:38 +0800 Subject: [PATCH] Migrate SmoothQuant for IPEX to 3.x API (#1629) Signed-off-by: Cheng, Zixuan Signed-off-by: Lu, Yintong --- .../torch/algorithms/smooth_quant/__init__.py | 17 + .../algorithms/smooth_quant/smooth_quant.py | 249 ++ .../algorithms/smooth_quant/smoothquant.py | 1596 ----------- .../torch/algorithms/smooth_quant/utility.py | 2341 +++++++++++++++++ .../algorithms/static_quant/static_quant.py | 4 +- .../torch/algorithms/static_quant/utility.py | 727 +++-- .../torch/quantization/algorithm_entry.py | 62 +- .../torch/quantization/config.py | 23 +- requirements_pt.txt | 3 + .../torch/quantization/test_smooth_quant.py | 84 + .../torch/quantization/test_static_quant.py | 33 +- 11 files changed, 3157 insertions(+), 1982 deletions(-) create mode 100644 neural_compressor/torch/algorithms/smooth_quant/__init__.py create mode 100644 neural_compressor/torch/algorithms/smooth_quant/smooth_quant.py delete mode 100644 neural_compressor/torch/algorithms/smooth_quant/smoothquant.py create mode 100644 neural_compressor/torch/algorithms/smooth_quant/utility.py create mode 100644 test/3x/torch/quantization/test_smooth_quant.py diff --git a/neural_compressor/torch/algorithms/smooth_quant/__init__.py b/neural_compressor/torch/algorithms/smooth_quant/__init__.py new file mode 100644 index 00000000000..921e0a6cda9 --- /dev/null +++ b/neural_compressor/torch/algorithms/smooth_quant/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2024 Intel Corporation + +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .utility import * +from .smooth_quant import smooth_quantize diff --git a/neural_compressor/torch/algorithms/smooth_quant/smooth_quant.py b/neural_compressor/torch/algorithms/smooth_quant/smooth_quant.py new file mode 100644 index 00000000000..30fdb8f532e --- /dev/null +++ b/neural_compressor/torch/algorithms/smooth_quant/smooth_quant.py @@ -0,0 +1,249 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +import torch + +try: + import intel_extension_for_pytorch as ipex +except: + assert False, "Please install IPEX for smooth quantization." + +from packaging.version import Version + +from .utility import ( + TorchSmoothQuant, + cfg_to_qconfig, + dump_model_op_stats, + get_ipex_version, + get_quantizable_ops_recursively, + ipex_config_path, + logger, + simple_inference, + update_sq_scale, +) + +ipex_ver = get_ipex_version() + + +def smooth_quantize(model, tune_cfg, run_fn, example_inputs, inplace=True): + """Execute the quantize process on the specified model. + + Args: + model: a float model to be quantized. + tune_cfg: quantization config for ops. + run_fn: a calibration function for calibrating the model. + example_inputs: used to trace torch model. + inplace: whether to carry out model transformations in-place. + + Returns: + A quantized model. + """ + assert not ipex_ver.release < Version("2.1").release, "IPEX version >= 2.1 is required for SmoothQuant." + + _, cfgs, op_infos_from_cfgs, output_tensor_id_op_name = get_quantizable_ops_recursively(model, example_inputs) + + # check smoothquant folding value + recipe_cfgs = tune_cfg.get("recipe_cfgs", None) + if "smooth_quant_args" in recipe_cfgs and "folding" in recipe_cfgs["smooth_quant_args"]: + if recipe_cfgs["smooth_quant_args"]["folding"] is None: + if ipex_ver.release < Version("2.1").release: + folding = True + else: + folding = False + else: + folding = recipe_cfgs["smooth_quant_args"]["folding"] + + # Note: we should make sure smoothquant is only executed once with inplacing fp32 model. + if hasattr(model, "_smoothquant_optimized") and model._smoothquant_optimized: + logger.info("The model is already optimized by SmoothQuant algorithm, skip it.") + return model + + sq = TorchSmoothQuant(model, dataloader=None, example_inputs=example_inputs, q_func=run_fn, record_max_info=True) + model = sq.transform( + alpha=recipe_cfgs["smooth_quant_args"]["alpha"], + folding=folding, + auto_alpha_args=recipe_cfgs["smooth_quant_args"]["auto_alpha_args"], + scale_sharing=recipe_cfgs["smooth_quant_args"]["scale_sharing"], + ) + + # Update model parameter when smoothquant folding = False + if recipe_cfgs and recipe_cfgs.get("smooth_quant", False) and not folding: + return qdq_quantize( + model, tune_cfg, run_fn, example_inputs, inplace, cfgs, op_infos_from_cfgs, output_tensor_id_op_name, sq + ) + + # Update model parameter when smoothquant folding = True + if recipe_cfgs and recipe_cfgs.get("smooth_quant", False) and folding: + _apply_pre_optimization(model, tune_cfg, sq) + model.eval() + + # Check save_qconf_summary part is a workaround for IPEX bug. + # Sometimes the prepared model from get_op_capablitiy loss this attribute + if not hasattr(model, "save_qconf_summary") or not hasattr(model, "load_qconf_summary"): + static_qconfig = ipex.quantization.default_static_qconfig_mapping + if isinstance(example_inputs, dict): + model = ipex.quantization.prepare( + model, static_qconfig, example_kwarg_inputs=example_inputs, inplace=inplace + ) + else: + model = ipex.quantization.prepare(model, static_qconfig, example_inputs=example_inputs, inplace=inplace) + + model.load_qconf_summary(qconf_summary=ipex_config_path) + run_fn(model) + model.save_qconf_summary(qconf_summary=ipex_config_path) + model = _ipex_post_quant_process(model, example_inputs, inplace=inplace) + + # Recover model parameter when smoothquant folding = True + if ( + recipe_cfgs + and recipe_cfgs.get("smooth_quant", False) + and recipe_cfgs["smooth_quant_args"]["folding"] + and not inplace + ): # pragma: no cover + _apply_pre_optimization(model, tune_cfg, sq, recover=True) + + with open(ipex_config_path, "r") as f: + model.tune_cfg = json.load(f) + model.ipex_config_path = ipex_config_path + dump_model_op_stats(tune_cfg) + return model + + +def qdq_quantize( + model, tune_cfg, run_fn, example_inputs, inplace, cfgs, op_infos_from_cfgs, output_tensor_id_op_name, sq +): + smoothquant_scale_info = sq.sq_scale_info + sq_minmax_init = True if tune_cfg.get("act_algo", "kl") == "minmax" else False + + # Check save_qconf_summary part is a workaround for IPEX bug. + # Sometimes the prepared model from get_op_capablitiy loss this attribute + if not hasattr(model, "save_qconf_summary") or not hasattr(model, "load_qconf_summary"): + from torch.ao.quantization.observer import MinMaxObserver + + if ipex_ver.release >= Version("2.1.1").release: + static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=0.5, act_observer=MinMaxObserver) + else: + if sq_minmax_init: + static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping( + alpha=0.5, act_observer=MinMaxObserver() + ) + logger.warning( + "The int8 model accuracy will be close to 0 with MinMaxobserver, " + + "the suggested IPEX version is higher or equal than 2.1.100+cpu." + ) + else: + static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=0.5) + if isinstance(example_inputs, dict): + model = ipex.quantization.prepare( + model, static_qconfig, example_kwarg_inputs=example_inputs, inplace=inplace + ) + else: + model = ipex.quantization.prepare(model, static_qconfig, example_inputs=example_inputs, inplace=inplace) + + # The IPEX SmoothQuant observer can only use save/load_qconf_summary once. + # The save_qconf_summary API will freeze the scale used in model and calibration won't work anymore. + # The load_qconf_summary will overwrite the scales used in model but only work in the first call. + # Here, we use INC collected scale for Linear and set normal observer instead of SQObserver \ + # to make sure calibration works for other ops, like add, bmm. + cfg_to_qconfig(tune_cfg, cfgs, op_infos_from_cfgs, output_tensor_id_op_name, smooth_quant=True) + update_sq_scale(ipex_config_path, smoothquant_scale_info) + model.load_qconf_summary(qconf_summary=ipex_config_path) + # real calibration for other operators + try: + # IPEX may raise an error on the second iteration. + # OverflowError: cannot convert float infinity to integer + run_fn(model) + except: + logger.warning( + "The calibration failed when calibrating with ipex, " + + "using scale info from SmoothQuant for Linear and " + + "one iter calibration for other ops." + ) + + if ipex_ver.release > Version("2.1.0").release: + update_sq_scale(ipex_config_path, smoothquant_scale_info) + model.load_qconf_summary(qconf_summary=ipex_config_path) + _ipex_post_quant_process(model, example_inputs, inplace=inplace) + + with open(ipex_config_path, "r") as f: + model.tune_cfg = json.load(f) + model.ipex_config_path = ipex_config_path + dump_model_op_stats(tune_cfg) + return model + + +def _apply_pre_optimization(model, tune_cfg, sq, recover=False): + sq_max_info = {} + if sq.record_max_info: + sq_max_info = sq.max_value_info + if sq_max_info: + tsq = TorchSmoothQuant(model, None) + alpha = tune_cfg["recipe_cfgs"]["smooth_quant_args"]["alpha"] + for op_name, info in sq_max_info.items(): + if alpha == "auto": + alpha = info["alpha"] + absorb_layer = op_name + absorbed_layer = info["absorbed_layer"] + input_minmax = info["input_minmax"] + weight_max = info["weight_max"] + if sq.weight_clip: + weight_max = weight_max.clamp(min=1e-5) + abs_input_max = torch.max(torch.abs(input_minmax[0]), torch.abs(input_minmax[1])) + input_power = torch.pow(abs_input_max, alpha) + weight_power = torch.pow(weight_max, 1 - alpha) + scale = torch.clip(input_power / weight_power, min=1e-5) + with torch.no_grad(): + if recover: + scale = 1.0 / scale + for layer in absorbed_layer: + tsq._scale_layer_weight(layer, scale) + tsq._absorb_scales(absorb_layer, 1.0 / scale) + logger.debug(f"Current smoothquant scale of {op_name} is {scale}, alpha is {alpha}") + + +def _ipex_post_quant_process(model, example_inputs, inplace=False): + """Convert to a jit model. + + Args: + model: a prepared model. + example_inputs: used to trace torch model. + inplace: whether to carry out model transformations in-place. + + Returns: + A converted jit model. + """ + model = ipex.quantization.convert(model, inplace=inplace) + with torch.no_grad(): + try: + if isinstance(example_inputs, dict): + model = torch.jit.trace(model, example_kwarg_inputs=example_inputs) + else: + model = torch.jit.trace(model, example_inputs) + model = torch.jit.freeze(model.eval()) + except: + if isinstance(example_inputs, dict): + model = torch.jit.trace(model, example_kwarg_inputs=example_inputs, strict=False, check_trace=False) + else: + model = torch.jit.trace(model, example_inputs, strict=False) + model = torch.jit.freeze(model.eval()) + # After freezing, run 1 time to warm up the profiling graph executor to insert prim::profile + # At the 2nd run, the llga pass will be triggered and the model is turned into + # an int8 model: prim::profile will be removed and will have LlgaFusionGroup in the graph + simple_inference(model, example_inputs, iterations=2) + return model diff --git a/neural_compressor/torch/algorithms/smooth_quant/smoothquant.py b/neural_compressor/torch/algorithms/smooth_quant/smoothquant.py deleted file mode 100644 index 9de5bbb40f9..00000000000 --- a/neural_compressor/torch/algorithms/smooth_quant/smoothquant.py +++ /dev/null @@ -1,1596 +0,0 @@ -# -# -*- coding: utf-8 -*- -# -# Copyright (c) 2023 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import copy -import json -import logging - -import torch - -logger = logging.getLogger() -from collections import UserDict, defaultdict - -import numpy -from tqdm import tqdm - - -def enough_memo_store_scale(device, need_space): - if device == "cuda": # pragma: no cover - current_gpu_index = torch.cuda.current_device() - total_memory = torch.cuda.get_device_properties(current_gpu_index).total_memory - used_memory = torch.cuda.memory_allocated(current_gpu_index) - free_space = total_memory - used_memory - else: - import psutil - - free_space = psutil.virtual_memory().free - return free_space >= need_space - - -def move_input_to_device(input, device=torch.device("cpu")): - if isinstance(input, dict) or isinstance(input, UserDict): - tmp_input = {} - for k, inp in input.items(): - tmp_input[k] = move_input_to_device(inp, device) - input = tmp_input - elif isinstance(input, list) or isinstance(input, tuple): - is_tuple = isinstance(input, tuple) - tmp_input = [] - for inp in input: - tmp_input.append(move_input_to_device(inp, device)) - input = tuple(tmp_input) if is_tuple else tmp_input - elif isinstance(input, torch.Tensor): - input = input.to(device) # pylint: disable=no-member - return input - - -##TODO potential bug, data typeR -def forward_wrapper(model, input, device=torch.device("cpu")): - try: - model = model.to(device) - input = move_input_to_device(input, device) - except Exception as e: - logger.warning(e) - logger.warning("Please check the input device if the error raised.") - if isinstance(input, dict) or isinstance(input, UserDict): - output = model(**input) - elif isinstance(input, list) or isinstance(input, tuple): - try: - output = model(*input) - except: - output = model(input) - else: - output = model(input) - return output - - -def model_forward(model, dataloader, iters, device): - try: - cnt = 0 - for idx, (input, label) in enumerate(dataloader): - output = forward_wrapper(model, input, device) - cnt += 1 - if iters != -1 and cnt >= iters: - break - except Exception as e: - cnt = 0 - for idx, input in enumerate(dataloader): - output = forward_wrapper(model, input, device) - cnt += 1 - if iters != -1 and cnt >= iters: - break - - -def model_forward_per_sample(model, sample, device): - try: - output = forward_wrapper(model, sample, device) - return output - - except Exception as e: - output = forward_wrapper(model, sample[0], device) - return output - - -def quant_dequant_w(m, num_bits=8, scheme="sym"): - eps = torch.finfo(torch.float32).eps - if isinstance(m, torch.nn.Linear): - x = m.weight - tmp = torch.zeros(torch.max(x, dim=1).values.size()) - if scheme == "sym": - q_min, q_max = -(2.0 ** (num_bits - 1)), 2.0 ** (num_bits - 1) - 1.0 - x_max = torch.max(torch.abs(x), dim=1).values - scale = x_max / (float(q_max - q_min) / 2) - else: - q_min, q_max = 0, 2.0**num_bits - 1.0 - x_max = torch.maximum(torch.max(x, dim=1).values, tmp) - x_min = torch.minimum(torch.min(x, dim=1).values, tmp) - scale = (x_max - x_min) / (2**num_bits - 1) - - scale = torch.clip(scale, min=eps) - - if scheme == "sym": - bias = 0 - else: - bias = torch.round(0 - (torch.min(x, dim=1).values) / scale) - bias = bias.unsqueeze(dim=-1) - scale = scale.unsqueeze(dim=-1) - q_x = torch.round(x / scale + bias) - q_x.clamp_(q_min, q_max) - return (q_x - bias) * scale - elif isinstance(m, torch.nn.Conv2d): - x = m.weight - x = torch.permute(x, (0, 2, 3, 1)) - x = x.reshape(-1, x.shape[-1]) - tmp = torch.zeros(torch.max(x, dim=0).values.size()) - if scheme == "sym": - q_min, q_max = -(2.0 ** (num_bits - 1)), 2.0 ** (num_bits - 1) - 1.0 - x_max = torch.max(torch.abs(x), dim=0).values - scale = x_max / (2 ** (num_bits - 1) - 1) - else: - q_min, q_max = 0, 2.0**num_bits - 1.0 - x_max = torch.maximum(torch.max(x, dim=0).values, tmp) - x_min = torch.minimum(torch.min(x, dim=0).values, tmp) - scale = (x_max - x_min) / (2**num_bits - 1) - scale = torch.clip(scale, min=eps) - if scheme == "sym": - bias = 0 - else: - bias = torch.round(0 - (torch.min(x, dim=0).values) / scale) - bias = bias.unsqueeze(dim=0) - scale = scale.unsqueeze(dim=0) - - q_x = x / scale + bias - q_x.clamp_(q_min, q_max).round_() - q_dq_x = (q_x - bias) * scale - q_dq_x = q_dq_x.view(m.weight.shape[0], m.weight.shape[2], m.weight.shape[3], m.weight.shape[1]) - q_dq_x = torch.permute(q_dq_x, (0, 3, 1, 2)) - return q_dq_x - else: - logger.warning("unsupported layer type, please have a check") - - -def quant_dequant_x(x, min_x=None, max_x=None, num_bits=8): - eps = torch.finfo(torch.float32).eps - q_min, q_max = 0, 2.0**num_bits - 1.0 - if max_x is None or min_x is None: - max_x, min_x = torch.max(x), torch.min(x) - else: - max_x = torch.max(max_x) - min_x = torch.min(min_x) - scale = (max_x - min_x) / (2**num_bits - 1) - scale = torch.clip(scale, min=eps) - bias = torch.round((0 - min_x) / scale) - q_x = torch.round(x / scale + bias) - q_x.clamp_(q_min, q_max) - return scale * (q_x - bias) - - -def get_module(model, key): - """Get module from model by key name. - - Args: - model (torch.nn.Module): original model - key (str): module name to be replaced - """ - module = model - name_list = key.split(".") - for name in name_list: - if hasattr(module, name): - module = getattr(module, name) - elif hasattr(module, "sq_linear"): # for peft models - module = getattr(module, "sq_linear") - module = getattr(module, name) - elif hasattr(module, "orig_layer"): # for peft models and auto alpha - module = getattr(module, "orig_layer") - module = getattr(module, name) - else: - module = module - return module - - -def set_module(model, key, new_module): - """Set new module into model by key name. - - Args: - model (torch.nn.Module): original model - key (str): module name to be replaced - new_module (torch.nn.Module): new module to be inserted - """ - module = model - name_list = key.split(".") - for name in name_list[:-1]: - if hasattr(module, name): - module = getattr(module, name) - elif hasattr(module, ("sq_linear")): # for peft models that Linears are contained in Linear - module = getattr(module, "sq_linear") - module = getattr(module, name) - elif hasattr(module, ("orig_layer")): # for peft models and auto alpha - module = getattr(module, "orig_layer") - module = getattr(module, name) - else: - module = module - - if hasattr(module, "sq_linear") and name_list[-1] != "sq_linear": # for peft models - module = getattr(module, "sq_linear") - if hasattr(module, "orig_layer") and name_list[-1] != "orig_layer": # for peft models and auto alpha - module = getattr(module, "orig_layer") - setattr(module, name_list[-1], new_module) - - -def cal_scale(input_max, weights, alpha, scale_type="orig"): - if scale_type == "orig": # same as the paper - weights = torch.cat(weights, dim=0) - weight_max = torch.max(torch.abs(weights), dim=0)[0] - input_power = torch.pow(input_max, alpha) - logger.debug(f"{max(input_max)}, {min(input_max)}") - weight_power = torch.pow(weight_max, 1 - alpha) - scale = torch.clip(input_power / weight_power, min=1e-5) - scale[input_power == 0] = 1.0 - if input_power.size() == weight_power.size(): - scale[weight_power == 0] = 0.0 ##FIXME - return scale - - -class WrapperLayer(torch.nn.Module): - def __init__(self, layer, input_min, input_max, save_q_input=False): - super(WrapperLayer, self).__init__() - self.add_module("orig_layer", layer) # set orig_layer in get/set_module - self.quant = False - self.q_input = None - self.fp32_output = None - self.input_max = input_max - self.input_min = input_min - self.weight_scale = None - self.input_scale = None - self.save_q_input = save_q_input - self.do_blockwise = False - - def enable_quant(self): - self.quant = True - - def disable_quant(self): - self.quant = False - - def update_scale(self, input_scale, weight_scale): - self.input_scale = input_scale - self.weight_scale = weight_scale - - ##TODO better tradeoff performance and memory, currently it's too slow - def q_dq_forward(self, x, input_scale, weight_scale): - layer_copy = copy.deepcopy(self.orig_layer) - if weight_scale is not None: - layer_copy.weight *= weight_scale - q_dq_weight = quant_dequant_w(layer_copy) - layer_copy.weight.data.copy_(q_dq_weight) - if input_scale is None: - x = quant_dequant_x(x, self.input_min, self.input_max) - else: - x = input_scale * x - x = quant_dequant_x(x, self.input_min * input_scale, self.input_max * input_scale) ##FIXME - output = layer_copy(x) - return output - - def q_dq_forward_blockwise(self, x, input_scale): - layer_copy = copy.deepcopy(self.orig_layer) - if input_scale is None: - x = quant_dequant_x(x, self.input_min, self.input_max) - else: - x = input_scale * x - x = quant_dequant_x(x, self.input_min * input_scale, self.input_max * input_scale) ##FIXME - output = layer_copy(x) - return output - - def forward(self, x): - if self.quant: - # self.q_input = x * scale ##save the q_input - if self.save_q_input: - self.q_input = x - if not self.do_blockwise: - output = self.q_dq_forward(x, self.input_scale, self.weight_scale) - else: - output = self.q_dq_forward_blockwise(x, self.input_scale) - - else: - output = self.orig_layer(x) - self.output = output - return output - - -class TorchSmoothQuant: - """Fake input channel quantization, for more details please refer to - [1] SmoothQuant: Accurate and Efficient - Post-Training Quantization for Large Language Models - [2] SPIQ: Data-Free Per-Channel Static Input Quantization - Currently, we only handle the layers whose smooth scale could be absorbed, we will support other layers later. - - We only support inplace mode which means the model weights will be changed, you can call recover function - to recover the weights if needed - """ - - def __init__(self, model, dataloader=None, example_inputs=None, q_func=None, traced_model=None): - """ - :param model: Torch model :param dataloader: Calibration dataloader :param traced_model: A specific model - shares the same architecture as the model and could be traced by torch.jit. If not supplied, we use model - instead. - """ - self.model = model - if not isinstance(self.model, torch.nn.Module): - return - device, dtype = self._get_device() - self.model = self.model.to(device) - self.model.eval() - self.device = device - self.dtype = dtype - self.dataloader = dataloader - self.example_inputs = example_inputs - self.q_func = q_func - self.input_maxes = {} - self.input_mins = {} - self.input_maxes_abs = {} - self.traced_model = traced_model - if self.traced_model is None: - self.traced_model = self.model - self.weight_scale_info = {} - self.absorb_scales_info = {} - self.insert_mul = False - self.allow_absorb = True - self.record_max_info = False - self.max_value_info = {} # to record max values for alpha tune - self.self_absorb_layers = {} - self.absorb_to_layer = {} - self.adjust_alpha_space = False - self.weight_clip = True - self.default_alpha = 0.5 - - self._save_scale = False - self.weight_scale_dict = {} - - self.do_blockwise = False - self.block_inputs = {} - self.block_outputs = {} - - def _get_device(self): - """Get the model device - :return:Model device.""" - for _, p in self.model.named_parameters(): - return p.data.device, p.data.dtype - - def _save_input_pc_hook(self, name, percentile=100): - """A forward hook to save input max of a module - :param name: the module name - :return: A hook function.""" - - def save_input_hook(module, inputs, outputs): - input = inputs[0] - ##TODO check input channel is correct - if len(module.weight.shape) == 4: ##conv3d or conv1d not supported now, need better way - input = input.permute(0, 2, 3, 1) - input = input.reshape(-1, input.shape[-1]) - max_tensor = torch.max(input, dim=0)[0] - min_tensor = torch.min(input, dim=0)[0] - k_index = int(input.shape[0] * percentile / 100) - res, _ = torch.kthvalue(torch.abs(input), k_index, dim=0) - ##res = torch.max(torch.abs(input),dim=0)[0] - if name not in self.input_maxes.keys(): - self.input_mins[name], self.input_maxes[name] = min_tensor, max_tensor - self.input_maxes_abs[name] = res - else: - self.input_mins[name] = torch.min(self.input_mins[name], min_tensor) - self.input_maxes[name] = torch.max(self.input_maxes[name], max_tensor) - self.input_maxes_abs[name] = torch.max(self.input_maxes_abs[name], res) - - return save_input_hook - - def _add_min_max_observer(self, modules, percentile=100): - """ - :param modules: the modules which the observer will insert to - :return: - """ - self.hook_handles = [] - for key in modules.keys(): - hook_func = self._save_input_pc_hook(key, percentile) - hook_handle = modules[key].register_forward_hook(hook_func) - self.hook_handles.append(hook_handle) - - def _remove_observer(self): - """Remove the observer from the model - :return:""" - for hook_handle in self.hook_handles: - hook_handle.remove() - - def _calibrate(self, absorb_to_layer, calib_iter, percentile): - """ - :param absorb_to_layer: A dict,key is the absorb layer, val is a list of the to be smoothed layer - :param calib_iter: Data size for calibration - :return: A dict that saved the layer name and the channel-wise max value info - """ - ##hook all the module - hook_modules = {} - for n, module in self.model.named_modules(): - if isinstance(module, tuple(self.op_types)): - hook_modules[n] = module - - self._add_min_max_observer(hook_modules, percentile) - - self._dump_min_max(calib_iter=calib_iter) - self._remove_observer() - return self.input_maxes_abs - - def _dump_min_max(self, calib_iter=100): - """Dump min max per channel information, the min max value will be saved in input_maxes attribute - :param calibration_method: only support min_max currently - :param calib_iter: Sample size for calibration - :return:""" - logger.info("Calibrating...") - if self.q_func: - self.q_func(self.model) - else: - assert self.dataloader, "Please set dataloader for calibration." - model_forward(self.model, self.dataloader, calib_iter, self.device) - - def _reshape_in_channel_to_last(self, layer_name): - """Move the input channel to the last dim - :param layer_name: Layer name - :return: The reshaped weight.""" - layer = get_module(self.model, layer_name) - if layer.__class__.__name__ == "WrapperLayer": - layer = layer.orig_layer - - weight = layer.weight ##TODO oc*ic, support transposed conv - if len(weight.shape) == 4: - weight = weight.permute(0, 2, 3, 1) - weight = weight.reshape(-1, weight.shape[-1]) - return weight - - def _reshape_scale_for_weight(self, layer, scale): - """Reshape the scale for weight input channel, depthwise output channel - :param layer: torch module - :param scale: orig scale - :return: reshaped scale.""" - if hasattr(layer, "orig_layer"): - layer = layer.orig_layer - if isinstance(layer, torch.nn.Conv2d) and layer.groups > 1: ##only depthwise conv could hit here - scale = scale.view(scale.shape[0], 1, 1, 1) ##mount on output channel - - elif isinstance(layer, torch.nn.Conv2d): - scale = scale.view(1, scale.shape[0], 1, 1) - - elif isinstance(layer, torch.nn.Linear): - scale = scale.view(1, scale.shape[0]) - - return scale - - def get_blocks(self): - block_names = [] - for n, m in self.model.named_modules(): - if hasattr(type(m), "__name__") and "ModuleList" in type(m).__name__: - for nn, mm in m.named_children(): - block_name = n + "." + nn - block_names.append(block_name) - return block_names - - def _reshape_scale_for_input(self, layer, scale): - """Reshape the scale for input feature in channel - :param layer: - - :param scale: - :return: - """ - if hasattr(layer, "orig_layer"): - layer = layer.orig_layer - if isinstance(layer, torch.nn.Conv2d): - scale = scale.view(1, scale.shape[0], 1, 1) - - elif isinstance(layer, torch.nn.Linear): - scale = scale.view(1, scale.shape[0]) - - return scale - - def _scale_layer_weight(self, layer_name, scale, alpha=0.5, input_minmax=None): ##input channel - """Scale the layer weights at input channel, depthwise conv output channel - :param layer_name: The layer name - :param scale: The scale to be multiplied - :param alpha: alpha for SQLinearWrapper - :param input_minmax: input_minmax for SQLinearWrapper - :return:""" - layer = get_module(self.model, layer_name) - if self.insert_mul: - from .model_wrapper import SQLinearWrapper - - layer = get_module(self.model, layer_name) - if isinstance(layer, SQLinearWrapper): - layer._recover_sq_linear() - set_module(self.model, layer_name, layer.sq_linear) ##recover - else: - new_module = SQLinearWrapper(layer, 1.0 / scale, input_minmax, alpha) - set_module(self.model, layer_name, new_module) - elif self.allow_absorb: - scale = self._reshape_scale_for_weight(layer, scale) - layer.weight = torch.nn.Parameter(layer.weight * scale) - return scale - - def _absorb_scales(self, layer_name, scale): ##output channel - """Absorb the scale to the layer at output channel - :param layer_name: The module name - :param scale: The scale to be absorbed - :param alpha_key: The alpha passed to SQLinearWrapper - :return:""" - if self.insert_mul or not self.allow_absorb: - return # absorb is updated in SQLinearWrapper in def _scale_layer_weight - - ##if self.allow absorb - layer = get_module(self.model, layer_name) - if layer.__class__.__name__ == "WrapperLayer": - layer = layer.orig_layer - if ( - isinstance(layer, torch.nn.BatchNorm2d) - or isinstance(layer, torch.nn.GroupNorm) - or isinstance(layer, torch.nn.InstanceNorm2d) - ): - if layer.affine: - layer.weight *= scale - layer.bias *= scale - else: - layer.affine = True - weight = torch.ones(layer.num_features, device=self.device, dtype=self.dtype) * scale - layer.weight = torch.nn.Parameter(weight, requires_grad=False) - bias = torch.zeros(layer.num_features, device=self.device, dtype=self.dtype) - layer.bias = torch.nn.Parameter(bias, requires_grad=False) - elif isinstance(layer, torch.nn.LayerNorm): - if layer.elementwise_affine: - layer.weight *= scale - layer.bias *= scale - else: - layer.elementwise_affine = True - weight = torch.ones(layer.num_features, device=self.device, dtype=self.dtype) * scale - layer.weight = torch.nn.Parameter(torch.ones(weight, requires_grad=False)) - bias = torch.zeros(layer.num_features, device=self.device, dtype=self.dtype) - layer.bias = torch.nn.Parameter(bias, requires_grad=False) - - elif isinstance(layer, torch.nn.Conv2d): - ##the order could not be changed - if hasattr(layer, "bias") and (layer.bias is not None): - layer.bias *= scale - scale = scale.view(scale.shape[0], 1, 1, 1) - layer.weight *= scale - - elif isinstance(layer, torch.nn.Linear): - if hasattr(layer, "bias") and (layer.bias is not None): - layer.bias *= scale - scale = scale.view(scale.shape[0], 1) - layer.weight *= scale - - elif layer.__class__.__name__ == "LlamaRMSNorm" or layer.__class__.__name__ == "T5LayerNorm": ##quite tricky - layer.weight *= scale - - else: - logger.warning( - f"found unsupported layer {type(layer)}, try to multiply scale to " - f"weight and bias directly, this may introduce accuracy issue, please have a check " - ) - if hasattr(layer, "weight") and layer.weight is not None: - layer.weight *= scale - if hasattr(layer, "bias") and layer.bias is not None: - layer.bias *= scale - - def _cal_scales(self, absorb_to_layer, input_maxes, alpha=0.5, tuning=False): - """Cal the adjust scales - :param absorb_to_layer: A dict mapping absorb layer to smooth quantized layer - :param input_maxes: The channel-wise input max info for layers - :param alpha: Alpha value to balance the quantization difficulty of activation and weight, a float of a dict - :return:""" - absorb_to_input_maxes = {} - for key in absorb_to_layer.keys(): - layer_name = absorb_to_layer[key][0] - absorb_to_input_maxes[key] = input_maxes[layer_name] - - weight_scales_info = {} - absorb_scales_info = {} - for index, key in enumerate(absorb_to_layer.keys()): - alpha_tmp = alpha[key] if isinstance(alpha, dict) else alpha - if alpha_tmp < 0: - scale = torch.ones((1), device=self.device) - else: - input_max = absorb_to_input_maxes[key] - layer_names = absorb_to_layer[key] - weights = [] - for layer_name in layer_names: - weight = self._reshape_in_channel_to_last(layer_name) - weights.append(weight) - - weight_max_per_channel = torch.max(torch.abs(torch.cat(weights, dim=0)), dim=0)[0] - if self.weight_clip: - weight_max_per_channel = weight_max_per_channel.clamp(min=1e-5) - if self.record_max_info and not tuning: - # the input of layers with same absorb layer is the same. - input_minmax = [self.input_mins[layer_names[0]], self.input_maxes[layer_names[0]]] - self.max_value_info[key] = {} - self.max_value_info[key]["alpha"] = alpha_tmp - self.max_value_info[key]["input_minmax"] = input_minmax - self.max_value_info[key]["weight_max"] = weight_max_per_channel - self.max_value_info[key]["absorbed_layer"] = layer_names - continue - - if self._save_scale: - if key in self.weight_scale_dict and alpha_tmp in self.weight_scale_dict[key]: - scale = self.weight_scale_dict[key][alpha_tmp] - else: - scale = cal_scale(input_max, weights, alpha_tmp) - else: - scale = cal_scale(input_max, weights, alpha_tmp) - - absorb_scales_info[key] = 1.0 / scale - absorb_scales_info[key][scale == 0] = 0 - layer_names = absorb_to_layer[key] - for layer_name in layer_names: - ##self._scale_layer_weight(layer_name, scale) - weight_scales_info[layer_name] = scale - if self._save_scale: - if layer_name not in self.weight_scale_dict: - self.weight_scale_dict[layer_name] = {} - self.weight_scale_dict[layer_name][alpha_tmp] = scale - return absorb_scales_info, weight_scales_info - - def _adjust_parameters(self, absorb_to_layer, input_maxes, alpha=0.5, tuning=False): - """Adjust the weights and biases - :param absorb_to_layer: A dict mapping absorb layer to smooth quantized layer - :param input_maxes: The channel-wise input max info for layers - :param alpha: Alpha value to balance the quantization difficulty of activation and weight, a float of a dict - :return:""" - absorb_scales_info, weight_scales_info = self._cal_scales(absorb_to_layer, input_maxes, alpha, tuning) - if not absorb_scales_info or not weight_scales_info: - return weight_scales_info, absorb_scales_info - for index, key in enumerate(absorb_to_layer.keys()): - if isinstance(alpha, float): - alpha_tmp = alpha - elif isinstance(alpha, dict): - alpha_tmp = alpha[key] - absorb_scale = absorb_scales_info[key] - self._absorb_scales(key, absorb_scale) - layer_names = absorb_to_layer[key] - for layer_name in layer_names: - input_minmax = [self.input_mins[layer_names[0]], self.input_maxes[layer_names[0]]] - self._scale_layer_weight(layer_name, weight_scales_info[layer_name], alpha_tmp, input_minmax) - return weight_scales_info, absorb_scales_info - - def _check_need_calibration(self, alpha, percentile, op_types, scales_per_op, calib_iter): - """ - check need calibration or not - :param alpha: current alpha - :param percentile: current percentile - :param op_types: current op_types - :param scales_per_op: current scales_per_op - :param calib_iter:: current scales_per_op - :return: - """ - need_calib = True - if len(self.input_maxes) == 0: ## the first time - need_calib = True - self.alpha = alpha - self.percentile = percentile - self.op_types = op_types - self.scales_per_op = scales_per_op - self.calib_iter = calib_iter - return need_calib - - if ( - self.percentile == percentile - and self.op_types == op_types - and self.scales_per_op == scales_per_op - and self.calib_iter == calib_iter - ): - if isinstance(alpha, float) or self.alpha == "auto": - need_calib = False - - self.alpha, self.percentile = alpha, percentile - self.op_types, self.scales_per_op = op_types, scales_per_op - self.calib_iter = calib_iter - return need_calib - - def _get_auto_loss(self, output, output_q, loss_type="abs", loss_alpha=1.0): - """Get the loss for auto tuning - :param output: Fp32 output for one layer - :param output_q: Quant output for one layer - :param loss_type: The type of loss - :param loss_alpha: Loss alpha i for mean scale error - :return: A tensor of the loss.""" - if len(output.shape) <= 2: - max_value = torch.max(torch.abs(output)) - else: - output = output.reshape(output.shape[0], -1) - output_q = output_q.reshape(output_q.shape[0], -1) - max_value = torch.max(torch.abs(output), dim=-1).values.unsqueeze(-1) - max_value = torch.clip(max_value, 1e-5) - output = output / max_value ##FIXME need copy not replace - output_q = output_q / max_value - # if loss_type == "nsr": # nsr is unused at this point. - # output[output == 0] = 1e-5 - # loss = torch.sum(torch.log(1.0 + torch.abs(output - output_q) / torch.abs(output))) - # return loss - if loss_type == "abs": - return torch.sum(torch.pow(torch.abs(output - output_q), 0.5)) - else: - return torch.sum((output - output_q) ** 2) - - def _get_sq_layer_names(self): - """Get the all the hook sq layer - :return: All the sq layer names.""" - ##TODO this may not fit for folding=False - module_names = [] - for key in self.absorb_to_layer: - module_names += self.absorb_to_layer[key] - return module_names - - def _get_all_hook_module_names(self): - module_names = [] - for n, module in self.model.named_modules(): - if isinstance(module, tuple(self.op_types)): - module_names.append(n) - return module_names - - def _qdq_model_wrapper_for_auto(self, save_q_input=False): - """Wrapper all the module with qdq - :return:""" - module_names = self._get_all_hook_module_names() - self.to_unwrap_module_names = module_names - for name in module_names: - if name not in self.input_mins: # skip module if it's not used in calibration - continue - module = get_module(self.model, name) - new_module = WrapperLayer(module, self.input_mins[name], self.input_maxes[name], save_q_input=save_q_input) - set_module(self.model, name, new_module) - - def _qdq_model_unwrapper_for_auto(self): - module_names = self.to_unwrap_module_names - for name in module_names: - module = get_module(self.model, name) - if not hasattr(module, "orig_layer"): # skip module if it's not used in calibration - continue - set_module(self.model, name, module.orig_layer) - - def _change_qdq_for_auto(self, enable=True): - module_names = self._get_all_hook_module_names() - for name in module_names: - name = name.split(".orig_layer")[0] - module = get_module(self.model, name) - if not hasattr(module, "orig_layer"): # skip module if it's not used in calibration - continue - if enable: - module.enable_quant() - else: - module.disable_quant() - - def _update_scales_for_auto(self, absorb_scales, weight_scales): - for key in self.absorb_to_layer.keys(): - layer_names = self.absorb_to_layer[key] - for layer_name in layer_names: - layer = get_module(self.model, layer_name) - input_scale = absorb_scales[key] - weight_scale = weight_scales[layer_name] - input_scale = self._reshape_scale_for_input(layer, input_scale) - weight_scale = self._reshape_scale_for_weight(layer, weight_scale) - layer.update_scale(input_scale, weight_scale) ##FIXME - - def _add_blockwise_observer(self, block_modules): - """ - :param block_modules: the block modules which the observer will insert to - :return: - """ - self.blockwise_hook_handles = [] - for key in block_modules.keys(): - hook_func = self._save_blockwise_hook(key) - hook_handle = block_modules[key].register_forward_hook(hook_func) - self.blockwise_hook_handles.append(hook_handle) - - def _save_blockwise_hook(self, name): - """A forward hook to save inputs/outputs of a block - :param name: the block name - :return: A hook function.""" - - def save_blockwise_hook(module, inputs, outputs): - self.block_inputs[name] = inputs[0] - self.block_outputs[name] = outputs[0] - - return save_blockwise_hook - - def _get_one_batch_auto_loss(self, input, alpha_space, orig_best_alpha, input_maxes): - self._change_qdq_for_auto(enable=False) - module_names = self._get_sq_layer_names() - - if self.do_blockwise: - block_modules = {} - for key in self.block_names: - block_modules[key] = get_module(self.model, key) - self._add_blockwise_observer(block_modules) - - forward_wrapper(self.model, input, self.device) ##disable quant and get fp32 output - - fp32_output = {} - if not self.do_blockwise: - for name in module_names: - module = get_module(self.model, name) - fp32_output[name] = module.output - module.output = None - else: - for block_name in self.block_names: - fp32_output[block_name] = self.block_outputs[block_name] - self._change_qdq_for_auto(enable=True) - absorb_input_scales, weight_scales = self._cal_scales( - self.absorb_to_layer, input_maxes, orig_best_alpha, tuning=True - ) - self._update_scales_for_auto(absorb_input_scales, weight_scales) - forward_wrapper(self.model, input, self.device) ##save quant_input - for mod_name in module_names: # save fp32 values - mod = get_module(self.model, mod_name) - if mod_name in self.fp32_output_val: - self.fp32_output_val[mod_name].append(torch.norm(mod.output)) - else: - self.fp32_output_val[mod_name] = [torch.norm(mod.output)] - del mod - - loss_alphas = {} - if not self.do_blockwise: - for name in module_names: - module = get_module(self.model, name) - loss = self._get_auto_loss(fp32_output[name], module.output) - cur_alpha = orig_best_alpha - if isinstance(orig_best_alpha, dict): - cur_alpha = orig_best_alpha[name] - key_name = str(cur_alpha) - loss_alphas[name] = {key_name: loss} - else: - for block_name in self.block_names: - block = get_module(self.model, block_name) - loss = self._get_auto_loss(fp32_output[block_name], self.block_outputs[block_name]) - cur_alpha = orig_best_alpha - if isinstance(orig_best_alpha, dict): - cur_alpha = orig_best_alpha[self.block_to_module[block_name][0]] - key_name = str(cur_alpha) - loss_alphas[block_name] = {key_name: loss} - # for name in module_names: - # loss_alphas[name]={} - for alpha in alpha_space: - absorb_input_scales, weight_scales = self._cal_scales(self.absorb_to_layer, input_maxes, alpha, tuning=True) - self._update_scales_for_auto(absorb_input_scales, weight_scales) - if not self.do_blockwise: - for name in module_names: - losses = loss_alphas[name] - if str(alpha) in losses.keys(): - continue - module = get_module(self.model, name) - output = module.q_dq_forward(module.q_input, module.input_scale, module.weight_scale) - loss = self._get_auto_loss(fp32_output[name], output) - loss_alphas[name][str(alpha)] = loss - else: - for block_name in self.block_names: - losses = loss_alphas[block_name] - if str(alpha) in losses.keys(): - continue - block = get_module(self.model, block_name) - block_copy = copy.deepcopy(block) - for name in self.block_to_module[block_name]: - if name == block_name and len(self.block_to_module[block_name]) == 1: - module, module_copy = block, block_copy - else: - module = get_module(block, name) - module_copy = copy.deepcopy(module) - if module.weight_scale is not None: - module_copy.orig_layer.weight *= module.weight_scale - q_dq_weight = quant_dequant_w(module_copy.orig_layer) - module_copy.orig_layer.weight.data.copy_(q_dq_weight) - module_copy.do_blockwise = True - if not (name == block_name and len(self.block_to_module[block_name]) == 1): - set_module(block_copy, name, module_copy) - try: - output = block_copy(self.block_inputs[block_name])[0] - except: # Llama model decoder_layer forward requires position_id - position_ids = torch.arange(self.block_inputs[block_name].size()[1]) - position_ids = position_ids.view(self.block_inputs[block_name].size()[0], -1) - output = block_copy(self.block_inputs[block_name], position_ids=position_ids)[0] - loss = self._get_auto_loss(fp32_output[block_name], output) - loss_alphas[block_name][str(alpha)] = loss - del block_copy # release memory - return loss_alphas - - def _get_best_alpha(self, absorb_to_layer, loss_alphas, shared_criterion): - def dict_to_list(dic): - res = [] - for key in dic.keys(): - res.append((key, dic[key])) - return res - - best_alpha = {} - for ln_name in absorb_to_layer.keys(): - layer_names = absorb_to_layer[ln_name] - cur_shared_criterion = shared_criterion - if len(layer_names) == 1: - cur_shared_criterion = "min" - if cur_shared_criterion == "mean": - loss_tmp = {} - for alpha in loss_alphas[layer_names[0]].keys(): - if alpha not in loss_tmp.keys(): - loss_tmp[alpha] = 0 - for layer_name in layer_names: - loss_tmp[alpha] += loss_alphas[layer_name][alpha] - res = dict_to_list(loss_tmp) - res.sort(key=lambda x: x[1]) - - best_alpha[ln_name] = float(res[0][0]) - - elif cur_shared_criterion == "min" or cur_shared_criterion == "max": - tmp_best_alpha = [] - for layer_name in layer_names: - res = dict_to_list(loss_alphas[layer_name]) - res.sort(key=lambda x: x[1]) - tmp_best_alpha.append(float(res[0][0])) - if cur_shared_criterion == "min": - best_alpha[ln_name] = min(tmp_best_alpha) - else: - best_alpha[ln_name] = max(tmp_best_alpha) - - else: - raise NotImplementedError - return best_alpha - - def _auto_tune_alpha( - self, - input_maxes, - calib_sample_num=32, - alpha_min=0.3, - alpha_max=0.7, - alpha_step=0.05, - shared_criterion="min", - do_blockwise=False, - ): - """Perform alpha-tuning to obtain layer-wise optimal alpha values and adjust parameters accordingly. - - This function takes quantization of the former layers into consideration when qdq one layer - Also, it reduces the memory usage at the cost of increasingtuning time - TODO may have compatibility issue when setting folding=True, check whether having issues when bs!=1 - :param input_maxes: calibration data, input max - :param calib_sample_num: sample count used to auto tuning alpha - :param alpha_min: the min value of alpha - :param alpha_max: the max value of alpha - :param alpha_step: the alpha step in search space - :param shared_criterion: the criterion to choose alpha when multiple layers must share one same alpha - :return: - """ - logger.info("start sq auto tuning") - round_num = max( - len(str(alpha_min).split(".")[1]), len(str(alpha_max).split(".")[1]), len(str(alpha_step).split(".")[1]) - ) - alpha_space = numpy.round(numpy.arange(alpha_min, alpha_max + alpha_step, alpha_step), round_num).tolist() - ##wrapper new module - self._qdq_model_wrapper_for_auto(save_q_input=True) - ##set alpha to 0.5 as default - default_alpha = alpha_space[len(alpha_space) // 2] - if 0.5 in alpha_space: - default_alpha = 0.5 - default_alpha = self.default_alpha - absorb_input_scales, weight_scales = self._cal_scales( - self.absorb_to_layer, input_maxes, default_alpha, tuning=True - ) - self._update_scales_for_auto(absorb_input_scales, weight_scales) - total_cnt = 0 - tmp_cnt = 0 - alpha_update_iter = 0 - # multiply_factor is used to combine samples to calib_sample_num // 4 before summarizing the best alpha - tune_cnt = 4 - multiply_factor = calib_sample_num // tune_cnt if calib_sample_num >= tune_cnt else calib_sample_num - self.fp32_output_val = {} - - best_alphas = default_alpha - if not self.dataloader: - logger.info(f"Auto-tuning failed due to no dataloader, using {best_alphas} instead.") - self._qdq_model_unwrapper_for_auto() - return best_alphas - bar = tqdm(self.dataloader, total=calib_sample_num, desc="auto tune alpha") - try: - for input, label in bar: - loss_alphas = {} - best_alphas_per_module = best_alphas - if isinstance(best_alphas, dict): - for key in self.absorb_to_layer.keys(): - layer_names = self.absorb_to_layer[key] - for layer_name in layer_names: - best_alphas_per_module[layer_name] = best_alphas_per_module[key] - - loss_tmp = self._get_one_batch_auto_loss(input, alpha_space, best_alphas_per_module, input_maxes) - if self.do_blockwise: - if loss_alphas == {}: - for block_name in self.block_names: - for key in self.block_to_module[block_name]: - loss_alphas[key] = loss_tmp[block_name] - else: - for block_name in self.block_names: - for key in self.block_to_module[block_name]: - cur_loss = loss_alphas[key] - for alpha_key in cur_loss.keys(): - cur_loss[alpha_key] += loss_tmp[block_name][alpha_key] - else: - if loss_alphas == {}: - loss_alphas = loss_tmp - else: - for key in loss_alphas.keys(): - cur_loss = loss_alphas[key] - for alpha_key in cur_loss.keys(): - cur_loss[alpha_key] += loss_tmp[key][alpha_key] - total_cnt += self.dataloader.batch_size - tmp_cnt += self.dataloader.batch_size - if tmp_cnt // multiply_factor >= 1: - alpha_update_iter += 1 - tmp_cnt = 0 - best_alphas = self._get_best_alpha(self.absorb_to_layer, loss_alphas, shared_criterion) - for key in best_alphas.keys(): - logger.info(f"Auto alpha update iter: {alpha_update_iter}, {key}: {best_alphas[key]}") - absorb_input_scales, weight_scales = self._cal_scales( - self.absorb_to_layer, input_maxes, best_alphas, tuning=True - ) - self._update_scales_for_auto(absorb_input_scales, weight_scales) - # does not need to reset the weight_scale_dict, because use the weight of ori_layer, no change - # self.weight_scale_dict = {} - if total_cnt >= calib_sample_num: - break - except: - for input in bar: - loss_alphas = {} - best_alphas_per_module = best_alphas - if isinstance(best_alphas, dict): - for key in self.absorb_to_layer.keys(): - layer_names = self.absorb_to_layer[key] - for layer_name in layer_names: - best_alphas_per_module[layer_name] = best_alphas_per_module[key] - - loss_tmp = self._get_one_batch_auto_loss(input, alpha_space, best_alphas_per_module, input_maxes) - if self.do_blockwise: - if loss_alphas == {}: - for block_name in self.block_names: - for key in self.block_to_module[block_name]: - loss_alphas[key] = loss_tmp[block_name] - else: - for block_name in self.block_names: - for key in self.block_to_module[block_name]: - cur_loss = loss_alphas[key] - for alpha_key in cur_loss.keys(): - cur_loss[alpha_key] += loss_tmp[block_name][alpha_key] - else: - if loss_alphas == {}: - loss_alphas = loss_tmp - else: - for key in loss_alphas.keys(): - cur_loss = loss_alphas[key] - for alpha_key in cur_loss.keys(): - cur_loss[alpha_key] += loss_tmp[key][alpha_key] - total_cnt += self.dataloader.batch_size - tmp_cnt += self.dataloader.batch_size - if tmp_cnt // multiply_factor >= 1: - alpha_update_iter += 1 - tmp_cnt = 0 - - best_alphas = self._get_best_alpha(self.absorb_to_layer, loss_alphas, shared_criterion) - for key in best_alphas.keys(): - logger.info(f"Auto alpha update iter: {alpha_update_iter}, {key}: {best_alphas[key]}") - absorb_input_scales, weight_scales = self._cal_scales( - self.absorb_to_layer, input_maxes, best_alphas, tuning=True - ) - self._update_scales_for_auto(absorb_input_scales, weight_scales) - # self.weight_scale_dict = {} - if total_cnt >= calib_sample_num: - break - - best_alphas = self._get_best_alpha(self.absorb_to_layer, loss_alphas, shared_criterion) - for key in best_alphas.keys(): - logger.info(f"Final alpha {key}:{best_alphas[key]}") - max_op, max_ratio, max_key = "", 0, "" - ratio_info = {} - for key in self.absorb_to_layer: - for op_name in self.absorb_to_layer[key]: - fp32_norm, loss_ = ( - torch.sum(torch.stack(self.fp32_output_val[op_name])), - loss_alphas[op_name][str(best_alphas[key])], - ) - ratio = loss_ / fp32_norm - max_op = op_name if ratio > max_ratio else max_op - max_key = key if ratio > max_ratio else max_key - max_ratio = max(ratio, max_ratio) - ratio_info[op_name] = ratio - logger.debug( - f"final loss: {op_name}: {loss_}; @alpha {best_alphas[key]}; \ - fp32_output norm: {fp32_norm}; ratio: {ratio}" - ) - import operator - - ratio_info = dict(sorted(ratio_info.items(), key=operator.itemgetter(1), reverse=True)) - for key in list(ratio_info.keys()): - logger.debug(f"sorted opname-ratio: {key}: {ratio_info[key]}") - if max_op != "": - logger.debug( - f"max loss: {max_op}: {loss_alphas[max_op][str(best_alphas[max_key])]} @alpha {best_alphas[max_key]}\ - fp32_output norm: {torch.sum(torch.stack(self.fp32_output_val[max_op]))}; ratio: {max_ratio}" - ) - self._qdq_model_unwrapper_for_auto() - logger.info("auto tuning done") - return best_alphas - - def transform( - self, - alpha=0.5, - folding=False, - percentile=100, - op_types=[torch.nn.Linear, torch.nn.Conv2d], - scales_per_op=False, - calib_iter=100, - auto_alpha_args={ - "alpha_min": 0.0, - "alpha_max": 1.0, - "alpha_step": 0.1, - "shared_criterion": "mean", - "do_blockwise": False, - }, - weight_clip=True, - default_alpha=0.5, - ): - """The main entry of smooth quant - :param alpha: Alpha value to balance the quantization difficulty of activation and weight, please refer - to the paper for more details - :param folding: whether insert mul(False) or just allow foldable layers(True) for SmoothQuant - :param percentile: remove the activation outlier when calculating the scale - :param op_types: The op typed to be smooth quantized - :param scales_per_op: Not supported now - :param calib_iter: Data size for calibration - :param weight_clip: Whether to clip weight_max when calculating scales. - - :param auto_alpha_args: Hyperparameters used to set the alpha search space in SQ auto-tuning. - By default the search space is 0.0-1.0 with step_size 0.1. - do_blockwise: Whether to do blockwise auto-tuning. - :param default_alpha: A hyperparameter that is used in SQ auto-tuning; by default it is 0.5. - :return: A FP32 model with the same architecture as the orig model but with different weight which will be - benefit to quantization. - """ - if isinstance(auto_alpha_args, dict): - self.do_blockwise = auto_alpha_args.get("do_blockwise", False) - else: - self.do_blockwise = False - if self.do_blockwise: - self.block_names = self.get_blocks() - logger.info("Blockwise auto-tuning will be performed") - if not isinstance(self.model, torch.nn.Module): - logger.warning("smooth quant is ignored since the model is not a torch module") - return self.model - - if folding: - self.insert_mul, self.allow_absorb = False, True - else: - self.insert_mul, self.allow_absorb = True, False - if isinstance(alpha, float) and (alpha < 0 or alpha > 1): - logger.warning("reset alpha to in range [0.0, 1.0]") - - alpha = numpy.clip(alpha, 0.0, 1.0) - - self.weight_clip = weight_clip - self.default_alpha = default_alpha - self.auto_alpha_args = auto_alpha_args - self.recover() - need_calibration = self._check_need_calibration(alpha, percentile, op_types, scales_per_op, calib_iter) - with torch.no_grad(): - str_op_types = [i.__name__ for i in op_types] - input_maxes_abs = self.input_maxes_abs - if need_calibration: ##avoid multiple calibaration during tuning if the only difference is alpha - if self.insert_mul: - self.self_absorb_layers = self._get_all_layer_names(op_types) # TODO: only support linear now. - # fetch modules with the same input - group_modules = self._trace(str_op_types, skip_unsupported_layers=False) - if group_modules is not None: - # use one input for qkv - for k, v in group_modules.items(): - for i in v: - if i in self.self_absorb_layers: - self.self_absorb_layers.pop(i) - self.self_absorb_layers[v[0]] = v - logger.debug(f"self_absorb_layers:{self.self_absorb_layers}") - if self.allow_absorb: - self.absorb_to_layer, no_absorb_layers = self._trace( - str_op_types - ) ##TODO we need to insert mul layer for no_absorb_layers later - if self.absorb_to_layer is None and no_absorb_layers is None: - return self.model - - # remove self.self_absorb_layers if it exists in self.absorb_to_layer - for k, v in self.absorb_to_layer.items(): - for i in v: - if i in self.self_absorb_layers: - self.self_absorb_layers.pop(i) - self.absorb_to_layer.update(self.self_absorb_layers) - - if self.absorb_to_layer is None and no_absorb_layers is None: - logger.warning( - "sorry, could not trace the model, smooth quant is ignored." - "If you are using huggingface model," - "you could set torchscript to True " - ) - return self.model - - if self.do_blockwise: - module_names = self._get_sq_layer_names() - block_names, self.block_to_module = self.block_names, {} - for block in block_names: - self.block_to_module[block] = [] - for module in module_names: - checked = False - for block in block_names: - if block + "." in module: - self.block_to_module[block].append(module) - checked = True - if not checked: - self.block_to_module[module] = [module] - self.block_names = list(self.block_to_module.keys()) - logger.info(f"Blockwise auto-tuning: {len(self.block_names)} blocks found") - logger.debug(f"Blockwise auto-tuning blocks info: {self.block_to_module}") - - input_maxes_abs = self._calibrate(self.absorb_to_layer, calib_iter, percentile) - - # Check if input_maxes match self.absorb_to_layer - # (due to self._get_all_layer_names use layer tree instead of forward_path) - if not folding: - diff_modules = set(self.absorb_to_layer.keys()).difference(input_maxes_abs.keys()) - for d in diff_modules: - del self.absorb_to_layer[d] - - scale_memo_use = 0 - for key in self.absorb_to_layer: - layer_name = self.absorb_to_layer[key][0] - input_max = input_maxes_abs[layer_name] - scale_memo_use += 4 * input_max.shape[0] * len(self.absorb_to_layer[key]) - if alpha == "auto": - alpha_space = (auto_alpha_args["alpha_max"] - auto_alpha_args["alpha_min"]) / auto_alpha_args[ - "alpha_step" - ] + 1 - scale_memo_use *= alpha_space - self._save_scale = enough_memo_store_scale(self.device, scale_memo_use) - - if alpha == "auto": - self.alpha_per_layer = self._auto_tune_alpha( - input_maxes_abs, calib_sample_num=32, **auto_alpha_args - ) ##save the alpha - - if alpha == "auto": - alpha = self.alpha_per_layer - example_inputs = self._get_example_input() - if example_inputs is not None: - out_pre_sq = model_forward_per_sample(self.model, example_inputs, self.device) - - if folding: - self._save_scale = False - if self.record_max_info: - # max_info is recorded in self.max_value_info - self._adjust_parameters(self.absorb_to_layer, input_maxes_abs, alpha) - self.model._smoothquant_optimized = False - return self.model - - self.weight_scale_info, self.absorb_scales_info = self._adjust_parameters( - self.absorb_to_layer, input_maxes_abs, alpha - ) - - self.model._smoothquant_optimized = True - if example_inputs is not None: - # Check mathematical equivelancy - out_post_sq = model_forward_per_sample(self.model, example_inputs, self.device) - - if not self.output_is_equal(out_post_sq, out_pre_sq): - logger.warning( - "Mathematical equivelancy of Smoothquant is not preserved. " - "Please kindly report this issue to https://github.com/intel/neural-compressor." - ) - else: - logger.warning(" Could not get example input, equivelancy check is skipped") - - return self.model - - def output_is_equal(self, out1, out2, atol=1e-04): - try: - if isinstance(out1, tuple): - return all(torch.all(torch.isclose(out1[i], out2[i], atol=atol)) for i in range(len(out1))) - elif isinstance(out1, dict): - return all(torch.all(torch.isclose(out1[k], out2[k], atol=atol)) for k in out1.keys()) - elif isinstance(out1, torch.Tensor): - return torch.all(torch.isclose(out1, out2, atol=atol)) - return False - except: - logger.warning( - "Automatically check failed, Please check equivelancy manually " - "between out_pre_sq and out_post_sq if necessary." - ) - return True - - def recover(self): - """Recover the model weights - :return:""" - with torch.no_grad(): - for key in self.weight_scale_info: - self._scale_layer_weight(key, 1.0 / self.weight_scale_info[key]) - for key in self.absorb_scales_info: - self._absorb_scales(key, 1.0 / self.absorb_scales_info[key]) - self.weight_scale_info = {} ##clear the data - self.absorb_scales_info = {} - - def _get_all_layer_names(self, op_types=[torch.nn.Linear]): - """Try the model to find the layers which can be smooth quantized. - - :param op_types: The op types to be smooth quantized - :return: - self_absorb_layer: A dict, absorb layer name (itself): layers to be smooth quantized - """ - self_absorb_layer = {} - op_types = [torch.nn.Linear] # TODOļ¼š only support SQLinearWrapper - for name, module in self.model.named_modules(): - if isinstance(module, tuple(op_types)): - self_absorb_layer[name] = [name] - return self_absorb_layer - - def _get_example_input(self): - if self.dataloader is None and self.example_inputs is None: - return None - if self.example_inputs is None: - try: - for idx, (input, label) in enumerate(self.dataloader): - self.example_inputs = input - break - except: - for idx, input in enumerate(self.dataloader): - self.example_inputs = input - break - - return self.example_inputs - - def _trace(self, op_types, skip_unsupported_layers=True): - """Try the model to find the layers which can be smooth quantized. - - :param op_types: The op types to be smooth quantized - :return: - absorb_to_layer: A dict, absorb layer name:layers to be smooth quantized - no_absorb_layers: A list saving the layers which could not find the absorb layer - """ - tg = GraphTrace() - self._get_example_input() - absorb_to_layer, no_absorb_layers = tg.get_absorb_to_layer( - self.traced_model, - self.example_inputs, - op_types, - skip_unsupported_layers=skip_unsupported_layers, - ) - if not skip_unsupported_layers: - return absorb_to_layer - if absorb_to_layer is None and no_absorb_layers is None: - logger.warning( - "sorry, could not trace the model, smooth quant is skipped." - "If you are using huggingface model," - "you could set torchscript to True " - "when loading the model or set the return_dict to False" - ) - elif absorb_to_layer == {}: - logger.warning("could not find any layer to be absorbed") - else: - to_absorb_cnt = 0 - for key, item in absorb_to_layer.items(): - to_absorb_cnt += len(item) - logger.info( - f" {to_absorb_cnt} out of {to_absorb_cnt + len(no_absorb_layers)} " - f"layers could be absorbed in smooth quant" - ) - return absorb_to_layer, no_absorb_layers - - -def get_parent(node, all_parents=False): - if node.inputs() is None: - return None - elif len(list(node.inputs())) == 0: - return None - if not all_parents: - return list(node.inputs())[0].node() - else: - return list(node.inputs()) - - -class GraphTrace: - """""" - - def __init__(self): - self.supported_torch_module_to_aten = { - "Linear": "aten::linear", - "Conv2d": "aten::_convolution", - "ConvTranspose2d": "aten::_convolution", - "LayerNorm": "aten::layer_norm", - "BatchNorm2d": "aten::batch_norm", - "GroupNorm": "aten::group_norm", - "InstanceNorm2d": "aten::instance_norm", - "LlamaRMSNorm": "aten::mul", - "T5LayerNorm": "aten::mul", - "LPLayerNorm": "aten::layer_norm", ##mpt_chat - } - - ##TODO potential bug, need to check only have one bug - ##TODO, must satisfy af(x)=f(ax),current skip layer may be incomplete - self.skip_ops_to_find_absorb = ["aten::to", "aten::relu", "aten::leaky_relu", "aten::hardtanh"] - - self.could_absorb_layers = [ - "aten::layer_norm", - "aten::batch_norm", - "aten::linear", - "aten::_convolution", - "aten::group_norm", - "aten::instance_norm", - "aten::mul", - ] ##TODO,support more norm - - def trace(self, model, dummy_input): - traced_model = None - optimize_numerics = False - orig_device = str(next(model.parameters()).device) - if orig_device != "cpu" and orig_device != "meta": # pragma: no cover - model = model.to("cpu") - dummy_input = move_input_to_device(dummy_input, "cpu") - if isinstance(dummy_input, dict) or isinstance(dummy_input, UserDict): - try: - traced_model = torch.jit.trace( - model, example_kwarg_inputs=dict(dummy_input), strict=False, check_trace=False - ) - traced_model = torch.jit.freeze(traced_model.eval(), optimize_numerics=optimize_numerics) - except Exception as e: - logger.warning(e) - logger.warning("Jit trace in GraphTrace failed, absorb layer detection is skipped") - else: - try: - traced_model = torch.jit.trace(model, dummy_input, strict=False) - traced_model = torch.jit.freeze(traced_model.eval(), optimize_numerics=optimize_numerics) - except: - try: - traced_model = torch.jit.trace(model, dummy_input[0], strict=False) - traced_model = torch.jit.freeze(traced_model.eval(), optimize_numerics=optimize_numerics) - except Exception as e: - logger.warning(e) - logger.warning("Jit trace in GraphTrace failed, absorb layer detection is skipped") - model = model.to(orig_device) - return traced_model - - def get_nodes(self, traced_model, op_types=["Linear"]): - if isinstance(op_types, str): - op_types = [op_types] - nodes = [] - for node in traced_model.graph.nodes(): - node_type = node.kind() - for op_type in op_types: - if node_type == op_type: - nodes.append((node, op_type)) - break - return nodes - - def get_prev_absorb_layer(self, nodes): - prev_absorb_layer = [] - for node in nodes: - parent = get_parent(node) - while 1: - if parent.kind() in self.skip_ops_to_find_absorb: - parent = get_parent(parent) - continue - if parent.kind() in self.could_absorb_layers: - parent_out_kinds = [] - for val_user in list(parent.outputs())[0].uses(): - next_node = val_user.user - parent_out_kinds.append(next_node.kind()) - parent_out_kinds = set(parent_out_kinds) - parent_out_kinds.discard("aten::size") - - if parent_out_kinds == parent_out_kinds.intersection(self.could_absorb_layers): - prev_absorb_layer.append(parent) - elif parent_out_kinds.intersection(self.skip_ops_to_find_absorb): - res = self.skip_op_absorb_helper(parent) - prev_absorb_layer.append(parent) if res else prev_absorb_layer.append(None) - else: # When parent to multiple ops, sq transformation could be wrong. - prev_absorb_layer.append(None) - else: - prev_absorb_layer.append(None) - break - return prev_absorb_layer - - def skip_op_absorb_helper(self, parent_node): - for val_user in list(parent_node.outputs())[0].uses(): - next_node = val_user.user - if next_node.kind() == "aten::size": - continue - elif next_node.kind() in self.could_absorb_layers: - continue - elif next_node.kind() in self.skip_ops_to_find_absorb: - node_res = self.skip_op_absorb_helper(next_node) - if not node_res: - return False - else: - return False - return True - - def mapping_torch_module_to_aten(self, op_types): - res = [] - for op in op_types: - if op not in self.supported_torch_module_to_aten.keys(): - logger.warning(f"{op} is not supported in smooth quant, ignoring...") - continue - res.append(self.supported_torch_module_to_aten[op]) - res = list(set(res)) - return res - - def _check_valid_conv(self, module): - """Remove group conv except depthwise conv - :param module: - - :return: - """ - if not isinstance(module, torch.nn.Conv2d): - return True - if module.groups > 1: - if module.in_channels == module.out_channels and module.groups == module.in_channels: - return True - else: - return False - return True - - def get_absorb_to_layer(self, model, example_input, op_types, skip_unsupported_layers=True): - traced_model = self.trace(model, example_input) - if traced_model is None: - return None, None - - aten_op_types = self.mapping_torch_module_to_aten(op_types) - nodes_types = self.get_nodes(traced_model, aten_op_types) - nodes = [node_type[0] for node_type in nodes_types] - nodes_prev_absorb = self.get_prev_absorb_layer(nodes) - absorb_to_layer = {} - no_absorb_layers = [] - for index, absorb in enumerate(nodes_prev_absorb): - if absorb is None: - no_absorb_layers.append(".".join(nodes[index].scopeName().split("/")[-1].split(".")[1:])) - continue - node = nodes[index] - layer_name = ".".join(node.scopeName().split("/")[-1].split(".")[1:]) - absorb_name = ".".join(absorb.scopeName().split("/")[-1].split(".")[1:]) - if layer_name == "" or absorb_name == "": - continue - if absorb_name in absorb_to_layer.keys(): - absorb_to_layer[absorb_name].append(layer_name) - else: - absorb_to_layer[absorb_name] = [layer_name] - if skip_unsupported_layers: - absorb_to_layer = self.remove_unsupported_layers(model, absorb_to_layer, no_absorb_layers) - return absorb_to_layer, no_absorb_layers - - def remove_unsupported_layers(self, model, absorb_to_layer, no_absorb_layers): - res = {} - for key in absorb_to_layer.keys(): - absorb_layer = get_module(model, key) - layer_type = absorb_layer.__class__.__name__ - if layer_type not in self.supported_torch_module_to_aten.keys(): - no_absorb_layers.extend(absorb_to_layer[key]) - continue - supported = True - for layer_name in absorb_to_layer[key]: - layer = get_module(model, layer_name) - layer_type = layer.__class__.__name__ - if (layer_type not in self.supported_torch_module_to_aten.keys()) or not self._check_valid_conv(layer): - supported = False - no_absorb_layers.extend(absorb_to_layer[key]) - break - if supported: - res[key] = absorb_to_layer[key] - return res diff --git a/neural_compressor/torch/algorithms/smooth_quant/utility.py b/neural_compressor/torch/algorithms/smooth_quant/utility.py new file mode 100644 index 00000000000..ceb2657b89a --- /dev/null +++ b/neural_compressor/torch/algorithms/smooth_quant/utility.py @@ -0,0 +1,2341 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import json +import os +import re +import subprocess +from collections import UserDict + +import cpuinfo +import intel_extension_for_pytorch as ipex +import numpy +import psutil +import torch +import tqdm +from packaging.version import Version + +from neural_compressor.torch.algorithms.static_quant import ( + TransformerBasedModelBlockPatternDetector, + dump_model_op_stats, + get_quantizable_ops_from_cfgs, + ipex_config_path, + paser_cfgs, + simple_inference, + unify_op_type_mapping_ipex, +) +from neural_compressor.torch.utils import get_ipex_version, get_torch_version, logger + +version = get_torch_version() +ipex_ver = get_ipex_version() + + +def generate_activation_observer(scheme, algorithm, smooth_quant=False, smooth_quant_enable=False): # pragma: no cover + """This is a helper method to generate an activation observer. + + Args: + scheme (str): Quantization scheme to be used. + algorithm (str): What algorithm for computing the quantization parameters based on. + + Returns: + An observer. + """ + kl_activation_observer = { + "name": "HistogramObserver", + "bins": 2048, + "upsample_rate": 128, + "dtype": "torch.quint8", + "qscheme": "torch.per_tensor_affine", + "reduce_range": False, + "quant_min": 0, + "quant_max": 255, + } + minmax_activation_observer = { + "name": "MinMaxObserver", + "dtype": "torch.quint8", + "qscheme": "torch.per_tensor_affine", + "reduce_range": False, + "quant_min": 0, + "quant_max": 255, + } + smoothquant_kl_activation_observer = { + "name": "SmoothQuantActivationObserver", + "smooth_quant_enabled": smooth_quant_enable, + "dtype": "torch.quint8", + "qscheme": "torch.per_tensor_affine", + "reduce_range": False, + "quant_min": 0, + "quant_max": 255, + "alpha": 0.5, + "act_observer": kl_activation_observer, + "act_ic_observer": { + "name": "PerChannelMinMaxObserver", + "ch_axis": -1, + "dtype": "torch.quint8", + "qscheme": "torch.per_channel_affine", + "reduce_range": False, + "quant_min": 0, + "quant_max": 255, + }, + } + smoothquant_minmax_activation_observer = { + "name": "SmoothQuantActivationObserver", + "smooth_quant_enabled": smooth_quant_enable, + "dtype": "torch.quint8", + "qscheme": "torch.per_tensor_affine", + "reduce_range": False, + "quant_min": 0, + "quant_max": 255, + "alpha": 0.5, + "act_observer": minmax_activation_observer, + "act_ic_observer": { + "name": "PerChannelMinMaxObserver", + "ch_axis": -1, + "dtype": "torch.quint8", + "qscheme": "torch.per_channel_affine", + "reduce_range": False, + "quant_min": 0, + "quant_max": 255, + }, + } + REDUCE_RANGE = False if CpuInfo().vnni else True + if REDUCE_RANGE: + minmax_activation_observer["reduce_range"] = REDUCE_RANGE + kl_activation_observer["reduce_range"] = REDUCE_RANGE + if scheme == "sym": + minmax_activation_observer["qscheme"] = "torch.per_tensor_symmetric" + minmax_activation_observer["dtype"] = "torch.qint8" + minmax_activation_observer["quant_min"] = -128 + minmax_activation_observer["quant_max"] = 127 + kl_activation_observer["qscheme"] = "torch.per_tensor_symmetric" + kl_activation_observer["dtype"] = "torch.qint8" + kl_activation_observer["quant_min"] = -128 + kl_activation_observer["quant_max"] = 127 + if smooth_quant and smooth_quant_enable: + if algorithm == "kl": + return smoothquant_kl_activation_observer + if algorithm == "minmax": + return smoothquant_minmax_activation_observer + else: + if algorithm == "kl": + return kl_activation_observer + if algorithm == "minmax": + return minmax_activation_observer + + +def check_cfg_and_qconfig( + tune_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_op_name, smooth_quant=False +): # pragma: no cover + """Check configs and quantization configs. + + Args: + tune_cfg (dict): dictionary of quantization configuration. + cfgs (dict): the input configs. + op_infos_from_cfgs (dict): op infos from configs. + output_tensor_ids_op_name (dict): dictionary of output tensor op names. + + Returns: + cfgs (dict). + """ + for op_name in tune_cfg: + inc_op_cfg = tune_cfg[op_name] + for i, name in enumerate(op_name[0]): + # to int8 + ipex_op_cfg = op_infos_from_cfgs[name] + input_tensor_infos = ipex_op_cfg["input_tensor_infos"] + if op_name[1] == "Linear" or op_name[1] == "Linear&add": # record op_name for possible op-wise fallback + logger.debug(f"ipex_op_cfg['fqn'] - op_name {ipex_op_cfg['fqn']} {op_name}") + for index, input_tensor_info in enumerate(input_tensor_infos): + if "force_dtype" not in input_tensor_info.keys(): + continue + if ( + input_tensor_info["force_dtype"] == "torch.qint8" + or input_tensor_info["force_dtype"] == "torch.quint8" + ): + # int8 -> int8 + if inc_op_cfg["weight"]["dtype"] == "int8": + inc_scheme = inc_op_cfg["activation"]["scheme"] + inc_algorithm = inc_op_cfg["activation"]["algorithm"] + ipex_op_cfg["input_tensor_infos"] = input_tensor_infos + if ( + "op_type" in ipex_op_cfg + and ipex_op_cfg["op_type"] == "" + ): + smooth_quant_enable = True + else: + smooth_quant_enable = False + activation_observer = generate_activation_observer( + inc_scheme, inc_algorithm, smooth_quant, smooth_quant_enable + ) + if not smooth_quant: + if inc_scheme == "sym": + input_tensor_infos[index]["force_dtype"] = "torch.qint8" + if inc_scheme == "asym": + input_tensor_infos[index]["force_dtype"] = "torch.quint8" + ipex_op_cfg["activation_observer"] = activation_observer + # int8 -> fp32 + else: + input_tensor_infos[index]["force_dtype"] = "torch.float32" + # modify pre_op output inf_dtype + if i == 0: + input_tensor_id = input_tensor_info["id"] + input_tensor_dtype = input_tensor_info["force_dtype"] + if input_tensor_id in output_tensor_ids_op_name.keys(): + pre_op_name = output_tensor_ids_op_name[input_tensor_id] + pre_op_module = pre_op_name[0][0] + pre_op_state = pre_op_name[0][1] + pre_op_index = pre_op_name[0][2] + pre_op_infos = cfgs[pre_op_module][pre_op_state][pre_op_index] + pre_op_output_infos = pre_op_infos["output_tensor_infos"] + for index, pre_op_output in enumerate(pre_op_output_infos): + if pre_op_output["id"] == input_tensor_id: + pre_op_output_infos[index]["inf_dtype"] = input_tensor_dtype + else: + pass + pre_op_infos["output_tensor_infos"] = pre_op_output_infos + cfgs[pre_op_module][pre_op_state][pre_op_index] = pre_op_infos + else: + pass + cfgs[name[0]][name[1]][name[2]] = ipex_op_cfg + return cfgs + + +def cfg_to_qconfig( + tune_cfg, cfgs, op_infos_from_cfgs, output_tensor_id_op_name, smooth_quant=False +): # pragma: no cover + assert cfgs is not None, "No configure for IPEX int8 model..." + op_infos = copy.deepcopy(op_infos_from_cfgs) + cfgs = check_cfg_and_qconfig(tune_cfg["op"], cfgs, op_infos, output_tensor_id_op_name, smooth_quant) + with open(ipex_config_path, "w") as write_f: + json.dump(cfgs, write_f, indent=4) + return None + + +def get_quantizable_ops_recursively(model, example_inputs): # pragma: no cover + """Get all quantizable ops from model. + + Args: + model (object): input model + example_inputs (dict|list|tuple|torch.Tensor): used to trace torch model. + Returns: + quantizable_ops (list): list of tuples of op_name and op_type. + cfgs (dict): dict of configuration + """ + quantizable_ops = [] + # group ops by position for transform-based model + detector = TransformerBasedModelBlockPatternDetector(model) + detect_result = detector.detect_block() + attention_block = detect_result.get("attention_blocks", None) + ffn_blocks = detect_result.get("ffn_blocks", None) + logger.info(f"Attention Blocks: {len(attention_block)}") + logger.info(f"FFN Blocks: {len(ffn_blocks)}") + if not os.path.exists(ipex_config_path): + assert isinstance(model, torch.nn.Module), "The model passed in is not the instance of torch.nn.Module" + + if hasattr(model, "save_qconf_summary"): # pragma: no cover + os.makedirs(os.path.dirname(ipex_config_path), exist_ok=True) + model.save_qconf_summary(qconf_summary=ipex_config_path) + else: + model.eval() + + # create a quantization config file for intel pytorch extension model + os.makedirs(os.path.dirname(ipex_config_path), exist_ok=True) + assert example_inputs is not None, "IPEX need q_dataloader or example_inputs to prepare the model" + from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver, QConfig + + if ipex_ver.release >= Version("2.1").release: + # HistogramObserver will cause a performance issue. + # static_qconfig = ipex.quantization.default_static_qconfig_mapping + qconfig = QConfig( + activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8), + weight=PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric), + ) + from torch.ao.quantization import QConfigMapping + + static_qconfig = QConfigMapping().set_global(qconfig) + else: + static_qconfig = QConfig( + activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8), + weight=PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric), + ) + + if isinstance(example_inputs, dict): + model = ipex.quantization.prepare(model, static_qconfig, example_kwarg_inputs=example_inputs, inplace=True) + else: + model = ipex.quantization.prepare(model, static_qconfig, example_inputs=example_inputs, inplace=True) + simple_inference(model, example_inputs, iterations=1) + model.save_qconf_summary(qconf_summary=ipex_config_path) + + map_op_name_to_fqn = {} + with open(ipex_config_path, "r") as f: + cfgs = json.load(f) + if ipex_ver.release < Version("1.12.0").release: # pragma: no cover + for op_cfg in cfgs: + if op_cfg["name"] in unify_op_type_mapping_ipex: + quantizable_ops.append((op_cfg["id"], unify_op_type_mapping_ipex[op_cfg["name"]])) + else: + re_flag = False + for pattern, unify_op_type in unify_op_type_mapping_ipex["re"].items(): + if re.match(pattern, op_cfg["name"]): + re_flag = True + quantizable_ops.append((op_cfg["id"], unify_op_type)) + break + if not re_flag: + quantizable_ops.append((op_cfg["id"], op_cfg["name"])) + else: + ( + ops_name, + op_infos_from_cfgs, + input_tensor_id_op_name, + output_tensor_id_op_name, + ) = paser_cfgs(cfgs) + quantizable_op_names = get_quantizable_ops_from_cfgs(ops_name, op_infos_from_cfgs, input_tensor_id_op_name) + for name in quantizable_op_names: + # name : list + if len(name) == 1: + module_key = name[0][0] + op_cfg_id = name[0][2] + ipex_op_type = cfgs[module_key]["q_op_infos"][op_cfg_id]["op_type"] + module_fqn = cfgs[module_key]["q_op_infos"][op_cfg_id].get("fqn", None) + + if ipex_op_type in unify_op_type_mapping_ipex: + quantizable_ops.append((tuple(name), unify_op_type_mapping_ipex[ipex_op_type])) + map_op_name_to_fqn[(tuple(name), ipex_op_type)] = module_fqn + else: + re_flag = False + for pattern, unify_op_type in unify_op_type_mapping_ipex["re"].items(): + if re.match(pattern, ipex_op_type): + re_flag = True + quantizable_ops.append((tuple(name), unify_op_type)) + map_op_name_to_fqn[(tuple(name), unify_op_type)] = module_fqn + break + if not re_flag: + quantizable_ops.append((tuple(name), ipex_op_type)) + map_op_name_to_fqn[(tuple(name), ipex_op_type)] = module_fqn + else: + op_type = "" + for op_name in name: + module_key = op_name[0] + op_cfg_id = op_name[2] + single_op_type = cfgs[module_key]["q_op_infos"][op_cfg_id]["op_type"] + if single_op_type in unify_op_type_mapping_ipex: + single_op_type = unify_op_type_mapping_ipex[single_op_type] + op_type += "&" + single_op_type if op_type else single_op_type + quantizable_ops.append((tuple(name), op_type)) + _module_key = name[0][0] + _op_cfg_id = name[0][2] + module_fqn = cfgs[_module_key]["q_op_infos"][_op_cfg_id]["fqn"] + map_op_name_to_fqn[(tuple(name), op_type)] = module_fqn + + logger.debug("Map op name to fqn: ") + logger.debug(map_op_name_to_fqn) + logger.info("Attention Blocks : ") + logger.info(attention_block) + logger.info("FFN Blocks : ") + logger.info(ffn_blocks) + return quantizable_ops, cfgs, op_infos_from_cfgs, output_tensor_id_op_name + + +def get_parent(node, all_parents=False): # pragma: no cover + if node.inputs() is None: + return None + elif len(list(node.inputs())) == 0: + return None + if not all_parents: + return list(node.inputs())[0].node() + else: + return list(node.inputs()) + + +def get_module(model, key): # pragma: no cover + """Get module from model by key name. + + Args: + model (torch.nn.Module): original model + key (str): module name to be replaced + """ + module = model + name_list = key.split(".") + for name in name_list: + if hasattr(module, name): + module = getattr(module, name) + elif hasattr(module, "sq_linear"): # for peft models + module = getattr(module, "sq_linear") + module = getattr(module, name) + elif hasattr(module, "orig_layer"): # for peft models and auto alpha + module = getattr(module, "orig_layer") + module = getattr(module, name) + else: + module = module + return module + + +def set_module(model, key, new_module): # pragma: no cover + """Set new module into model by key name. + + Args: + model (torch.nn.Module): original model + key (str): module name to be replaced + new_module (torch.nn.Module): new module to be inserted + """ + module = model + name_list = key.split(".") + for name in name_list[:-1]: + if hasattr(module, name): + module = getattr(module, name) + elif hasattr(module, ("sq_linear")): # for peft models that Linears are contained in Linear + module = getattr(module, "sq_linear") + module = getattr(module, name) + elif hasattr(module, ("orig_layer")): # for peft models and auto alpha + module = getattr(module, "orig_layer") + module = getattr(module, name) + else: + module = module + + if hasattr(module, "sq_linear") and name_list[-1] != "sq_linear": # for peft models + module = getattr(module, "sq_linear") + if hasattr(module, "orig_layer") and name_list[-1] != "orig_layer": # for peft models and auto alpha + module = getattr(module, "orig_layer") + setattr(module, name_list[-1], new_module) + + +def update_sq_scale(ipex_config_path, smoothquant_scale_info): # pragma: no cover + """Update ipex_config.json with smoothquant scale info generated by our algorithm. + + Args: + ipex_config_path (str): a path to temporary ipex_config.json file. + smoothquant_scale_info (dict): a dict contains smoothquant scale info. + """ + with open(ipex_config_path, "r") as f: + ipex_config = json.load(f) + for module_name, v in ipex_config.items(): + if "q_op_infos" in v and v["q_op_infos"]: + for op_num, v1 in v["q_op_infos"].items(): + # update alpha data instead of updating weight scale + op_name = v1["fqn"] # fqn always exists even it's empty. + if op_name in smoothquant_scale_info and v1["op_type_is_module"]: + input_scale_for_mul = smoothquant_scale_info[op_name]["input_scale_for_mul"].tolist() + input_scale_after_mul = smoothquant_scale_info[op_name]["input_scale_after_mul"].tolist() + input_zero_point_after_mul = smoothquant_scale_info[op_name][ + "input_zero_point_after_mul" + ].tolist() + weight_scale_for_mul = (1 / smoothquant_scale_info[op_name]["input_scale_for_mul"]).tolist() + weight_scale_after_mul = smoothquant_scale_info[op_name]["weight_scale_after_mul"].tolist() + v1["input_tensor_infos"][0]["scale"] = input_scale_after_mul + v1["input_tensor_infos"][0]["zero_point"] = input_zero_point_after_mul + v1["input_tensor_infos"][0]["smooth_quant_scaling_factor"] = input_scale_for_mul + v1["weight_tensor_infos"][0]["smooth_quant_scaling_factor"] = weight_scale_for_mul + v1["weight_tensor_infos"][0]["scale"] = weight_scale_after_mul + # # observers were overridden by the fallback step, setting it back. + f.close() + # overwrite ipex_config_path + with open(ipex_config_path, "w") as f1: + json.dump(ipex_config, f1, indent=4) + f1.close() + + +def enough_memo_store_scale(device, need_space): # pragma: no cover + if device == "cuda": # pragma: no cover + current_gpu_index = torch.cuda.current_device() + total_memory = torch.cuda.get_device_properties(current_gpu_index).total_memory + used_memory = torch.cuda.memory_allocated(current_gpu_index) + free_space = total_memory - used_memory + else: + import psutil + + free_space = psutil.virtual_memory().free + return free_space >= need_space + + +def move_input_to_device(input, device=torch.device("cpu")): # pragma: no cover + if isinstance(input, dict) or isinstance(input, UserDict): + tmp_input = {} + for k, inp in input.items(): + tmp_input[k] = move_input_to_device(inp, device) + input = tmp_input + elif isinstance(input, list) or isinstance(input, tuple): + is_tuple = isinstance(input, tuple) + tmp_input = [] + for inp in input: + tmp_input.append(move_input_to_device(inp, device)) + input = tuple(tmp_input) if is_tuple else tmp_input + elif isinstance(input, torch.Tensor): + input = input.to(device) # pylint: disable=no-member + return input + + +def forward_wrapper(model, input, device=torch.device("cpu")): # pragma: no cover + try: + model = model.to(device) + input = move_input_to_device(input, device) + except Exception as e: + logger.warning(e) + logger.warning("Please check the input device if the error raised.") + if isinstance(input, dict) or isinstance(input, UserDict): + output = model(**input) + elif isinstance(input, list) or isinstance(input, tuple): + try: + output = model(*input) + except: + output = model(input) + else: + output = model(input) + return output + + +def model_forward(model, dataloader, iters, device): # pragma: no cover + try: + cnt = 0 + for idx, (input, label) in enumerate(dataloader): + output = forward_wrapper(model, input, device) + cnt += 1 + if iters != -1 and cnt >= iters: + break + except Exception as e: + cnt = 0 + for idx, input in enumerate(dataloader): + output = forward_wrapper(model, input, device) + cnt += 1 + if iters != -1 and cnt >= iters: + break + + +def cal_scale(input_max_abs, weights, alpha, weight_max_lb=1e-5): # pragma: no cover + weights = torch.cat(weights, dim=0) + weight_max = torch.max(torch.abs(weights), dim=0)[0] + weight_max = torch.clip(weight_max, weight_max_lb) + input_power = torch.pow(input_max_abs, alpha) + logger.debug(f"{max(input_max_abs)}, {min(input_max_abs)}") + weight_power = torch.pow(weight_max, 1 - alpha) + weight_scale = torch.clip(input_power / weight_power, min=1e-5) + weight_scale[input_power == 0] = 1.0 + return weight_scale + + +def model_forward_per_sample(model, sample, device): # pragma: no cover + try: + output = forward_wrapper(model, sample, device) + return output + + except Exception as e: + output = forward_wrapper(model, sample[0], device) + return output + + +def quant_dequant_w_v1(m, num_bits=8, scheme="sym"): # pragma: no cover + eps = torch.finfo(torch.float32).eps + if isinstance(m, torch.nn.Linear): + x = m.weight + tmp = torch.zeros(torch.max(x, dim=1).values.size()) + if scheme == "sym": + q_min, q_max = -(2.0 ** (num_bits - 1)), 2.0 ** (num_bits - 1) - 1.0 + x_max = torch.max(torch.abs(x), dim=1).values + scale = x_max / (float(q_max - q_min) / 2) + else: + q_min, q_max = 0, 2.0**num_bits - 1.0 + x_max = torch.maximum(torch.max(x, dim=1).values, tmp) + x_min = torch.minimum(torch.min(x, dim=1).values, tmp) + scale = (x_max - x_min) / (2**num_bits - 1) + + scale = torch.clip(scale, min=eps) + + if scheme == "sym": + bias = 0 + else: + bias = torch.round(0 - (torch.min(x, dim=1).values) / scale) + bias = bias.unsqueeze(dim=-1) + scale = scale.unsqueeze(dim=-1) + q_x = torch.round(x / scale + bias) + q_x.clamp_(q_min, q_max) + return (q_x - bias) * scale + elif isinstance(m, torch.nn.Conv2d): + x = m.weight + x = torch.permute(x, (0, 2, 3, 1)) + x = x.reshape(-1, x.shape[-1]) + tmp = torch.zeros(torch.max(x, dim=0).values.size()) + if scheme == "sym": + q_min, q_max = -(2.0 ** (num_bits - 1)), 2.0 ** (num_bits - 1) - 1.0 + x_max = torch.max(torch.abs(x), dim=0).values + scale = x_max / (2 ** (num_bits - 1) - 1) + else: + q_min, q_max = 0, 2.0**num_bits - 1.0 + x_max = torch.maximum(torch.max(x, dim=0).values, tmp) + x_min = torch.minimum(torch.min(x, dim=0).values, tmp) + scale = (x_max - x_min) / (2**num_bits - 1) + scale = torch.clip(scale, min=eps) + if scheme == "sym": + bias = 0 + else: + bias = torch.round(0 - (torch.min(x, dim=0).values) / scale) + bias = bias.unsqueeze(dim=0) + scale = scale.unsqueeze(dim=0) + + q_x = x / scale + bias + q_x.clamp_(q_min, q_max).round_() + q_dq_x = (q_x - bias) * scale + q_dq_x = q_dq_x.view(m.weight.shape[0], m.weight.shape[2], m.weight.shape[3], m.weight.shape[1]) + q_dq_x = torch.permute(q_dq_x, (0, 3, 1, 2)) + return q_dq_x + else: + logger.warning("unsupported layer type, please have a check") + + +def quant_dequant_x_v1(x, min_x=None, max_x=None, num_bits=8): # pragma: no cover + eps = torch.finfo(torch.float32).eps + q_min, q_max = 0, 2.0**num_bits - 1.0 + if max_x is None or min_x is None: + max_x, min_x = torch.max(x), torch.min(x) + else: + max_x = torch.max(max_x) + min_x = torch.min(min_x) + scale = (max_x - min_x) / (2**num_bits - 1) + scale = torch.clip(scale, min=eps) + bias = torch.round((0 - min_x) / scale) + q_x = torch.round(x / scale + bias) + q_x.clamp_(q_min, q_max) + return scale * (q_x - bias) + + +def reshape_scale_as_weight(layer, scale): # pragma: no cover + """Reshape the scale for weight input channel, depthwise output channel + :param layer: torch module + :param scale: orig scale + :return: reshaped scale.""" + if hasattr(layer, "orig_layer"): + layer = layer.orig_layer + if isinstance(layer, torch.nn.Conv2d) and layer.groups > 1: ##only depthwise conv could hit here + scale = scale.view(scale.shape[0], 1, 1, 1) ##mount on output channel + + elif isinstance(layer, torch.nn.Conv2d): + scale = scale.view(1, scale.shape[0], 1, 1) + + elif isinstance(layer, torch.nn.Linear): + scale = scale.view(1, scale.shape[0]) + + return scale + + +def reshape_in_channel_to_last(layer_name, model): # pragma: no cover + """Move the input channel to the last dim + :param layer_name: Layer name + :return: The reshaped weight.""" + layer = get_module(model, layer_name) + if layer.__class__.__name__ == "WrapperLayer": + layer = layer.orig_layer + + weight = layer.weight ##TODO oc*ic, support transposed conv + if len(weight.shape) == 4: + weight = weight.permute(0, 2, 3, 1) + weight = weight.reshape(-1, weight.shape[-1]) + return weight + + +def reshape_scale_as_input(layer, scale): # pragma: no cover + """Reshape the scale for input feature in channel + :param layer: + + :param scale: + :return: + """ + if hasattr(layer, "orig_layer"): + layer = layer.orig_layer + if isinstance(layer, torch.nn.Conv2d): + scale = scale.view(1, scale.shape[0], 1, 1) + + elif isinstance(layer, torch.nn.Linear): + scale = scale.view(1, scale.shape[0]) + + return scale + + +TUNERS = {} + + +def register_autotune(name): # pragma: no cover + """Class decorator to register a smoothquant auto-tune subclass. + + :return: the class of register + """ + + def register(auto_tune): + TUNERS[name] = auto_tune + return auto_tune + + return register + + +class Calibration: # pragma: no cover + def __init__(self, model, dataloder=None, q_func=None, device="cpu"): + self.model = model + self.dataloader = dataloder + self.q_func = q_func + self.device = device + + @torch.no_grad() + def _save_input_pc_hook(self, name): + """A forward hook to save input max of a module + :param name: the module name + :return: A hook function.""" + + def save_input_hook(module, inputs, outputs): + input = inputs[0] + ##TODO check input channel is correct + if len(module.weight.shape) == 4: ##conv3d or conv1d not supported now, need better way + input = input.permute(0, 2, 3, 1) + input = input.reshape(-1, input.shape[-1]) + max_tensor = torch.max(input, dim=0)[0] + min_tensor = torch.min(input, dim=0)[0] + if name not in self.input_maxes.keys(): + self.input_mins[name], self.input_maxes[name] = min_tensor, max_tensor + else: + self.input_mins[name] = torch.min(self.input_mins[name], min_tensor) + self.input_maxes[name] = torch.max(self.input_maxes[name], max_tensor) + + return save_input_hook + + @torch.no_grad() + def _add_min_max_observer(self, modules): + """ + :param modules: the modules which the observer will insert to + :return: + """ + self.hook_handles = [] + for key in modules.keys(): + hook_func = self._save_input_pc_hook(key) + hook_handle = modules[key].register_forward_hook(hook_func) + self.hook_handles.append(hook_handle) + + @torch.no_grad() + def _remove_observer(self): + """Remove the observer from the model + :return:""" + for hook_handle in self.hook_handles: + hook_handle.remove() + + @torch.no_grad() + def _dump_min_max(self, calib_iter=100): + """Dump min max per channel information, the min max value will be saved in input_maxes attribute + :param calibration_method: only support min_max currently + :param calib_iter: Sample size for calibration + :return:""" + logger.info("Calibrating...") + if self.q_func: + self.q_func(self.model) + else: + assert self.dataloader, "Please set dataloader for calibration." + model_forward(self.model, self.dataloader, calib_iter, self.device) + + @torch.no_grad() + def calibrate(self, calib_iter, op_types=[torch.nn.Conv2d, torch.nn.Linear]): ##TODO transformers.conv1d + """ + :param absorb_to_layer: A dict,key is the absorb layer, val is a list of the to be smoothed layer + :param calib_iter: Data size for calibration + :return: A dict that saved the layer name and the channel-wise max value info + """ + ##hook all the module + self.input_mins = {} + self.input_maxes = {} + + hook_modules = {} + for n, module in self.model.named_modules(): + if isinstance(module, tuple(op_types)): + hook_modules[n] = module + + self._add_min_max_observer(hook_modules) + + self._dump_min_max(calib_iter=calib_iter) + self._remove_observer() + return self.input_mins, self.input_maxes + + +class GraphTrace: # pragma: no cover + """""" + + def __init__(self): + self.supported_torch_module_to_aten = { + "Linear": "aten::linear", + "Conv2d": "aten::_convolution", + "ConvTranspose2d": "aten::_convolution", + "LayerNorm": "aten::layer_norm", + "BatchNorm2d": "aten::batch_norm", + "GroupNorm": "aten::group_norm", + "InstanceNorm2d": "aten::instance_norm", + "LlamaRMSNorm": "aten::mul", + "T5LayerNorm": "aten::mul", + "LPLayerNorm": "aten::layer_norm", ##mpt_chat + } + + ##TODO potential bug, need to check only have one bug + ##TODO, must satisfy af(x)=f(ax),current skip layer may be incomplete + self.skip_ops_to_find_absorb = ["aten::to", "aten::relu", "aten::leaky_relu", "aten::hardtanh"] + + self.could_absorb_layers = [ + "aten::layer_norm", + "aten::batch_norm", + "aten::linear", + "aten::_convolution", + "aten::group_norm", + "aten::instance_norm", + "aten::mul", + ] ##TODO,support more norm + + def trace(self, model, dummy_input): + traced_model = None + optimize_numerics = False + orig_device = str(next(model.parameters()).device) + if orig_device != "cpu" and orig_device != "meta": # pragma: no cover + model = model.to("cpu") + dummy_input = move_input_to_device(dummy_input, "cpu") + if isinstance(dummy_input, dict) or isinstance(dummy_input, UserDict): + try: + traced_model = torch.jit.trace( + model, example_kwarg_inputs=dict(dummy_input), strict=False, check_trace=False + ) + traced_model = torch.jit.freeze(traced_model.eval(), optimize_numerics=optimize_numerics) + except Exception as e: + logger.warning(e) + logger.warning("Jit trace in GraphTrace failed, absorb layer detection is skipped") + else: + try: + traced_model = torch.jit.trace(model, dummy_input, strict=False) + traced_model = torch.jit.freeze(traced_model.eval(), optimize_numerics=optimize_numerics) + except: + try: + traced_model = torch.jit.trace(model, dummy_input[0], strict=False) + traced_model = torch.jit.freeze(traced_model.eval(), optimize_numerics=optimize_numerics) + except Exception as e: + logger.warning(e) + logger.warning("Jit trace in GraphTrace failed, absorb layer detection is skipped") + model = model.to(orig_device) + return traced_model + + def get_nodes(self, traced_model, op_types=["Linear"]): + if isinstance(op_types, str): + op_types = [op_types] + nodes = [] + for node in traced_model.graph.nodes(): + node_type = node.kind() + for op_type in op_types: + if node_type == op_type: + nodes.append((node, op_type)) + break + return nodes + + def get_prev_absorb_layer(self, nodes): + prev_absorb_layer = [] + for node in nodes: + parent = get_parent(node) + while 1: + if parent.kind() in self.skip_ops_to_find_absorb: + parent = get_parent(parent) + continue + if parent.kind() in self.could_absorb_layers: + parent_out_kinds = [] + for val_user in list(parent.outputs())[0].uses(): + next_node = val_user.user + parent_out_kinds.append(next_node.kind()) + parent_out_kinds = set(parent_out_kinds) + parent_out_kinds.discard("aten::size") + + if parent_out_kinds == parent_out_kinds.intersection(self.could_absorb_layers): + prev_absorb_layer.append(parent) + elif parent_out_kinds.intersection(self.skip_ops_to_find_absorb): + res = self.skip_op_absorb_helper(parent) + prev_absorb_layer.append(parent) if res else prev_absorb_layer.append(None) + else: # When parent to multiple ops, sq transformation could be wrong. + prev_absorb_layer.append(None) + else: + prev_absorb_layer.append(None) + break + return prev_absorb_layer + + def skip_op_absorb_helper(self, parent_node): + for val_user in list(parent_node.outputs())[0].uses(): + next_node = val_user.user + if next_node.kind() == "aten::size": + continue + elif next_node.kind() in self.could_absorb_layers: + continue + elif next_node.kind() in self.skip_ops_to_find_absorb: + node_res = self.skip_op_absorb_helper(next_node) + if not node_res: + return False + else: + return False + return True + + def mapping_torch_module_to_aten(self, op_types): + res = [] + for op in op_types: + if op not in self.supported_torch_module_to_aten.keys(): + logger.warning(f"{op} is not supported in smooth quant, ignoring...") + continue + res.append(self.supported_torch_module_to_aten[op]) + res = list(set(res)) + return res + + def _check_valid_conv(self, module): + """Remove group conv except depthwise conv + :param module: + + :return: + """ + if not isinstance(module, torch.nn.Conv2d): + return True + if module.groups > 1: + if module.in_channels == module.out_channels and module.groups == module.in_channels: + return True + else: + return False + return True + + def get_absorb_to_layer(self, model, example_input, op_types, skip_unsupported_layers=True): + traced_model = self.trace(model, example_input) + if traced_model is None: + return None, None + + aten_op_types = self.mapping_torch_module_to_aten(op_types) + nodes_types = self.get_nodes(traced_model, aten_op_types) + nodes = [node_type[0] for node_type in nodes_types] + nodes_prev_absorb = self.get_prev_absorb_layer(nodes) + absorb_to_layer = {} + no_absorb_layers = [] + for index, absorb in enumerate(nodes_prev_absorb): + if absorb is None: + no_absorb_layers.append(".".join(nodes[index].scopeName().split("/")[-1].split(".")[1:])) + continue + node = nodes[index] + layer_name = ".".join(node.scopeName().split("/")[-1].split(".")[1:]) + absorb_name = ".".join(absorb.scopeName().split("/")[-1].split(".")[1:]) + if layer_name == "" or absorb_name == "": + continue + if absorb_name in absorb_to_layer.keys(): + absorb_to_layer[absorb_name].append(layer_name) + else: + absorb_to_layer[absorb_name] = [layer_name] + if skip_unsupported_layers: + absorb_to_layer = self.remove_unsupported_layers(model, absorb_to_layer, no_absorb_layers) + return absorb_to_layer, no_absorb_layers + + def remove_unsupported_layers(self, model, absorb_to_layer, no_absorb_layers): + res = {} + for key in absorb_to_layer.keys(): + absorb_layer = get_module(model, key) + layer_type = absorb_layer.__class__.__name__ + if layer_type not in self.supported_torch_module_to_aten.keys(): + no_absorb_layers.extend(absorb_to_layer[key]) + continue + supported = True + for layer_name in absorb_to_layer[key]: + layer = get_module(model, layer_name) + layer_type = layer.__class__.__name__ + if (layer_type not in self.supported_torch_module_to_aten.keys()) or not self._check_valid_conv(layer): + supported = False + no_absorb_layers.extend(absorb_to_layer[key]) + break + if supported: + res[key] = absorb_to_layer[key] + return res + + +@register_autotune("version1") +class AutoAlpha: + def __init__( + self, + model, + dataloader, + absorb_to_layer, + op_types, + device, + q_func, + example_inputs, + weight_clip=True, + alpha_min=0.3, + alpha_max=0.7, + alpha_step=0.1, + shared_criterion="mean", + init_alpha=0.5, + folding=False, + do_blockwise=False, + n_samples=32, + ): + """Initialize the AutoAlpha tuner with necessary parameters and components.""" + + self.model = model.to("cpu") + self.model.eval() + self.dataloader = dataloader + self.alpha_min = alpha_min + self.alpha_max = alpha_max + self.alpha_step = alpha_step + self.shared_criterion = shared_criterion + self.init_alpha = init_alpha + self.loss_type = "blockwise" if do_blockwise else "model_wise" + self.calib_sample_num = n_samples if n_samples else 32 + self.op_types = op_types + self.absorb_to_layer = absorb_to_layer + self.weight_scale_dict = {} + self.q_func = q_func + self.folding = folding + self.example_inputs = example_inputs + self.max_value_info = {} # to record max values for alpha tune + self.weight_clip = weight_clip[0] if isinstance(weight_clip, tuple) else weight_clip + self.input_maxes = {} + self.input_mins = {} + self.input_maxes_abs = {} + self.device = device + + def tune(self): + """The main entry of auto_alpha + :return: Optimal alpha values and scales based on user-defined recipes.""" + calib = Calibration(self.model, self.dataloader, self.q_func, self.device) + calib_iter = 100 + self.input_mins, self.input_maxes = calib.calibrate(calib_iter, self.op_types) + for key in self.input_mins.keys(): + self.input_maxes_abs[key] = torch.max(torch.abs(self.input_mins[key]), torch.abs(self.input_maxes[key])) + + if not self.folding: + diff_modules = set(self.absorb_to_layer.keys()).difference(self.input_mins.keys()) + for d in diff_modules: + del self.absorb_to_layer[d] + + scale_memo_use = 0 + for key in self.absorb_to_layer: + layer_name = self.absorb_to_layer[key][0] + input_max = self.input_maxes_abs[layer_name] + scale_memo_use += 4 * input_max.shape[0] * len(self.absorb_to_layer[key]) + alpha_space_len = (self.alpha_max - self.alpha_min) / self.alpha_step + 1 + scale_memo_use *= alpha_space_len + self._save_scale = enough_memo_store_scale(self.device, scale_memo_use) + + if self.loss_type == "blockwise": + self.block_names = self.get_blocks() + logger.info("Blockwise auto-tuning will be performed") + module_names = self._get_sq_layer_names() + block_names, self.block_to_module = self.block_names, {} + for block in block_names: + self.block_to_module[block] = [] + for module in module_names: + checked = False + for block in block_names: + if block + "." in module: + self.block_to_module[block].append(module) + checked = True + if not checked: + self.block_to_module[module] = [module] + self.block_names = list(self.block_to_module.keys()) + logger.info(f"Blockwise auto-tuning: {len(self.block_names)} blocks found") + logger.debug(f"Blockwise auto-tuning blocks info: {self.block_to_module}") + return self._auto_tune_alpha_blockwise() + else: + return self._auto_tune_alpha() + + def get_blocks(self): + """Obtain a list of blocks in block-wise tuning mode.""" + block_names = [] + for n, m in self.model.named_modules(): + if hasattr(type(m), "__name__") and "ModuleList" in type(m).__name__: + for nn, mm in m.named_children(): + block_name = n + "." + nn + block_names.append(block_name) + break + return block_names + + def _add_blockwise_observer(self, block_modules): + """ + :param block_modules: the block modules which the observer will insert to + :return: + """ + self.blockwise_hook_handles = [] + for key in block_modules.keys(): + hook_func = self._save_blockwise_hook(key) + hook_handle = block_modules[key].register_forward_hook(hook_func) + self.blockwise_hook_handles.append(hook_handle) + + def _save_blockwise_hook(self, name): + """A forward hook to save inputs/outputs of a block + :param name: the block name + :return: A hook function.""" + + def save_blockwise_hook(module, inputs, outputs): + self.block_inputs[name] = inputs[0] + self.block_outputs[name] = outputs[0] + + return save_blockwise_hook + + def _get_all_hook_module_names(self): + """Obtain all the modules that could be hooked based on given op_types.""" + module_names = [] + for n, module in self.model.named_modules(): + if isinstance(module, tuple(self.op_types)): + module_names.append(n) + return module_names + + def _update_scales_for_auto(self, absorb_scales, weight_scales): + """Apply activation and weight scales to the model.""" + for key in self.absorb_to_layer.keys(): + layer_names = self.absorb_to_layer[key] + for layer_name in layer_names: + layer = get_module(self.model, layer_name) + input_scale = absorb_scales[key] + weight_scale = weight_scales[layer_name] + input_scale = reshape_scale_as_input(layer, input_scale) + weight_scale = reshape_scale_as_weight(layer, weight_scale) + layer.update_scale(input_scale, weight_scale) ##FIXME + + def _change_qdq_for_auto(self, enable=True): + """Change the option for qdq.""" + module_names = self._get_all_hook_module_names() + for name in module_names: + name = name.split(".orig_layer")[0] + module = get_module(self.model, name) + if not hasattr(module, "orig_layer"): # skip module if it's not used in calibration + continue + if enable: + module.enable_quant() + else: + module.disable_quant() + + def _qdq_model_wrapper_for_auto(self, save_q_input=False): + """Wrapper all the module with qdq + :return:""" + module_names = self._get_all_hook_module_names() + self.to_unwrap_module_names = module_names + for name in module_names: + if name not in self.input_mins: # skip module if it's not used in calibration + continue + module = get_module(self.model, name) + new_module = WrapperLayer(module, self.input_mins[name], self.input_maxes[name], save_q_input=save_q_input) + set_module(self.model, name, new_module) + + def _qdq_model_unwrapper_for_auto(self): + """Unwrapper all the module with qdq + :return:""" + module_names = self.to_unwrap_module_names + for name in module_names: + module = get_module(self.model, name) + if not hasattr(module, "orig_layer"): # skip module if it's not used in calibration + continue + set_module(self.model, name, module.orig_layer) + + def _cal_scales(self, absorb_to_layer, input_maxes, alpha=0.5): + """Cal the adjust scales + :param absorb_to_layer: A dict mapping absorb layer to smooth quantized layer + :param input_maxes: The channel-wise input max info for layers + :param alpha: Alpha value to balance the quantization difficulty of activation and weight, a float of a dict + :return:""" + absorb_to_input_maxes = {} + for key in absorb_to_layer.keys(): + layer_name = absorb_to_layer[key][0] + absorb_to_input_maxes[key] = input_maxes[layer_name] + + weight_scales_info = {} + absorb_scales_info = {} + for index, key in enumerate(absorb_to_layer.keys()): + alpha_tmp = alpha[key] if isinstance(alpha, dict) else alpha + if alpha_tmp < 0: + scale = torch.ones((1), device=self.device) + else: + input_max = absorb_to_input_maxes[key] + layer_names = absorb_to_layer[key] + weights = [] + for layer_name in layer_names: + weight = reshape_in_channel_to_last(layer_name, self.model) + weights.append(weight) + + weight_max_per_channel = torch.max(torch.abs(torch.cat(weights, dim=0)), dim=0)[0] + if self.weight_clip: + weight_max_per_channel = weight_max_per_channel.clamp(min=1e-5) + + if self._save_scale: + if key in self.weight_scale_dict and alpha_tmp in self.weight_scale_dict[key]: + scale = self.weight_scale_dict[key][alpha_tmp] + else: + scale = cal_scale(input_max, weights, alpha_tmp) + else: + scale = cal_scale(input_max, weights, alpha_tmp) + + absorb_scales_info[key] = 1.0 / scale + absorb_scales_info[key][scale == 0] = 0 + layer_names = absorb_to_layer[key] + for layer_name in layer_names: + ##self._scale_layer_weight(layer_name, scale) + weight_scales_info[layer_name] = scale + if self._save_scale: + if layer_name not in self.weight_scale_dict: + self.weight_scale_dict[layer_name] = {} + self.weight_scale_dict[layer_name][alpha_tmp] = scale + return absorb_scales_info, weight_scales_info + + def _get_auto_loss(self, output, output_q, loss_type="abs", loss_alpha=1.0): + """Get the loss for auto tuning + :param output: Fp32 output for one layer + :param output_q: Quant output for one layer + :param loss_type: The type of loss + :param loss_alpha: Loss alpha i for mean scale error + :return: A tensor of the loss.""" + if len(output.shape) <= 2: + max_value = torch.max(torch.abs(output)) + else: + output = output.reshape(output.shape[0], -1) + output_q = output_q.reshape(output_q.shape[0], -1) + max_value = torch.max(torch.abs(output), dim=-1).values.unsqueeze(-1) + max_value = torch.clip(max_value, 1e-5) + output = output / max_value ##FIXME need copy not replace + output_q = output_q / max_value + if loss_type == "abs": + return torch.sum(torch.pow(torch.abs(output - output_q), 0.5)) + else: + return torch.sum((output - output_q) ** 2) + + def _get_sq_layer_names(self): + """Get all the layers that could be smooth quanted + :return: All the sq layer names.""" + ##TODO this may not fit for folding=False + module_names = [] + for key in self.absorb_to_layer: + module_names += self.absorb_to_layer[key] + return module_names + + def _get_best_alpha(self, absorb_to_layer, loss_alphas, shared_criterion): + """Obtain the optimal alpha values based on shared criterion and loss values recorded in auto-tuning step. + + :return: A dict of layerwise alpha values. + """ + + def dict_to_list(dic): + res = [] + for key in dic.keys(): + res.append((key, dic[key])) + return res + + best_alpha = {} + for ln_name in absorb_to_layer.keys(): + layer_names = absorb_to_layer[ln_name] + cur_shared_criterion = shared_criterion + if len(layer_names) == 1: + cur_shared_criterion = "min" + if cur_shared_criterion == "mean": + loss_tmp = {} + for alpha in loss_alphas[layer_names[0]].keys(): + if alpha not in loss_tmp.keys(): + loss_tmp[alpha] = 0 + for layer_name in layer_names: + loss_tmp[alpha] += loss_alphas[layer_name][alpha] + res = dict_to_list(loss_tmp) + res.sort(key=lambda x: x[1]) + + best_alpha[ln_name] = float(res[0][0]) + + elif cur_shared_criterion == "min" or cur_shared_criterion == "max": + tmp_best_alpha = [] + for layer_name in layer_names: + res = dict_to_list(loss_alphas[layer_name]) + res.sort(key=lambda x: x[1]) + tmp_best_alpha.append(float(res[0][0])) + if cur_shared_criterion == "min": + best_alpha[ln_name] = min(tmp_best_alpha) + else: + best_alpha[ln_name] = max(tmp_best_alpha) + + else: + raise NotImplementedError + return best_alpha + + def _get_one_batch_auto_loss(self, input, alpha_space, orig_best_alpha, input_maxes): + """Calculate the losses for all alpha values given an input. + + :return: A dict of op-wise loss values with respect to alpha values. + """ + self._change_qdq_for_auto(enable=False) + module_names = self._get_sq_layer_names() + forward_wrapper(self.model, input, self.device) ##disable quant and get fp32 output + + fp32_output = {} + for name in module_names: + module = get_module(self.model, name) + fp32_output[name] = module.output + module.output = None + self._change_qdq_for_auto(enable=True) + absorb_input_scales, weight_scales = self._cal_scales(self.absorb_to_layer, input_maxes, orig_best_alpha) + self._update_scales_for_auto(absorb_input_scales, weight_scales) + forward_wrapper(self.model, input, self.device) ##save quant_input + for mod_name in module_names: # save fp32 values + mod = get_module(self.model, mod_name) + if mod_name in self.fp32_output_val: + self.fp32_output_val[mod_name].append(torch.norm(mod.output)) + else: + self.fp32_output_val[mod_name] = [torch.norm(mod.output)] + del mod + + loss_alphas = {} + for name in module_names: + module = get_module(self.model, name) + loss = self._get_auto_loss(fp32_output[name], module.output) + cur_alpha = orig_best_alpha + if isinstance(orig_best_alpha, dict): + cur_alpha = orig_best_alpha[name] + key_name = str(cur_alpha) + loss_alphas[name] = {key_name: loss} + # for name in module_names: + # loss_alphas[name]={} + for alpha in alpha_space: + absorb_input_scales, weight_scales = self._cal_scales(self.absorb_to_layer, input_maxes, alpha) + self._update_scales_for_auto(absorb_input_scales, weight_scales) + for name in module_names: + losses = loss_alphas[name] + if str(alpha) in losses.keys(): + continue + module = get_module(self.model, name) + output = module.q_dq_forward(module.q_input, module.input_scale, module.weight_scale) + loss = self._get_auto_loss(fp32_output[name], output) + loss_alphas[name][str(alpha)] = loss + return loss_alphas + + def _get_one_batch_auto_loss_blockwise(self, input, alpha_space, orig_best_alpha, input_maxes): + """Calculate the losses for all alpha values given an input in blockwise tuning mode. + + :return: A dict of blockwise-wise loss values with respect to alpha values. + """ + self._change_qdq_for_auto(enable=False) + module_names = self._get_sq_layer_names() + + block_modules = {} + for key in self.block_names: + block_modules[key] = get_module(self.model, key) + self._add_blockwise_observer(block_modules) + + forward_wrapper(self.model, input, self.device) ##disable quant and get fp32 output + + fp32_output = {} + for block_name in self.block_names: + fp32_output[block_name] = self.block_outputs[block_name] + self._change_qdq_for_auto(enable=True) + absorb_input_scales, weight_scales = self._cal_scales(self.absorb_to_layer, input_maxes, orig_best_alpha) + self._update_scales_for_auto(absorb_input_scales, weight_scales) + forward_wrapper(self.model, input, self.device) ##save quant_input + for mod_name in module_names: # save fp32 values + mod = get_module(self.model, mod_name) + if mod_name in self.fp32_output_val: + self.fp32_output_val[mod_name].append(torch.norm(mod.output)) + else: + self.fp32_output_val[mod_name] = [torch.norm(mod.output)] + del mod + + loss_alphas = {} + + for block_name in self.block_names: + block = get_module(self.model, block_name) + loss = self._get_auto_loss(fp32_output[block_name], self.block_outputs[block_name]) + cur_alpha = orig_best_alpha + if isinstance(orig_best_alpha, dict): + cur_alpha = orig_best_alpha[self.block_to_module[block_name][0]] + key_name = str(cur_alpha) + loss_alphas[block_name] = {key_name: loss} + # for name in module_names: + # loss_alphas[name]={} + for alpha in alpha_space: + absorb_input_scales, weight_scales = self._cal_scales(self.absorb_to_layer, input_maxes, alpha) + self._update_scales_for_auto(absorb_input_scales, weight_scales) + + for block_name in self.block_names: + losses = loss_alphas[block_name] + if str(alpha) in losses.keys(): + continue + block = get_module(self.model, block_name) + block_copy = copy.deepcopy(block) + for name in self.block_to_module[block_name]: + if name == block_name and len(self.block_to_module[block_name]) == 1: + module, module_copy = block, block_copy + else: + module = get_module(block, name) + module_copy = copy.deepcopy(module) + if module.weight_scale is not None: + module_copy.orig_layer.weight *= module.weight_scale + q_dq_weight = quant_dequant_w_v1(module_copy.orig_layer) + module_copy.orig_layer.weight.data.copy_(q_dq_weight) + module_copy.do_blockwise = True + if not (name == block_name and len(self.block_to_module[block_name]) == 1): + set_module(block_copy, name, module_copy) + try: + output = block_copy(self.block_inputs[block_name])[0] + except: # Llama model decoder_layer forward requires position_id + position_ids = torch.arange(self.block_inputs[block_name].size()[1]) + position_ids = position_ids.view(self.block_inputs[block_name].size()[0], -1) + output = block_copy(self.block_inputs[block_name], position_ids=position_ids)[0] + loss = self._get_auto_loss(fp32_output[block_name], output) + loss_alphas[block_name][str(alpha)] = loss + del block_copy # release memory + return loss_alphas + + def opwise_rank(self, loss_alphas, best_alphas): + """Rank the final losses of ops based on their ratio with respect to op output norm. + + :return: + """ + max_op, max_ratio, max_key = "", 0, "" + ratio_info = {} + for key in self.absorb_to_layer: + for op_name in self.absorb_to_layer[key]: + fp32_norm, loss_ = ( + torch.sum(torch.stack(self.fp32_output_val[op_name])), + loss_alphas[op_name][str(best_alphas[key])], + ) + ratio = loss_ / fp32_norm + max_op = op_name if ratio > max_ratio else max_op + max_key = key if ratio > max_ratio else max_key + max_ratio = max(ratio, max_ratio) + ratio_info[op_name] = ratio + logger.debug( + f"final loss: {op_name}: {loss_}; @alpha {best_alphas[key]}; \ + fp32_output norm: {fp32_norm}; ratio: {ratio}" + ) + import operator + + ratio_info = dict(sorted(ratio_info.items(), key=operator.itemgetter(1), reverse=True)) + for key in list(ratio_info.keys()): + logger.debug(f"sorted opname-ratio: {key}: {ratio_info[key]}") + if max_op != "": + logger.debug( + f"max loss: {max_op}: {loss_alphas[max_op][str(best_alphas[max_key])]} @alpha {best_alphas[max_key]}\ + fp32_output norm: {torch.sum(torch.stack(self.fp32_output_val[max_op]))}; ratio: {max_ratio}" + ) + return None + + def default_tune_setup(self): + """Setup default auto-tune settings. + + :return: A dict of op-wise loss values with respect to alpha values. + """ + round_num = max( # Initialize the alpha search space + len(str(self.alpha_min).split(".")[1]), + len(str(self.alpha_max).split(".")[1]), + len(str(self.alpha_step).split(".")[1]), + ) + self.alpha_space = numpy.round( + numpy.arange(self.alpha_min, self.alpha_max + self.alpha_step, self.alpha_step), round_num + ).tolist() + ##wrapper new module + self._qdq_model_wrapper_for_auto(save_q_input=True) + + absorb_input_scales, weight_scales = self._cal_scales( + self.absorb_to_layer, self.input_maxes_abs, self.init_alpha + ) + self._update_scales_for_auto(absorb_input_scales, weight_scales) + return absorb_input_scales, weight_scales + + def _auto_tune_alpha(self): + """Perform alpha-tuning to obtain layer-wise optimal alpha values and adjust parameters accordingly.""" + logger.info("Start alpha tuning") + + absorb_input_scales, weight_scales = self.default_tune_setup() + + total_cnt, tmp_cnt = 0, 0 + alpha_update_iter, tune_cnt = 0, 4 + # multiply_factor is used to combine samples to calib_sample_num // 4 before summarizing the best alpha + multiply_factor = ( + self.calib_sample_num // tune_cnt if self.calib_sample_num >= tune_cnt else self.calib_sample_num + ) + self.fp32_output_val = {} + best_alphas = self.init_alpha + + if not self.dataloader: + logger.info(f"Auto-tuning failed due to no dataloader, using {best_alphas} instead.") + self._qdq_model_unwrapper_for_auto() + return best_alphas + bar = tqdm(self.dataloader, total=self.calib_sample_num, desc="auto tune alpha") # pylint: disable=E1102 + for input in bar: + if isinstance(input, tuple) or isinstance(input, list): + if len(input) == 2: + input, _ = input # Extract input when both input and label are yielded by dataloader. + loss_alphas = {} + best_alphas_per_module = best_alphas + if isinstance(best_alphas, dict): + for key in self.absorb_to_layer.keys(): + layer_names = self.absorb_to_layer[key] + for layer_name in layer_names: + best_alphas_per_module[layer_name] = best_alphas_per_module[key] + loss_tmp = self._get_one_batch_auto_loss( + input, self.alpha_space, best_alphas_per_module, self.input_maxes_abs + ) + if loss_alphas == {}: + loss_alphas = loss_tmp + else: + for key in loss_alphas.keys(): + cur_loss = loss_alphas[key] + for alpha_key in cur_loss.keys(): + cur_loss[alpha_key] += loss_tmp[key][alpha_key] + total_cnt += self.dataloader.batch_size + tmp_cnt += self.dataloader.batch_size + if tmp_cnt // multiply_factor >= 1: + alpha_update_iter += 1 + tmp_cnt = 0 + best_alphas = self._get_best_alpha(self.absorb_to_layer, loss_alphas, self.shared_criterion) + for key in best_alphas.keys(): + logger.info(f"Auto alpha update iter: {alpha_update_iter}, {key}: {best_alphas[key]}") + absorb_input_scales, weight_scales = self._cal_scales( + self.absorb_to_layer, self.input_maxes_abs, best_alphas + ) + self._update_scales_for_auto(absorb_input_scales, weight_scales) + # does not need to reset the weight_scale_dict, because use the weight of ori_layer, no change + # self.weight_scale_dict = {} + if total_cnt >= self.calib_sample_num: + break + + best_alphas = self._get_best_alpha(self.absorb_to_layer, loss_alphas, self.shared_criterion) + for key in best_alphas.keys(): + logger.info(f"Final alpha {key}:{best_alphas[key]}") + + self.opwise_rank(loss_alphas, best_alphas) + self._qdq_model_unwrapper_for_auto() + logger.info("auto tuning done") + + return best_alphas + + def _auto_tune_alpha_blockwise(self): + """Perform blockwise-alpha-tuning to obtain optimal alpha values and adjust parameters accordingly.""" + logger.info("Start block-wise alpha tuning") + self.block_inputs, self.block_outputs = {}, {} + + absorb_input_scales, weight_scales = self.default_tune_setup() + + total_cnt, tmp_cnt = 0, 0 + alpha_update_iter, tune_cnt = 0, 4 + # multiply_factor is used to combine samples to calib_sample_num // 4 before summarizing the best alpha + multiply_factor = ( + self.calib_sample_num // tune_cnt if self.calib_sample_num >= tune_cnt else self.calib_sample_num + ) + self.fp32_output_val = {} + best_alphas = self.init_alpha + + if not self.dataloader: + logger.info(f"Auto-tuning failed due to no dataloader, using {best_alphas} instead.") + self._qdq_model_unwrapper_for_auto() + return best_alphas + bar = tqdm(self.dataloader, total=self.calib_sample_num, desc="auto tune alpha") # pylint: disable=E1102 + for input in bar: + if isinstance(input, tuple): # Extract input when both input and label are yielded by dataloader. + input = input[0] + loss_alphas = {} + best_alphas_per_module = best_alphas + if isinstance(best_alphas, dict): + for key in self.absorb_to_layer.keys(): + layer_names = self.absorb_to_layer[key] + for layer_name in layer_names: + best_alphas_per_module[layer_name] = best_alphas_per_module[key] + loss_tmp = self._get_one_batch_auto_loss_blockwise( + input, self.alpha_space, best_alphas_per_module, self.input_maxes_abs + ) + if loss_alphas == {}: + for block_name in self.block_names: + for key in self.block_to_module[block_name]: + loss_alphas[key] = loss_tmp[block_name] + else: + for block_name in self.block_names: + for key in self.block_to_module[block_name]: + cur_loss = loss_alphas[key] + for alpha_key in cur_loss.keys(): + cur_loss[alpha_key] += loss_tmp[block_name][alpha_key] + + total_cnt += self.dataloader.batch_size + tmp_cnt += self.dataloader.batch_size + if tmp_cnt // multiply_factor >= 1: + alpha_update_iter += 1 + tmp_cnt = 0 + best_alphas = self._get_best_alpha(self.absorb_to_layer, loss_alphas, self.shared_criterion) + for key in best_alphas.keys(): + logger.info(f"Auto alpha update iter: {alpha_update_iter}, {key}: {best_alphas[key]}") + absorb_input_scales, weight_scales = self._cal_scales( + self.absorb_to_layer, self.input_maxes_abs, best_alphas + ) + self._update_scales_for_auto(absorb_input_scales, weight_scales) + # does not need to reset the weight_scale_dict, because use the weight of ori_layer, no change + # self.weight_scale_dict = {} + if total_cnt >= self.calib_sample_num: + break + + best_alphas = self._get_best_alpha(self.absorb_to_layer, loss_alphas, self.shared_criterion) + for key in best_alphas.keys(): + logger.info(f"Final alpha {key}:{best_alphas[key]}") + + self.opwise_rank(loss_alphas, best_alphas) + self._qdq_model_unwrapper_for_auto() + logger.info("block-wise auto tuning done") + + return best_alphas + + +class TorchSmoothQuant: + """Fake input channel quantization, for more details please refer to + [1] SmoothQuant: Accurate and Efficient + Post-Training Quantization for Large Language Models + [2] SPIQ: Data-Free Per-Channel Static Input Quantization + Currently, we only handle the layers whose smooth scale could be absorbed, we will support other layers later. + + We only support inplace mode which means the model weights will be changed, you can call recover function + to recover the weights if needed + """ + + def __init__( + self, + model, + dataloader=None, + example_inputs=None, + q_func=None, + traced_model=None, + scale_sharing=True, + record_max_info=False, + ): + """ + :param model: Torch model :param dataloader: Calibration dataloader :param traced_model: A specific model + shares the same architecture as the model and could be traced by torch.jit. If not supplied, we use model + instead. + """ + self.model = model + if not isinstance(self.model, torch.nn.Module): + return + device, dtype = self._get_device() + self.model = self.model.to(device) + self.model.eval() + self.device = device + self.dtype = dtype + self.dataloader = dataloader + self.example_inputs = example_inputs + self.q_func = q_func + self.input_maxes = {} + self.input_mins = {} + self.input_maxes_abs = {} + self.traced_model = traced_model + if self.traced_model is None: + self.traced_model = self.model + self.weight_scale_info = {} + self.absorb_scales_info = {} + self.scale_sharing = scale_sharing + self.insert_mul = False + self.allow_absorb = True + self.record_max_info = record_max_info + self.max_value_info = {} # to record max values for alpha tune + self.absorb_to_layer = {} + self.weight_max_lb = 1e-5 ##weight max low bound + self.weight_scale_dict = {} + self.sq_scale_info = {} + self.max_value_info = {} + self.need_calibration = False + + def _get_device(self): + """Get the model device + :return:Model device.""" + for _, p in self.model.named_parameters(): + return p.data.device, p.data.dtype + + def _scale_layer_weight(self, layer_name, scale, alpha=0.5, input_minmax=None): ##input channel + """Scale the layer weights at input channel, depthwise conv output channel + :param layer_name: The layer name + :param scale: The scale to be multiplied + :param alpha: alpha for SQLinearWrapper + :param input_minmax: input_minmax for SQLinearWrapper + :return:""" + layer = get_module(self.model, layer_name) + if self.insert_mul: + layer = get_module(self.model, layer_name) + if isinstance(layer, SQLinearWrapper): + layer._recover_sq_linear() + set_module(self.model, layer_name, layer.sq_linear) ##recover + else: + new_module = SQLinearWrapper(layer, 1.0 / scale, input_minmax, alpha) + set_module(self.model, layer_name, new_module) + elif self.allow_absorb: + scale = reshape_scale_as_weight(layer, scale) + layer.weight = torch.nn.Parameter(layer.weight * scale) + return scale + + def _absorb_scales(self, layer_name, scale): ##output channel + """Absorb the scale to the layer at output channel + :param layer_name: The module name + :param scale: The scale to be absorbed + :param alpha_key: The alpha passed to SQLinearWrapper + :return:""" + if self.insert_mul or not self.allow_absorb: + return # absorb is updated in SQLinearWrapper in def _scale_layer_weight + + ##if self.allow absorb + layer = get_module(self.model, layer_name) + if layer.__class__.__name__ == "WrapperLayer": + layer = layer.orig_layer + if ( + isinstance(layer, torch.nn.BatchNorm2d) + or isinstance(layer, torch.nn.GroupNorm) + or isinstance(layer, torch.nn.InstanceNorm2d) + ): + if layer.affine: + layer.weight *= scale + layer.bias *= scale + else: + layer.affine = True + weight = torch.ones(layer.num_features, device=self.device, dtype=self.dtype) * scale + layer.weight = torch.nn.Parameter(weight, requires_grad=False) + bias = torch.zeros(layer.num_features, device=self.device, dtype=self.dtype) + layer.bias = torch.nn.Parameter(bias, requires_grad=False) + elif isinstance(layer, torch.nn.LayerNorm): + if layer.elementwise_affine: + layer.weight *= scale + layer.bias *= scale + else: + layer.elementwise_affine = True + weight = torch.ones(layer.num_features, device=self.device, dtype=self.dtype) * scale + layer.weight = torch.nn.Parameter(torch.ones(weight, requires_grad=False)) + bias = torch.zeros(layer.num_features, device=self.device, dtype=self.dtype) + layer.bias = torch.nn.Parameter(bias, requires_grad=False) + + elif isinstance(layer, torch.nn.Conv2d): + ##the order could not be changed + if hasattr(layer, "bias") and (layer.bias is not None): + layer.bias *= scale + scale = scale.view(scale.shape[0], 1, 1, 1) + layer.weight *= scale + + elif isinstance(layer, torch.nn.Linear): + if hasattr(layer, "bias") and (layer.bias is not None): + layer.bias *= scale + scale = scale.view(scale.shape[0], 1) + layer.weight *= scale + + elif layer.__class__.__name__ == "LlamaRMSNorm" or layer.__class__.__name__ == "T5LayerNorm": ##quite tricky + layer.weight *= scale + + else: + logger.warning( + f"found unsupported layer {type(layer)}, try to multiply scale to " + f"weight and bias directly, this may introduce accuracy issue, please have a check " + ) + if hasattr(layer, "weight") and layer.weight is not None: + layer.weight *= scale + if hasattr(layer, "bias") and layer.bias is not None: + layer.bias *= scale + + def _export_sq_info(self, absorb_to_layer, input_maxes, alpha=0.5): + absorb_to_input_maxes = {} + for key in absorb_to_layer.keys(): + layer_name = absorb_to_layer[key][0] + absorb_to_input_maxes[key] = input_maxes[layer_name] + + for index, key in enumerate(absorb_to_layer.keys()): + alpha_tmp = alpha[key] if isinstance(alpha, dict) else alpha + layer_names = absorb_to_layer[key] + weights = [] + for layer_name in layer_names: + weight = reshape_in_channel_to_last(layer_name, self.model) + weights.append(weight) + weight_max_per_channel = torch.max(torch.abs(torch.cat(weights, dim=0)), dim=0)[0] + + weight_max_per_channel = weight_max_per_channel.clamp(min=self.weight_max_lb) + + input_max = absorb_to_input_maxes[key] + layer_names = absorb_to_layer[key] + # weight_scale = cal_scale(input_max, weights, alpha_tmp) + input_minmax = [self.input_mins[layer_names[0]].to("cpu"), self.input_maxes[layer_names[0]].to("cpu")] + abs_input_max = torch.max(torch.abs(input_minmax[0]), torch.abs(input_minmax[1])) + input_power = torch.pow(abs_input_max, alpha_tmp) + weight_power = torch.pow(weight_max_per_channel, 1 - alpha_tmp) + weight_scale = torch.clip(input_power / weight_power, min=1e-5) + + input_scale = 1.0 / weight_scale + + self.max_value_info[key] = { + "alpha": alpha_tmp, + "input_minmax": input_minmax, + "weight_max": weight_max_per_channel, + "absorbed_layer": layer_names, + } # max_value_info is used for pytorch backend and sq_scale_info is used for ipex backend. + # the input of layers with same absorb layer is the same. + for op_name in layer_names: + module = copy.deepcopy(get_module(self.model, op_name)) + new_module = SQLinearWrapper(module, 1.0 / weight_scale, input_minmax, alpha_tmp) + self.sq_scale_info[op_name] = {} + self.sq_scale_info[op_name] = { + "alpha": alpha_tmp, + "input_scale_for_mul": input_scale.to("cpu"), + "input_scale_after_mul": new_module.scale, + "input_zero_point_after_mul": new_module.zero_point, + "input_dtype": new_module.dtype, + "weight_scale_after_mul": new_module._get_weight_scale(), + } + + def _cal_scales(self, absorb_to_layer, input_maxes, alpha=0.5): + """Cal the adjust scales + :param absorb_to_layer: A dict mapping absorb layer to smooth quantized layer + :param input_maxes: The channel-wise input max info for layers + :param alpha: Alpha value to balance the quantization difficulty of activation and weight, a float of a dict + :return:""" + absorb_to_input_maxes = {} + for key in absorb_to_layer.keys(): + layer_name = absorb_to_layer[key][0] + absorb_to_input_maxes[key] = input_maxes[layer_name] + + weight_scales_info = {} + absorb_scales_info = {} + for index, key in enumerate(absorb_to_layer.keys()): + alpha_tmp = alpha[key] if isinstance(alpha, dict) else alpha + + input_max = absorb_to_input_maxes[key] + layer_names = absorb_to_layer[key] + weights = [] + for layer_name in layer_names: + weight = reshape_in_channel_to_last(layer_name, self.model) + weights.append(weight) + scale = cal_scale(input_max, weights, alpha_tmp) + absorb_scales_info[key] = 1.0 / scale + absorb_scales_info[key][scale == 0] = 0 + layer_names = absorb_to_layer[key] + for layer_name in layer_names: + ##self._scale_layer_weight(layer_name, scale) + weight_scales_info[layer_name] = scale + return absorb_scales_info, weight_scales_info + + def _adjust_parameters(self, absorb_to_layer, input_maxes, alpha=0.5): + """Adjust the weights and biases + :param absorb_to_layer: A dict mapping absorb layer to smooth quantized layer + :param input_maxes: The channel-wise input max info for layers + :param alpha: Alpha value to balance the quantization difficulty of activation and weight, a float of a dict + :return:""" + absorb_scales_info, weight_scales_info = self._cal_scales(absorb_to_layer, input_maxes, alpha) + if not absorb_scales_info or not weight_scales_info: + return weight_scales_info, absorb_scales_info + for index, key in enumerate(absorb_to_layer.keys()): + if isinstance(alpha, float): + alpha_tmp = alpha + elif isinstance(alpha, dict): + alpha_tmp = alpha[key] + absorb_scale = absorb_scales_info[key] + self._absorb_scales(key, absorb_scale) + layer_names = absorb_to_layer[key] + for layer_name in layer_names: + input_minmax = [self.input_mins[layer_names[0]], self.input_maxes[layer_names[0]]] + self._scale_layer_weight(layer_name, weight_scales_info[layer_name], alpha_tmp, input_minmax) + return weight_scales_info, absorb_scales_info + + def _check_need_calibration(self, alpha, percentile, op_types, scales_per_op, calib_iter): + """ + check need calibration or not + :param alpha: current alpha + :param percentile: current percentile + :param op_types: current op_types + :param scales_per_op: current scales_per_op + :param calib_iter:: current scales_per_op + :return: + """ + need_calib = True + from peft import PeftModel # pylint: disable=E0401 + + is_peft, is_auto = isinstance(self.model, PeftModel), alpha == "auto" + if len(self.input_maxes) == 0: ## the first time + need_calib = True + self.alpha = alpha + self.percentile = percentile + self.op_types = op_types + self.scales_per_op = scales_per_op + self.calib_iter = calib_iter + return False if (is_auto and not is_peft) else need_calib + + if ( + self.percentile == percentile + and self.op_types == op_types + and self.scales_per_op == scales_per_op + and self.calib_iter == calib_iter + ): + if isinstance(alpha, float) or self.alpha == "auto": + need_calib = False + + self.alpha, self.percentile, self.calib_iter = alpha, percentile, calib_iter + self.op_types, self.scales_per_op = op_types, scales_per_op + return need_calib + + @torch.no_grad() + def _parse_absorb_to_layers(self, op_types, folding): + str_op_types = [i.__name__ for i in op_types] + self_absorb_layers = {} + if self.insert_mul: + self_absorb_layers = self._get_all_layer_names(op_types) # TODO: only support linear now. + # fetch modules with the same input + group_modules = self._trace(str_op_types, skip_unsupported_layers=False) + if group_modules is not None: + # use one input for qkv + for k, v in group_modules.items(): + for i in v: + if i in self_absorb_layers: + self_absorb_layers.pop(i) + self_absorb_layers[v[0]] = v + logger.debug(f"self_absorb_layers:{self_absorb_layers}") + if self.allow_absorb: + self.absorb_to_layer, no_absorb_layers = self._trace(str_op_types) + if self.absorb_to_layer is None and no_absorb_layers is None: + return None + + # remove self.self_absorb_layers if it exists in self.absorb_to_layer + for k, v in self.absorb_to_layer.items(): + for i in v: + if i in self_absorb_layers: + self_absorb_layers.pop(i) + self.absorb_to_layer.update(self_absorb_layers) + + if self.absorb_to_layer is None and no_absorb_layers is None: + logger.warning( + "sorry, could not trace the model, smooth quant is ignored." + "If you are using huggingface model," + "you could set torchscript to True " + ) + return None + + # Check if input_maxes match self.absorb_to_layer + # (due to self._get_all_layer_names use layer tree instead of forward_path) + if not folding and self.need_calibration: + if len(self.input_mins) == 0: ##there are some modules not used in forward + calib = Calibration(self.model, self.dataloader, self.q_func, self.device) ## + input_mins, input_maxes = calib.calibrate( + 1, op_types + ) ##TODO if using qfunc for calibration, it will calibrate twice + # use qfunc to calibrate, the input min could be used for fixed alpha transformation + self.input_mins = input_mins + self.input_maxes = input_maxes + diff_modules = set(self.absorb_to_layer.keys()).difference(input_mins.keys()) + for d in diff_modules: + del self.absorb_to_layer[d] + return self.absorb_to_layer + + @torch.no_grad() + def transform( + self, + alpha=0.5, + folding=False, + percentile=100, + op_types=[torch.nn.Linear, torch.nn.Conv2d], + scales_per_op=False, + calib_iter=100, + weight_clip=True, + scale_sharing=True, + auto_alpha_args={ + "init_alpha": 0.5, + "alpha_min": 0.0, + "alpha_max": 1.0, + "alpha_step": 0.1, + "shared_criterion": "mean", + "n_samples": 32, ##512 for cuda, 128 for cpu? + }, + ): + """The main entry of smooth quant + :param alpha: Alpha value to balance the quantization difficulty of activation and weight, please refer + to the paper for more details + :param folding: whether insert mul(False) or just allow foldable layers(True) for SmoothQuant + :param percentile: Not supported now + :param op_types: The op typed to be smooth quantized + :param scales_per_op: Not supported now + :param calib_iter: Data size for calibration + :param weight_clip: Whether to clip weight_max when calculating scales. + + :param auto_alpha_args: Hyperparameters used to set the alpha search space in SQ auto-tuning. + By default, the search space is 0.0-1.0 with step_size 0.1. + do_blockwise: Whether to do blockwise auto-tuning. + :param init_alpha: A hyperparameter that is used in SQ auto-tuning; by default it is 0.5. + :return: A FP32 model with the same architecture as the orig model but with different weight which will be + benefit to quantization. + """ + + if not isinstance(self.model, torch.nn.Module): + logger.warning("smoothquant is ignored since the model is not a torch module") + return self.model + + if isinstance(alpha, float) and (alpha < 0): + logger.warning("reset alpha to >=0") + alpha = numpy.clip(alpha, 0.0) + + if folding: + self.insert_mul, self.allow_absorb = False, True + else: + self.insert_mul, self.allow_absorb = True, False + self.weight_clip = weight_clip + + self.revert() + self.need_calibration = self._check_need_calibration(alpha, percentile, op_types, scales_per_op, calib_iter) + with torch.no_grad(): + str_op_types = [i.__name__ for i in op_types] + input_maxes_abs = self.input_maxes_abs + if self.need_calibration: ##avoid multiple calibaration during tuning if the only difference is alpha + if self.insert_mul: + self.self_absorb_layers = self._get_all_layer_names(op_types) # TODO: only support linear now. + if self.scale_sharing: + # fetch modules with the same input + group_modules = self._trace(str_op_types, skip_unsupported_layers=False) + if group_modules is not None: + # use one input for qkv + for k, v in group_modules.items(): + for i in v: + if i in self.self_absorb_layers: + self.self_absorb_layers.pop(i) + self.self_absorb_layers[v[0]] = v + logger.debug(f"self_absorb_layers:{self.self_absorb_layers}") + + self.absorb_to_layer = self._parse_absorb_to_layers( + op_types, folding + ) ##need to forward to check modules not used in forward + if len(self.input_mins) != 0: ##this is from _parse_absorb_to_layers, ugly code to support q_func + input_maxes_abs = {} + for key in self.input_mins.keys(): + input_maxes_abs[key] = torch.max(torch.abs(self.input_mins[key]), torch.abs(self.input_maxes[key])) + if self.q_func: + self.need_calibration = False # Avoid double-calibration in fixed-value alpha SQ. + + if self.absorb_to_layer is None: + logger.warning("empty absorb_to_layer, smoothquant is ignored ") + return self.model + example_inputs = self._get_example_input() + if alpha == "auto": ##TODO need to polish later + auto_alpha_version = "version1" + auto_alpha_tuner = TUNERS[auto_alpha_version]( + self.model, + self.dataloader, + self.absorb_to_layer, + op_types=op_types, + device=self.device, + q_func=self.q_func, + folding=folding, + example_inputs=self.example_inputs, + **auto_alpha_args, + ) + self.alpha = auto_alpha_tuner.tune() + input_maxes_abs = auto_alpha_tuner.input_maxes_abs + self.input_mins, self.input_maxes = auto_alpha_tuner.input_mins, auto_alpha_tuner.input_maxes + if auto_alpha_tuner.loss_type == "blockwise": + self.block_names = auto_alpha_tuner.block_names + + elif self.need_calibration: + calib = Calibration(self.model, self.dataloader, self.q_func, self.device) + self.input_mins, self.input_maxes = calib.calibrate(calib_iter, op_types) + input_maxes_abs = {} + for key in self.input_mins.keys(): + input_maxes_abs[key] = torch.max(torch.abs(self.input_mins[key]), torch.abs(self.input_maxes[key])) + + if example_inputs is not None: + out_pre_sq = model_forward_per_sample(self.model, example_inputs, self.device) + + if folding: + self._save_scale = False ##TODO remove it later + + if self.record_max_info: + self._export_sq_info(self.absorb_to_layer, input_maxes_abs, self.alpha) + # # max_info is recorded in self.max_value_info + # self._adjust_parameters(self.absorb_to_layer, input_maxes_abs, alpha) + self.model._smoothquant_optimized = False + return self.model + + self.weight_scale_info, self.absorb_scales_info = self._adjust_parameters( + self.absorb_to_layer, input_maxes_abs, self.alpha + ) + self.model._smoothquant_optimized = True + + if example_inputs is not None: + # Check mathematical equivalency + out_post_sq = model_forward_per_sample(self.model, example_inputs, self.device) + if not self.output_is_equal(out_post_sq, out_pre_sq): + logger.warning( + "Mathematical equivelancy of Smoothquant is not preserved. " + "Please kindly report this issue to https://github.com/intel/neural-compressor." + ) + else: + logger.warning(" Could not get example input, equivelancy check is skipped") + + return self.model + + def output_is_equal(self, out1, out2, atol=1e-04): + try: + if isinstance(out1, tuple): + return all(torch.all(torch.isclose(out1[i], out2[i], atol=atol)) for i in range(len(out1))) + elif isinstance(out1, dict): + return all(torch.all(torch.isclose(out1[k], out2[k], atol=atol)) for k in out1.keys()) + elif isinstance(out1, torch.Tensor): + return torch.all(torch.isclose(out1, out2, atol=atol)) + return False + except: + logger.warning( + "Automatically check failed, Please check equivelancy manually " + "between out_pre_sq and out_post_sq if necessary." + ) + return True + + @torch.no_grad() + def revert(self): + """Revert the model weights + :return:""" + for key in self.weight_scale_info: + self._scale_layer_weight(key, 1.0 / self.weight_scale_info[key]) + for key in self.absorb_scales_info: + self._absorb_scales(key, 1.0 / self.absorb_scales_info[key]) + self.weight_scale_info = {} ##clear the data + self.absorb_scales_info = {} + + def _get_all_layer_names(self, op_types=[torch.nn.Linear]): + """Try the model to find the layers which can be smooth quantized. + + :param op_types: The op types to be smooth quantized + :return: + self_absorb_layer: A dict, absorb layer name (itself): layers to be smooth quantized + """ + self_absorb_layer = {} + op_types = [torch.nn.Linear] # TODOļ¼š only support SQLinearWrapper + for name, module in self.model.named_modules(): + if isinstance(module, tuple(op_types)): + self_absorb_layer[name] = [name] + return self_absorb_layer + + def _get_example_input(self): + if self.dataloader is None and self.example_inputs is None: + return None + if self.example_inputs is None: + try: + for idx, (input, label) in enumerate(self.dataloader): + self.example_inputs = input + break + except: + for idx, input in enumerate(self.dataloader): + self.example_inputs = input + break + + return self.example_inputs + + def _trace(self, op_types, skip_unsupported_layers=True): + """Try the model to find the layers which can be smooth quantized. + + :param op_types: The op types to be smooth quantized + :return: + absorb_to_layer: A dict, absorb layer name:layers to be smooth quantized + no_absorb_layers: A list saving the layers which could not find the absorb layer + """ + + tg = GraphTrace() + self._get_example_input() + absorb_to_layer, no_absorb_layers = tg.get_absorb_to_layer( + self.traced_model, + self.example_inputs, + op_types, + skip_unsupported_layers=skip_unsupported_layers, + ) + if not skip_unsupported_layers: + return absorb_to_layer + if absorb_to_layer is None and no_absorb_layers is None: + logger.warning( + "sorry, could not trace the model, smooth quant is skipped." + "If you are using huggingface model," + "you could set torchscript to True " + "when loading the model or set the return_dict to False" + ) + elif absorb_to_layer == {}: + logger.warning("could not find any layer to be absorbed") + else: + to_absorb_cnt = 0 + for key, item in absorb_to_layer.items(): + to_absorb_cnt += len(item) + logger.info( + f" {to_absorb_cnt} out of {to_absorb_cnt + len(no_absorb_layers)} " + f"layers could be absorbed in smooth quant" + ) + return absorb_to_layer, no_absorb_layers + + +class SQLinearWrapper(torch.nn.Module): + def __init__(self, module, input_scale, input_minmax, alpha=0.5, dtype=torch.quint8): + super().__init__() + self.register_buffer("input_scale", input_scale) + self.alpha = alpha + self.dtype = dtype + # calculate and only save scale, zero_point to avoid memory usage + self.scale, self.zero_point = self._calculate_qparams(input_scale, input_minmax, dtype) + self.add_module("sq_linear", module) + self._update_sq_linear() + self.ipex = False # a flag used for ipex inference + + @property + def weight(self): + return self.sq_linear.weight + + def forward(self, X): + if self.ipex: + X = self.sq_linear(X) + else: + X = torch.mul(X, self.input_scale) + X = self.sq_linear(X) + return X + + def _calculate_qparams(self, input_scale, input_minmax, dtype=torch.quint8): + # calculate scale and zero_point + if dtype == torch.quint8: + quant_min, quant_max = 0, 255 + min_val = torch.min(input_minmax[0] * input_scale) + max_val = torch.max(input_minmax[1] * input_scale) + # work when min_val bigger than zero. + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min) + scale = torch.max(scale, torch.tensor([torch.finfo(torch.float32).eps], device=scale.device)) + zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int) + zero_point = torch.clamp(zero_point, quant_min, quant_max) + return scale, zero_point + + def _get_weight_scale(self): + # get weight scale and zero_point + from torch.ao.quantization.observer import default_per_channel_weight_observer + + obs = default_per_channel_weight_observer() + obs(self.sq_linear.weight) + scale, _ = obs.calculate_qparams() + return scale + + def _update_sq_linear(self): + # remove mul and reset sq_linear for ipex inference + scale = self.input_scale.view(1, self.input_scale.shape[0]) + with torch.no_grad(): + self.sq_linear.weight /= scale + + def _recover_sq_linear(self): + # remove mul and reset sq_linear for ipex inference + scale = self.input_scale.view(1, self.input_scale.shape[0]) + with torch.no_grad(): + self.sq_linear.weight *= scale + + +class WrapperLayer(torch.nn.Module): + def __init__(self, layer, input_min, input_max, save_q_input=False): + super(WrapperLayer, self).__init__() + self.add_module("orig_layer", layer) # set orig_layer in get/set_module + self.quant = False + self.q_input = None + self.fp32_output = None + self.input_max = input_max + self.input_min = input_min + self.weight_scale = None + self.input_scale = None + self.save_q_input = save_q_input + self.do_blockwise = False + + def enable_quant(self): + self.quant = True + + def disable_quant(self): + self.quant = False + + def update_scale(self, input_scale, weight_scale): + self.input_scale = input_scale + self.weight_scale = weight_scale + + ##TODO better tradeoff performance and memory, currently it's too slow + def q_dq_forward(self, x, input_scale, weight_scale): + layer_copy = copy.deepcopy(self.orig_layer) + if weight_scale is not None: + layer_copy.weight *= weight_scale + q_dq_weight = quant_dequant_w_v1(layer_copy) + layer_copy.weight.data.copy_(q_dq_weight) + if input_scale is None: + x = quant_dequant_x_v1(x, self.input_min, self.input_max) + else: + x = input_scale * x + x = quant_dequant_x_v1(x, self.input_min * input_scale, self.input_max * input_scale) ##FIXME + output = layer_copy(x) + return output + + def q_dq_forward_blockwise(self, x, input_scale): + layer_copy = copy.deepcopy(self.orig_layer) + if input_scale is None: + x = quant_dequant_x_v1(x, self.input_min, self.input_max) + else: + x = input_scale * x + x = quant_dequant_x_v1(x, self.input_min * input_scale, self.input_max * input_scale) ##FIXME + output = layer_copy(x) + return output + + def forward(self, x): + if self.quant: + # self.q_input = x * scale ##save the q_input + if self.save_q_input: + self.q_input = x + if not self.do_blockwise: + output = self.q_dq_forward(x, self.input_scale, self.weight_scale) + else: + output = self.q_dq_forward_blockwise(x, self.input_scale) + + else: + output = self.orig_layer(x) + self.output = output + return output + + +class CpuInfo(object): # pragma: no cover + """Get CPU Info.""" + + def __init__(self): + """Get whether the cpu numerical format is bf16, the number of sockets, cores and cores per socket.""" + self._bf16 = False + self._vnni = False + info = cpuinfo.get_cpu_info() + if "arch" in info and "X86" in info["arch"]: + cpuid = cpuinfo.CPUID() + max_extension_support = cpuid.get_max_extension_support() + if max_extension_support >= 7: + ecx = cpuid._run_asm( + b"\x31\xC9", # xor ecx, ecx + b"\xB8\x07\x00\x00\x00" b"\x0f\xa2" b"\x89\xC8" b"\xC3", # mov eax, 7 # cpuid # mov ax, cx # ret + ) + self._vnni = bool(ecx & (1 << 11)) + eax = cpuid._run_asm( + b"\xB9\x01\x00\x00\x00", # mov ecx, 1 + b"\xB8\x07\x00\x00\x00" b"\x0f\xa2" b"\xC3", # mov eax, 7 # cpuid # ret + ) + self._bf16 = bool(eax & (1 << 5)) + if "arch" in info and "ARM" in info["arch"]: # pragma: no cover + self._sockets = 1 + else: + self._sockets = self.get_number_of_sockets() + self._cores = psutil.cpu_count(logical=False) + self._cores_per_socket = int(self._cores / self._sockets) + + @property + def bf16(self): + """Get whether it is bf16.""" + return self._bf16 + + @property + def vnni(self): + """Get whether it is vnni.""" + return self._vnni + + @property + def cores_per_socket(self): + """Get the cores per socket.""" + return self._cores_per_socket + + def get_number_of_sockets(self) -> int: + """Get number of sockets in platform.""" + cmd = "cat /proc/cpuinfo | grep 'physical id' | sort -u | wc -l" + if psutil.WINDOWS: + cmd = r'wmic cpu get DeviceID | C:\Windows\System32\find.exe /C "CPU"' + + with subprocess.Popen( + args=cmd, + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + universal_newlines=False, + ) as proc: + proc.wait() + if proc.stdout: + for line in proc.stdout: + return int(line.decode("utf-8", errors="ignore").strip()) + return 0 diff --git a/neural_compressor/torch/algorithms/static_quant/static_quant.py b/neural_compressor/torch/algorithms/static_quant/static_quant.py index 92aa34e9eb0..b3dccdafb00 100644 --- a/neural_compressor/torch/algorithms/static_quant/static_quant.py +++ b/neural_compressor/torch/algorithms/static_quant/static_quant.py @@ -17,19 +17,19 @@ import json -from neural_compressor.torch.utils import get_ipex_version +import torch try: import intel_extension_for_pytorch as ipex except: assert False, "Please install IPEX for static quantization." -import torch from packaging.version import Version from .utility import ( cfg_to_qconfig, dump_model_op_stats, + get_ipex_version, get_quantizable_ops_recursively, ipex_config_path, simple_inference, diff --git a/neural_compressor/torch/algorithms/static_quant/utility.py b/neural_compressor/torch/algorithms/static_quant/utility.py index fac2604786b..cdfd3cb72d0 100644 --- a/neural_compressor/torch/algorithms/static_quant/utility.py +++ b/neural_compressor/torch/algorithms/static_quant/utility.py @@ -18,13 +18,14 @@ import re from typing import Dict, List, Union +import torch +from packaging.version import Version + try: import intel_extension_for_pytorch as ipex import prettytable as pt except: pass -import torch -from packaging.version import Version from neural_compressor.common.utils import DEFAULT_WORKSPACE from neural_compressor.torch.utils import get_ipex_version, get_torch_version, logger @@ -50,7 +51,6 @@ "re": {" int: - """Query the depth of the dict.""" - if isinstance(d, dict): - return 1 + max(get_depth(v) for v in d.values()) - return 0 - - -def get_dict_at_depth(d, target_depth, result, depth=0): - """Get all sub-dicts that are at a specified depth in a nested dict.""" - if depth == target_depth: - result.append(d) - return - elif depth < target_depth and isinstance(d, dict): - for k, v in d.items(): - get_dict_at_depth(v, target_depth, result, depth=depth + 1) - - -def get_element_under_depth(d, ops_lst): - """Get all values in a nested dict.""" - if isinstance(d, dict): - for k, v in d.items(): - get_element_under_depth(v, ops_lst) - else: - ops_lst.append(d) - - -def paser_cfgs(cfgs): # pragma: no cover - """Parse configs. - - Args: - cfgs (dict): the input configs. - - - Returns: - ops_name (list): list of op names. - tune_cfg (dict): dictionary of quantization configuration. - op_infos_from_cfgs (dict): op infos from configs. - output_tensor_ids_op_name (dict): dictionary of output tensor op names. - """ - ops_name = [] - layer_output_infos_ids = [] - op_infos_from_cfgs = {} - # record input_tensor_id and op_name - # {"0": [(" ", "q_op_infos", "0"), (" ", "q_op_infos", "1")]} - input_tensor_ids_op_name = {} - output_tensor_ids_op_name = {} - for module_key in cfgs.keys(): - for state in cfgs[module_key]: - if state == "layer_output_infos": - for index, op_info in enumerate(cfgs[module_key][state]): - name = (module_key, state, index) - ops_name.append(name) - layer_output_infos_ids.append(op_info["id"]) - op_infos_from_cfgs[name] = op_info - continue - for op_cfg_id in cfgs[module_key][state].keys(): - op_info = cfgs[module_key][state][op_cfg_id] - name = (module_key, state, op_cfg_id) - if name not in ops_name: - ops_name.append(name) - else: - assert False, "Please check IPEX int8 configure json whether have the same name ops" - op_infos_from_cfgs[name] = op_info - input_tensors = op_info["input_tensor_infos"] - for input_tensor in input_tensors: - if "id" not in input_tensor.keys(): - continue - else: - input_tensor_id = input_tensor["id"] - if input_tensor_id not in input_tensor_ids_op_name.keys(): - input_tensor_ids_op_name[input_tensor_id] = [name] - else: - input_tensor_ids_op_name[input_tensor_id].append(name) - output_tensors = op_info["output_tensor_infos"] - for output_tensor in output_tensors: - if "id" not in output_tensor.keys(): - continue - else: - output_tensor_id = output_tensor["id"] - if output_tensor_id not in output_tensor_ids_op_name.keys(): - output_tensor_ids_op_name[output_tensor_id] = [name] - else: - output_tensor_ids_op_name[output_tensor_id].append(name) - return ops_name, op_infos_from_cfgs, input_tensor_ids_op_name, output_tensor_ids_op_name - - -def get_quantizable_ops_from_cfgs(ops_name, op_infos_from_cfgs, input_tensor_ids_op_name): # pragma: no cover - """Get quantizable ops from configs, combine fused ops as one op. - - Args: - ops_name (list): list of op names. - op_infos_from_cfgs (dict): op infos from configs. - input_tensor_ids_op_name (dict): dictionary of input tensor op names. - - Returns: - cfgs (dict). - """ - quantizable_ops = [] - seen_ops = [] - for name in ops_name: - start = True - if name in seen_ops: - continue - elif name[1] not in ["q_op_infos"]: - continue - else: - # judge fuse ops the first op - op_info = op_infos_from_cfgs[name] - output_tensors = op_info["output_tensor_infos"] - input_tensors = op_info["input_tensor_infos"] - start = any( - [ - input_tensor["inf_dtype"] != "torch.float32" - for input_tensor in input_tensors - if "inf_dtype" in input_tensor.keys() - ] - ) - if not start: - continue - # add quantizable ops, include op and fuse ops. - q_ops, stack = [], [(name, [])] - while stack: - cur_name, cur = stack.pop() - seen_ops.append(cur_name) - if cur_name[1] not in ["q_op_infos"]: - q_ops.append(cur) - break - op_info = op_infos_from_cfgs[cur_name] - output_tensors = op_info["output_tensor_infos"] - for output_tensor in output_tensors: - if output_tensor["inf_dtype"] == "torch.qint8" or output_tensor["inf_dtype"] == "torch.quint8": - q_ops.append(cur + [cur_name]) - break - try: - next_op_names = input_tensor_ids_op_name[output_tensor["id"]] - for next_op_name in next_op_names: - stack.append((next_op_name, cur + [cur_name])) - except: - next_op_name = None - if next_op_name is None: - q_ops.append(cur + [cur_name]) - for q_op in q_ops: - quantizable_ops.append(q_op) - return quantizable_ops - - -def get_pattern(fallback_op, fuse_ops): # pragma: no cover - for fuse_pattern in fuse_ops: - if fuse_pattern[0] == fallback_op: - if fuse_pattern[1] in ["relu_", "add_"]: - return None - else: - return fuse_pattern[1] - return None - - -def simple_inference(q_model, example_inputs, iterations=1): - """The function is used for ipex warm-up inference.""" - for _ in range(iterations): - if isinstance(example_inputs, tuple) or isinstance(example_inputs, list): - q_model(*example_inputs) - elif isinstance(example_inputs, dict): - q_model(**example_inputs) - else: - q_model(example_inputs) - - def cfg_to_qconfig(tune_cfg, cfgs, default_cfgs, fuse_ops): # pragma: no cover assert cfgs is not None, "No configure for IPEX int8 model..." for key in tune_cfg["op"]: @@ -283,40 +116,144 @@ def cfg_to_qconfig(tune_cfg, cfgs, default_cfgs, fuse_ops): # pragma: no cover return torch.per_tensor_symmetric -def get_fuse_ops(default_cfgs): # pragma: no cover - elt_wise = ["relu", "sigmoid", "gelu"] - inplace_ops = ["relu_", "add_"] - op_patterns = [] - num_ops = len(default_cfgs) - for cur_id in range(num_ops): - cur_op = default_cfgs[cur_id]["name"] - if cur_op == "dropout": - continue - inputs = default_cfgs[cur_id]["inputs_flow"] - num_input = len(inputs) - pre_ops = {} - for i_num in range(num_input): - inp = inputs[i_num] - for pre_id in range(cur_id): - pre_op = default_cfgs[pre_id]["name"] - pre_out = default_cfgs[pre_id]["outputs_flow"] - num_out = len(pre_out) - for o_num in range(num_out): - if pre_out[o_num] == inp: - if cur_op in inplace_ops and (pre_op in ["conv2d", "conv3d", "linear"]): - op_patterns.append([(pre_id, pre_op), (cur_id, cur_op)]) - if cur_op in elt_wise and (pre_op in ["conv2d", "conv3d", "linear", "add"]): - op_patterns.append([(pre_id, pre_op), (cur_id, cur_op)]) - if cur_op == "add": - pre_ops[i_num] = [pre_id, pre_op] - if len(pre_ops) > 0: - for key, value in pre_ops.items(): - if ( - value[1] in ["conv2d", "conv3d", "linear"] - and default_cfgs[cur_id]["inputs_quantized"][key] is False - ): - op_patterns.append([(value[0], value[1]), (cur_id, cur_op)]) - return op_patterns +def get_quantizable_ops_recursively(model, example_inputs): # pragma: no cover + """Get all quantizable ops from model. + + Args: + model (object): input model + example_inputs (dict|list|tuple|torch.Tensor): used to trace torch model. + Returns: + quantizable_ops (list): list of tuples of op_name and op_type. + cfgs (dict): dict of configuration + """ + quantizable_ops = [] + # group ops by position for transform-based model + detector = TransformerBasedModelBlockPatternDetector(model) + detect_result = detector.detect_block() + attention_block = detect_result.get("attention_blocks", None) + ffn_blocks = detect_result.get("ffn_blocks", None) + logger.info(f"Attention Blocks: {len(attention_block)}") + logger.info(f"FFN Blocks: {len(ffn_blocks)}") + if not os.path.exists(ipex_config_path): + assert isinstance(model, torch.nn.Module), "The model passed in is not the instance of torch.nn.Module" + + if hasattr(model, "save_qconf_summary"): # pragma: no cover + os.makedirs(os.path.dirname(ipex_config_path), exist_ok=True) + model.save_qconf_summary(qconf_summary=ipex_config_path) + else: + model.eval() + + # create a quantization config file for intel pytorch extension model + os.makedirs(os.path.dirname(ipex_config_path), exist_ok=True) + assert example_inputs is not None, "IPEX need q_dataloader or example_inputs to prepare the model" + from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver, QConfig + + if ipex_ver.release >= Version("2.1").release: + # HistogramObserver will cause a performance issue. + # static_qconfig = ipex.quantization.default_static_qconfig_mapping + qconfig = QConfig( + activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8), + weight=PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric), + ) + from torch.ao.quantization import QConfigMapping + + static_qconfig = QConfigMapping().set_global(qconfig) + else: + static_qconfig = QConfig( + activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8), + weight=PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric), + ) + + if isinstance(example_inputs, dict): + model = ipex.quantization.prepare(model, static_qconfig, example_kwarg_inputs=example_inputs, inplace=True) + else: + model = ipex.quantization.prepare(model, static_qconfig, example_inputs=example_inputs, inplace=True) + simple_inference(model, example_inputs, iterations=1) + model.save_qconf_summary(qconf_summary=ipex_config_path) + + map_op_name_to_fqn = {} + with open(ipex_config_path, "r") as f: + cfgs = json.load(f) + default_cfgs = {} + fuse_ops = [] + if ipex_ver.release < Version("1.12.0").release: # pragma: no cover + default_cfgs = copy.deepcopy(cfgs) + fuse_ops = get_fuse_ops(cfgs) + for op_cfg in cfgs: + if op_cfg["name"] in unify_op_type_mapping_ipex: + quantizable_ops.append((op_cfg["id"], unify_op_type_mapping_ipex[op_cfg["name"]])) + else: + re_flag = False + for pattern, unify_op_type in unify_op_type_mapping_ipex["re"].items(): + if re.match(pattern, op_cfg["name"]): + re_flag = True + quantizable_ops.append((op_cfg["id"], unify_op_type)) + break + if not re_flag: + quantizable_ops.append((op_cfg["id"], op_cfg["name"])) + else: + ( + ops_name, + op_infos_from_cfgs, + input_tensor_id_op_name, + output_tensor_id_op_name, + ) = paser_cfgs(cfgs) + quantizable_op_names = get_quantizable_ops_from_cfgs(ops_name, op_infos_from_cfgs, input_tensor_id_op_name) + for name in quantizable_op_names: + # name : list + if len(name) == 1: + module_key = name[0][0] + op_cfg_id = name[0][2] + ipex_op_type = cfgs[module_key]["q_op_infos"][op_cfg_id]["op_type"] + module_fqn = cfgs[module_key]["q_op_infos"][op_cfg_id].get("fqn", None) + + if ipex_op_type in unify_op_type_mapping_ipex: + quantizable_ops.append((tuple(name), unify_op_type_mapping_ipex[ipex_op_type])) + map_op_name_to_fqn[(tuple(name), ipex_op_type)] = module_fqn + else: + re_flag = False + for pattern, unify_op_type in unify_op_type_mapping_ipex["re"].items(): + if re.match(pattern, ipex_op_type): + re_flag = True + quantizable_ops.append((tuple(name), unify_op_type)) + map_op_name_to_fqn[(tuple(name), unify_op_type)] = module_fqn + break + if not re_flag: + quantizable_ops.append((tuple(name), ipex_op_type)) + map_op_name_to_fqn[(tuple(name), ipex_op_type)] = module_fqn + else: + op_type = "" + for op_name in name: + module_key = op_name[0] + op_cfg_id = op_name[2] + single_op_type = cfgs[module_key]["q_op_infos"][op_cfg_id]["op_type"] + if single_op_type in unify_op_type_mapping_ipex: + single_op_type = unify_op_type_mapping_ipex[single_op_type] + op_type += "&" + single_op_type if op_type else single_op_type + quantizable_ops.append((tuple(name), op_type)) + _module_key = name[0][0] + _op_cfg_id = name[0][2] + module_fqn = cfgs[_module_key]["q_op_infos"][_op_cfg_id]["fqn"] + map_op_name_to_fqn[(tuple(name), op_type)] = module_fqn + + logger.debug("Map op name to fqn: ") + logger.debug(map_op_name_to_fqn) + logger.info("Attention Blocks : ") + logger.info(attention_block) + logger.info("FFN Blocks : ") + logger.info(ffn_blocks) + return quantizable_ops, cfgs, default_cfgs, fuse_ops + + +def simple_inference(q_model, example_inputs, iterations=1): + """The function is used for ipex warm-up inference.""" + for _ in range(iterations): + if isinstance(example_inputs, tuple) or isinstance(example_inputs, list): + q_model(*example_inputs) + elif isinstance(example_inputs, dict): + q_model(**example_inputs) + else: + q_model(example_inputs) def dump_model_op_stats(tune_cfg): @@ -372,141 +309,241 @@ def dump_model_op_stats(tune_cfg): ).print_stat() -def get_quantizable_ops_recursively(model, example_inputs): - """Get all quantizable ops from model. +def get_fuse_ops(default_cfgs): # pragma: no cover + elt_wise = ["relu", "sigmoid", "gelu"] + inplace_ops = ["relu_", "add_"] + op_patterns = [] + num_ops = len(default_cfgs) + for cur_id in range(num_ops): + cur_op = default_cfgs[cur_id]["name"] + if cur_op == "dropout": + continue + inputs = default_cfgs[cur_id]["inputs_flow"] + num_input = len(inputs) + pre_ops = {} + for i_num in range(num_input): + inp = inputs[i_num] + for pre_id in range(cur_id): + pre_op = default_cfgs[pre_id]["name"] + pre_out = default_cfgs[pre_id]["outputs_flow"] + num_out = len(pre_out) + for o_num in range(num_out): + if pre_out[o_num] == inp: + if cur_op in inplace_ops and (pre_op in ["conv2d", "conv3d", "linear"]): + op_patterns.append([(pre_id, pre_op), (cur_id, cur_op)]) + if cur_op in elt_wise and (pre_op in ["conv2d", "conv3d", "linear", "add"]): + op_patterns.append([(pre_id, pre_op), (cur_id, cur_op)]) + if cur_op == "add": + pre_ops[i_num] = [pre_id, pre_op] + if len(pre_ops) > 0: + for key, value in pre_ops.items(): + if ( + value[1] in ["conv2d", "conv3d", "linear"] + and default_cfgs[cur_id]["inputs_quantized"][key] is False + ): + op_patterns.append([(value[0], value[1]), (cur_id, cur_op)]) + return op_patterns + + +def get_depth(d) -> int: + """Query the depth of the dict.""" + if isinstance(d, dict): + return 1 + max(get_depth(v) for v in d.values()) + return 0 + + +def get_dict_at_depth(d, target_depth, result, depth=0): + """Get all sub-dicts that are at a specified depth in a nested dict.""" + if depth == target_depth: + result.append(d) + return + elif depth < target_depth and isinstance(d, dict): + for k, v in d.items(): + get_dict_at_depth(v, target_depth, result, depth=depth + 1) + + +def get_element_under_depth(d, ops_lst): + """Get all values in a nested dict.""" + if isinstance(d, dict): + for k, v in d.items(): + get_element_under_depth(v, ops_lst) + else: + ops_lst.append(d) + + +def paser_cfgs(cfgs): # pragma: no cover + """Parse configs. + + Args: + cfgs (dict): the input configs. + + + Returns: + ops_name (list): list of op names. + tune_cfg (dict): dictionary of quantization configuration. + op_infos_from_cfgs (dict): op infos from configs. + output_tensor_ids_op_name (dict): dictionary of output tensor op names. + """ + ops_name = [] + layer_output_infos_ids = [] + op_infos_from_cfgs = {} + # record input_tensor_id and op_name + # {"0": [(" ", "q_op_infos", "0"), (" ", "q_op_infos", "1")]} + input_tensor_ids_op_name = {} + output_tensor_ids_op_name = {} + for module_key in cfgs.keys(): + for state in cfgs[module_key]: + if state == "layer_output_infos": + for index, op_info in enumerate(cfgs[module_key][state]): + name = (module_key, state, index) + ops_name.append(name) + layer_output_infos_ids.append(op_info["id"]) + op_infos_from_cfgs[name] = op_info + continue + for op_cfg_id in cfgs[module_key][state].keys(): + op_info = cfgs[module_key][state][op_cfg_id] + name = (module_key, state, op_cfg_id) + if name not in ops_name: + ops_name.append(name) + else: + assert False, "Please check IPEX int8 configure json whether have the same name ops" + op_infos_from_cfgs[name] = op_info + input_tensors = op_info["input_tensor_infos"] + for input_tensor in input_tensors: + if "id" not in input_tensor.keys(): + continue + else: + input_tensor_id = input_tensor["id"] + if input_tensor_id not in input_tensor_ids_op_name.keys(): + input_tensor_ids_op_name[input_tensor_id] = [name] + else: + input_tensor_ids_op_name[input_tensor_id].append(name) + output_tensors = op_info["output_tensor_infos"] + for output_tensor in output_tensors: + if "id" not in output_tensor.keys(): + continue + else: + output_tensor_id = output_tensor["id"] + if output_tensor_id not in output_tensor_ids_op_name.keys(): + output_tensor_ids_op_name[output_tensor_id] = [name] + else: + output_tensor_ids_op_name[output_tensor_id].append(name) + return ops_name, op_infos_from_cfgs, input_tensor_ids_op_name, output_tensor_ids_op_name + + +def get_quantizable_ops_from_cfgs(ops_name, op_infos_from_cfgs, input_tensor_ids_op_name): # pragma: no cover + """Get quantizable ops from configs, combine fused ops as one op. Args: - model (object): input model - example_inputs (dict|list|tuple|torch.Tensor): used to trace torch model. + ops_name (list): list of op names. + op_infos_from_cfgs (dict): op infos from configs. + input_tensor_ids_op_name (dict): dictionary of input tensor op names. + Returns: - quantizable_ops (list): list of tuples of op_name and op_type. - cfgs (dict): dict of configuration + cfgs (dict). """ quantizable_ops = [] - # group ops by position for transform-based model - from .utility import TransformerBasedModelBlockPatternDetector - - detector = TransformerBasedModelBlockPatternDetector(model) - detect_result = detector.detect_block() - attention_block = detect_result.get("attention_blocks", None) - ffn_blocks = detect_result.get("ffn_blocks", None) - logger.info(f"Attention Blocks: {len(attention_block)}") - logger.info(f"FFN Blocks: {len(ffn_blocks)}") - if not os.path.exists(ipex_config_path): - assert isinstance(model, torch.nn.Module), "The model passed in is not the instance of torch.nn.Module" - - if hasattr(model, "save_qconf_summary"): # pragma: no cover - os.makedirs(os.path.dirname(ipex_config_path), exist_ok=True) - model.save_qconf_summary(qconf_summary=ipex_config_path) - else: - model.eval() - - # create a quantization config file for intel pytorch extension model - os.makedirs(os.path.dirname(ipex_config_path), exist_ok=True) - assert example_inputs is not None, "IPEX need q_dataloader or example_inputs to prepare the model" - from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver, QConfig - - if ipex_ver.release >= Version("2.1").release: - # HistogramObserver will cause a performance issue. - # static_qconfig = ipex.quantization.default_static_qconfig_mapping - qconfig = QConfig( - activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8), - weight=PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric), - ) - from torch.ao.quantization import QConfigMapping - - static_qconfig = QConfigMapping().set_global(qconfig) + seen_ops = [] + for name in ops_name: + start = True + if name in seen_ops: + continue + elif name[1] not in ["q_op_infos"]: + continue else: - static_qconfig = QConfig( - activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8), - weight=PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric), + # judge fuse ops the first op + op_info = op_infos_from_cfgs[name] + output_tensors = op_info["output_tensor_infos"] + input_tensors = op_info["input_tensor_infos"] + start = any( + [ + input_tensor["inf_dtype"] != "torch.float32" + for input_tensor in input_tensors + if "inf_dtype" in input_tensor.keys() + ] ) + if not start: + continue + # add quantizable ops, include op and fuse ops. + q_ops, stack = [], [(name, [])] + while stack: + cur_name, cur = stack.pop() + seen_ops.append(cur_name) + if cur_name[1] not in ["q_op_infos"]: + q_ops.append(cur) + break + op_info = op_infos_from_cfgs[cur_name] + output_tensors = op_info["output_tensor_infos"] + for output_tensor in output_tensors: + if output_tensor["inf_dtype"] == "torch.qint8" or output_tensor["inf_dtype"] == "torch.quint8": + q_ops.append(cur + [cur_name]) + break + try: + next_op_names = input_tensor_ids_op_name[output_tensor["id"]] + for next_op_name in next_op_names: + stack.append((next_op_name, cur + [cur_name])) + except: + next_op_name = None + if next_op_name is None: + q_ops.append(cur + [cur_name]) + for q_op in q_ops: + quantizable_ops.append(q_op) + return quantizable_ops - if isinstance(example_inputs, dict): - model = ipex.quantization.prepare(model, static_qconfig, example_kwarg_inputs=example_inputs, inplace=True) - else: - model = ipex.quantization.prepare(model, static_qconfig, example_inputs=example_inputs, inplace=True) - simple_inference(model, example_inputs, iterations=1) - model.save_qconf_summary(qconf_summary=ipex_config_path) - map_op_name_to_fqn = {} - with open(ipex_config_path, "r") as f: - cfgs = json.load(f) +def get_pattern(fallback_op, fuse_ops): # pragma: no cover + for fuse_pattern in fuse_ops: + if fuse_pattern[0] == fallback_op: + if fuse_pattern[1] in ["relu_", "add_"]: + return None + else: + return fuse_pattern[1] + return None - from .utility import unify_op_type_mapping_ipex - default_cfgs = {} - fuse_ops = [] - if ipex_ver.release < Version("1.12.0").release: # pragma: no cover - default_cfgs = copy.deepcopy(cfgs) - fuse_ops = get_fuse_ops(cfgs) - for op_cfg in cfgs: - if op_cfg["name"] in unify_op_type_mapping_ipex: - quantizable_ops.append((op_cfg["id"], unify_op_type_mapping_ipex[op_cfg["name"]])) - else: - re_flag = False - for pattern, unify_op_type in unify_op_type_mapping_ipex["re"].items(): - if re.match(pattern, op_cfg["name"]): - re_flag = True - quantizable_ops.append((op_cfg["id"], unify_op_type)) - break - if not re_flag: - quantizable_ops.append((op_cfg["id"], op_cfg["name"])) - else: - ( - ops_name, - op_infos_from_cfgs, - input_tensor_id_op_name, - output_tensor_id_op_name, - ) = paser_cfgs(cfgs) - quantizable_op_names = get_quantizable_ops_from_cfgs(ops_name, op_infos_from_cfgs, input_tensor_id_op_name) - for name in quantizable_op_names: - # name : list - if len(name) == 1: - module_key = name[0][0] - op_cfg_id = name[0][2] - ipex_op_type = cfgs[module_key]["q_op_infos"][op_cfg_id]["op_type"] - module_fqn = cfgs[module_key]["q_op_infos"][op_cfg_id].get("fqn", None) +class Statistics: # pragma: no cover + """The statistics printer.""" - if ipex_op_type in unify_op_type_mapping_ipex: - quantizable_ops.append((tuple(name), unify_op_type_mapping_ipex[ipex_op_type])) - map_op_name_to_fqn[(tuple(name), ipex_op_type)] = module_fqn - else: - re_flag = False - for pattern, unify_op_type in unify_op_type_mapping_ipex["re"].items(): - if re.match(pattern, ipex_op_type): - re_flag = True - quantizable_ops.append((tuple(name), unify_op_type)) - map_op_name_to_fqn[(tuple(name), unify_op_type)] = module_fqn - break - if not re_flag: - quantizable_ops.append((tuple(name), ipex_op_type)) - map_op_name_to_fqn[(tuple(name), ipex_op_type)] = module_fqn - else: - op_type = "" - for op_name in name: - module_key = op_name[0] - op_cfg_id = op_name[2] - single_op_type = cfgs[module_key]["q_op_infos"][op_cfg_id]["op_type"] - if single_op_type in unify_op_type_mapping_ipex: - single_op_type = unify_op_type_mapping_ipex[single_op_type] - op_type += "&" + single_op_type if op_type else single_op_type - quantizable_ops.append((tuple(name), op_type)) - _module_key = name[0][0] - _op_cfg_id = name[0][2] - module_fqn = cfgs[_module_key]["q_op_infos"][_op_cfg_id]["fqn"] - map_op_name_to_fqn[(tuple(name), op_type)] = module_fqn + def __init__(self, data, header, field_names, output_handle=logger.info): + """Init a Statistics object. - logger.debug("Map op name to fqn: ") - logger.debug(map_op_name_to_fqn) - logger.info("Attention Blocks : ") - logger.info(attention_block) - logger.info("FFN Blocks : ") - logger.info(ffn_blocks) - return quantizable_ops, cfgs, default_cfgs, fuse_ops + Args: + data: The statistics data + header: The table header + field_names: The field names + output_handle: The output logging method + """ + self.field_names = field_names + self.header = header + self.data = data + self.output_handle = output_handle + self.tb = pt.PrettyTable(min_table_width=40) + def print_stat(self): + """Print the statistics.""" + valid_field_names = [] + for index, value in enumerate(self.field_names): + if index < 2: + valid_field_names.append(value) + continue + + if any(i[index] for i in self.data): + valid_field_names.append(value) + self.tb.field_names = valid_field_names + for i in self.data: + tmp_data = [] + for index, value in enumerate(i): + if self.field_names[index] in valid_field_names: + tmp_data.append(value) + if any(tmp_data[1:]): + self.tb.add_row(tmp_data) + lines = self.tb.get_string().split("\n") + self.output_handle("|" + self.header.center(len(lines[0]) - 2, "*") + "|") + for i in lines: + self.output_handle(i) -class TransformerBasedModelBlockPatternDetector: + +class TransformerBasedModelBlockPatternDetector: # pragma: no cover """Detect the attention block and FFN block in transformer-based model.""" def __init__(self, model: torch.nn.Module, pattern_lst: List[List[Union[str, int]]] = BLOCK_PATTERNS) -> None: @@ -634,45 +671,3 @@ def _group_block(detect_result): if ffn_block: ffn_block_lst.append(ffn_block) return attention_block_lst, ffn_block_lst - - -class Statistics: - """The statistics printer.""" - - def __init__(self, data, header, field_names, output_handle=logger.info): - """Init a Statistics object. - - Args: - data: The statistics data - header: The table header - field_names: The field names - output_handle: The output logging method - """ - self.field_names = field_names - self.header = header - self.data = data - self.output_handle = output_handle - self.tb = pt.PrettyTable(min_table_width=40) - - def print_stat(self): - """Print the statistics.""" - valid_field_names = [] - for index, value in enumerate(self.field_names): - if index < 2: - valid_field_names.append(value) - continue - - if any(i[index] for i in self.data): - valid_field_names.append(value) - self.tb.field_names = valid_field_names - for i in self.data: - tmp_data = [] - for index, value in enumerate(i): - if self.field_names[index] in valid_field_names: - tmp_data.append(value) - if any(tmp_data[1:]): - self.tb.add_row(tmp_data) - lines = self.tb.get_string().split("\n") - self.output_handle("|" + self.header.center(len(lines[0]) - 2, "*") + "|") - for i in lines: - self.output_handle(i) diff --git a/neural_compressor/torch/quantization/algorithm_entry.py b/neural_compressor/torch/quantization/algorithm_entry.py index 0b5bddd9146..3ba86ea9d2c 100644 --- a/neural_compressor/torch/quantization/algorithm_entry.py +++ b/neural_compressor/torch/quantization/algorithm_entry.py @@ -18,13 +18,14 @@ import torch -from neural_compressor.common.utils import AWQ, FP8_QUANT, GPTQ, HQQ, RTN, STATIC_QUANT, TEQ +from neural_compressor.common.utils import AWQ, FP8_QUANT, GPTQ, HQQ, RTN, SMOOTH_QUANT, STATIC_QUANT, TEQ from neural_compressor.torch.quantization import ( AWQConfig, FP8Config, GPTQConfig, HQQConfig, RTNConfig, + SmoothQuantConfig, StaticQuantConfig, TEQConfig, ) @@ -121,7 +122,7 @@ def static_quant_entry( logger.info("Quantize model with the static quant algorithm.") from neural_compressor.torch.algorithms.static_quant import static_quantize - # rebuild tune_cfg for static_quantize function + # convert the user config into internal format quant_config_mapping = {} cfgs = deepcopy(configs_mapping) quant_config_mapping["op"] = cfgs @@ -158,6 +159,63 @@ def static_quant_entry( return q_model +###################### Smooth Quant Algo Entry ################################## +@register_algo(name=SMOOTH_QUANT) +@torch.no_grad() +def smooth_quant_entry( + model: torch.nn.Module, configs_mapping: Dict[Tuple[str, callable], SmoothQuantConfig], *args, **kwargs +) -> torch.nn.Module: + logger.info("Quantize model with the smooth quant algorithm.") + from neural_compressor.torch.algorithms.smooth_quant import smooth_quantize + + # convert the user config into internal format + quant_config_mapping = {} + cfgs = deepcopy(configs_mapping) + quant_config_mapping["op"] = cfgs + for (op_name, op_type), cfg in cfgs.items(): + quant_config_mapping["op"][(op_name, op_type)] = { + "weight": { + "dtype": cfg.w_dtype, + "scheme": "sym", + "granularity": cfg.w_granularity, + "algorithm": cfg.w_algo, + }, + "activation": { + "dtype": cfg.act_dtype, + "scheme": "sym" if cfg.act_sym else "asym", + "granularity": cfg.act_granularity, + "algorithm": cfg.act_algo, + }, + } + quant_config_mapping["recipe_cfgs"] = { + "smooth_quant": True, + "smooth_quant_args": { + "alpha": cfg.alpha, + "folding": cfg.folding, + "scale_sharing": cfg.scale_sharing, + "auto_alpha_args": cfg.auto_alpha_args if cfg.auto_alpha_args is not None else {}, + }, + "layer_wise_quant_args": {}, + "first_conv_or_matmul_quantization": True, + "last_conv_or_matmul_quantization": True, + "pre_post_process_quantization": True, + } + + run_fn = kwargs.get("run_fn", None) + example_inputs = kwargs.get("example_inputs", None) + inplace = kwargs.get("inplace", True) + assert example_inputs is not None, "Please provide example_inputs for smooth quantization." + q_model = smooth_quantize( + model=model, + tune_cfg=quant_config_mapping, + run_fn=run_fn, + example_inputs=example_inputs, + inplace=inplace, + ) + logger.info("Smooth quantization done.") + return q_model + + ###################### AWQ Algo Entry ################################## @register_algo(name=AWQ) @torch.no_grad() diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index 567a24a6d16..6fa16824440 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -728,7 +728,7 @@ def __init__( alpha_max: float = 1.0, alpha_step: float = 0.1, shared_criterion: str = "max", - enable_blockwise_loss: bool = False, + do_blockwise: bool = False, auto_alpha_args: dict = None, white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST, ): @@ -751,14 +751,14 @@ def __init__( self.alpha_max = alpha_max self.alpha_step = alpha_step self.shared_criterion = shared_criterion - self.enable_blockwise_loss = enable_blockwise_loss + self.do_blockwise = do_blockwise self.auto_alpha_args = { "init_alpha": self.init_alpha, "alpha_min": self.alpha_min, "alpha_max": self.alpha_max, "alpha_step": self.alpha_step, "shared_criterion": self.shared_criterion, - "enable_blockwise_loss": self.enable_blockwise_loss, + "do_blockwise": self.do_blockwise, } self._post_init() @@ -772,20 +772,15 @@ def register_supported_configs(cls) -> List[OperatorConfig]: cls.supported_configs = supported_configs @staticmethod - def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]: - white_list = (torch.nn.Linear,) - filter_result = [] - for op_name, module in model.named_modules(): - if isinstance(module, white_list): - pair = (op_name, type(module).__name__) - filter_result.append(pair) - logger.debug(f"Get model info: {filter_result}") - return filter_result + def get_model_info(model: torch.nn.Module, example_inputs) -> List[Tuple[str, Callable]]: + from neural_compressor.torch.algorithms.smooth_quant import get_quantizable_ops_recursively + + model_info, _, _, _ = get_quantizable_ops_recursively(model, example_inputs=example_inputs) + return model_info @classmethod def get_config_set_for_tuning(cls) -> Union[None, "SmoothQuantConfig", List["SmoothQuantConfig"]]: - # TODO fwk owner needs to update it. - return SmoothQuantConfig(alpha=[0.1, 0.5]) + return SmoothQuantConfig(alpha=[0.1, 0.5], folding=[True, False], scale_sharing=[True, False]) def get_default_sq_config() -> SmoothQuantConfig: diff --git a/requirements_pt.txt b/requirements_pt.txt index 4cc182d4c85..2b465f2582b 100644 --- a/requirements_pt.txt +++ b/requirements_pt.txt @@ -1,2 +1,5 @@ +intel_extension_for_pytorch +peft +py-cpuinfo pydantic torch diff --git a/test/3x/torch/quantization/test_smooth_quant.py b/test/3x/torch/quantization/test_smooth_quant.py new file mode 100644 index 00000000000..1e2b581dcc6 --- /dev/null +++ b/test/3x/torch/quantization/test_smooth_quant.py @@ -0,0 +1,84 @@ +import copy + +import pytest +import torch + +from neural_compressor.torch.quantization import SmoothQuantConfig, get_default_sq_config, quantize +from neural_compressor.torch.utils import is_ipex_available + +if is_ipex_available(): + import intel_extension_for_pytorch as ipex + + +class Model(torch.nn.Module): + device = torch.device("cpu") + + def __init__(self): + super(Model, self).__init__() + self.fc1 = torch.nn.Linear(3, 4) + self.fc2 = torch.nn.Linear(4, 3) + + def forward(self, x): + out = self.fc1(x) + out = self.fc2(out) + return out + + +model = Model() + + +def run_fn(model): + model(torch.randn([1, 3])) + + +class TestSmoothQuant: + @pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX") + def test_smooth_quant_default(self): + fp32_model = copy.deepcopy(model) + quant_config = get_default_sq_config() + example_inputs = torch.randn([1, 3]) + q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs) + assert q_model is not None, "Quantization failed!" + + @pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX") + def test_smooth_quant_auto(self): + fp32_model = copy.deepcopy(model) + auto_alpha_args = { + "alpha_min": 0.45, + "alpha_max": 0.55, + "alpha_step": 0.01, + "shared_criterion": "mean", + "do_blockwise": True, + } + quant_config = SmoothQuantConfig(alpha="auto", auto_alpha_args=auto_alpha_args, folding=False) + example_inputs = torch.randn([1, 3]) + q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs) + assert q_model is not None, "Quantization failed!" + + @pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX") + @pytest.mark.parametrize( + "act_sym, act_algo, alpha, folding, scale_sharing", + [ + (True, "kl", 0.1, True, True), + (True, "minmax", 0.1, False, False), + (False, "kl", 0.5, True, False), + (False, "minmax", 0.5, False, True), + (True, "minmax", 0.1, False, True), + (False, "kl", 0.5, True, False), + ], + ) + def test_sq_linear_params(self, act_sym, act_algo, alpha, folding, scale_sharing): + fp32_model = copy.deepcopy(model) + quant_config = SmoothQuantConfig( + act_sym=act_sym, act_algo=act_algo, alpha=alpha, folding=folding, scale_sharing=scale_sharing + ) + example_inputs = torch.zeros([1, 3]) + + def run_fn(model): + model(example_inputs) + + q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs) + assert q_model is not None, "Quantization failed!" + output1 = fp32_model(example_inputs) + output2 = q_model(example_inputs) + assert torch.allclose(output1, output2, atol=2e-2), "Accuracy gap atol > 0.02 is unexpected. Please check." diff --git a/test/3x/torch/quantization/test_static_quant.py b/test/3x/torch/quantization/test_static_quant.py index 3223e65f9f5..4d739041694 100644 --- a/test/3x/torch/quantization/test_static_quant.py +++ b/test/3x/torch/quantization/test_static_quant.py @@ -1,11 +1,13 @@ import copy -import intel_extension_for_pytorch as ipex import pytest import torch from neural_compressor.torch.quantization import StaticQuantConfig, get_default_static_config, quantize -from neural_compressor.torch.utils import get_model_info, is_ipex_available, logger +from neural_compressor.torch.utils import is_ipex_available + +if is_ipex_available(): + import intel_extension_for_pytorch as ipex def build_simple_torch_model(): @@ -63,3 +65,30 @@ def test_static_quant_params(self, act_sym, act_algo): example_inputs = self.input q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs) assert q_model is not None, "Quantization failed!" + + @pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX") + def test_static_quant_accuracy(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(2, 2, False) + + def forward(self, x): + x = self.linear(x) + x = x + x + return x + + model = M() + + def run_fn(model): + model(torch.randn(3, 2)) + + fp32_model = copy.deepcopy(model) + fp32_model.linear.weight = torch.nn.Parameter(torch.tensor([[0.0, 1.0], [1.0, 0.0]])) + example_inputs = torch.zeros(3, 2) + quant_config = StaticQuantConfig(act_sym=True, act_algo="kl") + q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs) + output1 = fp32_model(example_inputs) + output2 = q_model(example_inputs) + # set a big atol to avoid random issue + assert torch.allclose(output1, output2, atol=2e-2), "Accuracy gap atol > 0.02 is unexpected. Please check."