From 53e6ee6b75d476bae0382c7d6fb9aa1348c2ab5e Mon Sep 17 00:00:00 2001 From: Zixuan Cheng <110808245+violetch24@users.noreply.github.com> Date: Wed, 17 Jul 2024 20:35:03 +0800 Subject: [PATCH] Support xpu for ipex static quant (#1916) Signed-off-by: violetch24 --- .../algorithms/static_quant/save_load.py | 13 ++- .../algorithms/static_quant/static_quant.py | 108 +++++++++++------- .../torch/algorithms/static_quant/utility.py | 41 +++++++ .../torch/quantization/config.py | 26 ++++- .../torch/quantization/test_static_quant.py | 73 ++++++++++-- 5 files changed, 206 insertions(+), 55 deletions(-) diff --git a/neural_compressor/torch/algorithms/static_quant/save_load.py b/neural_compressor/torch/algorithms/static_quant/save_load.py index 557c1577728..9a7808c17eb 100644 --- a/neural_compressor/torch/algorithms/static_quant/save_load.py +++ b/neural_compressor/torch/algorithms/static_quant/save_load.py @@ -32,9 +32,16 @@ def save(model, output_dir="./saved_results"): qmodel_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), WEIGHT_NAME) qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), QCONFIG_NAME) - model.ori_save(qmodel_file_path) - with open(qconfig_file_path, "w") as f: - json.dump(model.tune_cfg, f, indent=4) + device = next(model.parameters(), None).device.type if next(model.parameters(), None) else "cpu" + if device == "cpu": + model.ori_save(qmodel_file_path) + with open(qconfig_file_path, "w") as f: + json.dump(model.tune_cfg, f, indent=4) + else: # pragma: no cover + from neural_compressor.common.utils import save_config_mapping + + torch.jit.save(model, qmodel_file_path) + save_config_mapping(model.qconfig, qconfig_file_path) logger.info("Save quantized model to {}.".format(qmodel_file_path)) logger.info("Save configuration of quantized model to {}.".format(qconfig_file_path)) diff --git a/neural_compressor/torch/algorithms/static_quant/static_quant.py b/neural_compressor/torch/algorithms/static_quant/static_quant.py index efd1880666c..08dc5a1035f 100644 --- a/neural_compressor/torch/algorithms/static_quant/static_quant.py +++ b/neural_compressor/torch/algorithms/static_quant/static_quant.py @@ -33,11 +33,13 @@ from neural_compressor.torch.algorithms import Quantizer from neural_compressor.torch.utils import logger +from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator from .utility import ( CpuInfo, cfg_to_qconfig, dump_model_op_stats, + generate_xpu_qconfig, get_ipex_version, get_quantizable_ops_recursively, ipex_config_path, @@ -56,6 +58,7 @@ def __init__(self, quant_config: OrderedDict = {}): """ super().__init__(quant_config) self.user_cfg = OrderedDict() + self.device = auto_detect_accelerator().current_device() def prepare(self, model, example_inputs, inplace=True, *args, **kwargs): """Prepares a given model for quantization. @@ -70,43 +73,61 @@ def prepare(self, model, example_inputs, inplace=True, *args, **kwargs): """ assert example_inputs is not None, "Please provide example_inputs for static quantization." - _, cfgs, op_infos_from_cfgs, output_tensor_id_op_name, _ = get_quantizable_ops_recursively( - model, example_inputs - ) - # update json file in ipex_config_path; map ipex op_name to pt op_name - self.user_cfg = cfg_to_qconfig(self.quant_config, cfgs, op_infos_from_cfgs, output_tensor_id_op_name) - model.eval() + if self.device == "cpu": + _, cfgs, op_infos_from_cfgs, output_tensor_id_op_name, _ = get_quantizable_ops_recursively( + model, example_inputs + ) + # update json file in ipex_config_path; map ipex op_name to pt op_name + self.user_cfg = cfg_to_qconfig(self.quant_config, cfgs, op_infos_from_cfgs, output_tensor_id_op_name) + else: # pragma: no cover + model = model.to("xpu") - use_bf16 = self.quant_config.get("use_bf16", None) + model.eval() # Check save_qconf_summary part is a workaround for IPEX bug. - # Sometimes the prepared model from get_op_capablitiy loss this attribute - if not hasattr(model, "save_qconf_summary") or not hasattr(model, "load_qconf_summary"): - from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver, QConfig - - if ipex_ver.release >= Version("2.1").release: - # HistogramObserver will cause a performance issue. - # static_qconfig = ipex.quantization.default_static_qconfig_mapping - qconfig = QConfig( - activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8), - weight=PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric), - ) - from torch.ao.quantization import QConfigMapping - - static_qconfig = QConfigMapping().set_global(qconfig) - else: - static_qconfig = QConfig( - activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8), - weight=PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric), - ) - if isinstance(example_inputs, dict): - model = ipex.quantization.prepare( - model, static_qconfig, example_kwarg_inputs=example_inputs, inplace=inplace - ) + # Sometimes the prepared model from get_op_capablitiy loss this attributes + if not hasattr(model, "save_qconf_summary") or not hasattr(model, "load_qconf_summary"): # pragma: no cover + from torch.ao.quantization import HistogramObserver, MinMaxObserver, PerChannelMinMaxObserver, QConfig + + if self.device != "cpu": # pragma: no cover + from torch.quantization.quantize_jit import prepare_jit + + with torch.no_grad(): + modelJit = torch.jit.trace(model, example_inputs) + qconfig = generate_xpu_qconfig(self.quant_config) + model = prepare_jit(modelJit, qconfig, inplace) else: - model = ipex.quantization.prepare(model, static_qconfig, example_inputs=example_inputs, inplace=inplace) + if ipex_ver.release >= Version("2.1").release: + # HistogramObserver will cause a performance issue. + # static_qconfig = ipex.quantization.default_static_qconfig_mapping + qconfig = QConfig( + activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8), + weight=PerChannelMinMaxObserver.with_args( + dtype=torch.qint8, qscheme=torch.per_channel_symmetric + ), + ) + from torch.ao.quantization import QConfigMapping + + static_qconfig = QConfigMapping().set_global(qconfig) + else: # pragma: no cover + static_qconfig = QConfig( + activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8), + weight=PerChannelMinMaxObserver.with_args( + dtype=torch.qint8, qscheme=torch.per_channel_symmetric + ), + ) + if isinstance(example_inputs, dict): + model = ipex.quantization.prepare( + model, static_qconfig, example_kwarg_inputs=example_inputs, inplace=inplace + ) + else: + model = ipex.quantization.prepare( + model, static_qconfig, example_inputs=example_inputs, inplace=inplace + ) + + if self.device == "cpu": + model.load_qconf_summary(qconf_summary=ipex_config_path) - model.load_qconf_summary(qconf_summary=ipex_config_path) return model def convert(self, model, example_inputs, inplace=True, *args, **kwargs): @@ -124,18 +145,27 @@ def convert(self, model, example_inputs, inplace=True, *args, **kwargs): from neural_compressor.torch.algorithms.static_quant import save - model.save_qconf_summary(qconf_summary=ipex_config_path) - model = _ipex_post_quant_process(model, example_inputs, use_bf16, inplace=inplace) + if self.device != "cpu": # pragma: no cover + from torch.quantization.quantize_jit import convert_jit - with open(ipex_config_path, "r") as f: - model.tune_cfg = json.load(f) - model.ipex_config_path = ipex_config_path + model = convert_jit(model, inplace) + simple_inference(model, example_inputs, iterations=2) + model.qconfig = self.quant_config["op"] + dump_model_op_stats(model.qconfig) + else: + model.save_qconf_summary(qconf_summary=ipex_config_path) + model = _ipex_post_quant_process(model, example_inputs, use_bf16, inplace=inplace) - dump_model_op_stats(self.user_cfg) + with open(ipex_config_path, "r") as f: + model.tune_cfg = json.load(f) + model.ipex_config_path = ipex_config_path + + dump_model_op_stats(self.user_cfg) - logger.info("Static quantization done.") model.ori_save = model.save model.save = MethodType(save, model) + + logger.info("Static quantization done.") return model diff --git a/neural_compressor/torch/algorithms/static_quant/utility.py b/neural_compressor/torch/algorithms/static_quant/utility.py index 23ac16630a4..f4930a22ddd 100644 --- a/neural_compressor/torch/algorithms/static_quant/utility.py +++ b/neural_compressor/torch/algorithms/static_quant/utility.py @@ -163,6 +163,47 @@ def check_cfg_and_qconfig(user_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_ return cfgs, ori_user_cfg +def generate_xpu_qconfig(tune_cfg): # pragma: no cover + # qconfig observer & config constants for ipex-xpu + from torch.ao.quantization import HistogramObserver, MinMaxObserver, QConfig + + act_observer_minmax_asym = MinMaxObserver.with_args(quant_min=0, quant_max=127) + act_observer_minmax_sym = MinMaxObserver.with_args( + dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, quant_min=-128, quant_max=127 + ) + act_observer_kl_asym = HistogramObserver.with_args(quant_min=0, quant_max=127) + act_observer_kl_sym = HistogramObserver.with_args( + dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, quant_min=-128, quant_max=127 + ) + # no tuning for granularity due to tuning space + weight_observer_minmax_sym = MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric) + + qconfig = {} + user_cfg = copy.deepcopy(tune_cfg["op"]) + for _, cfg in user_cfg.items(): + act_algo = cfg["activation"]["algorithm"] + act_sym = cfg["activation"]["scheme"] + break + + if act_algo == "minmax": + if act_sym == "sym": + activation = act_observer_minmax_sym + else: + activation = act_observer_minmax_asym + else: + if act_sym == "sym": + activation = act_observer_kl_sym + else: + activation = act_observer_kl_asym + + qconfig[""] = QConfig(activation=activation, weight=weight_observer_minmax_sym) + + for (op_name, op_type), cfg in user_cfg.items(): + if cfg["weight"]["dtype"] == "fp32": + qconfig[op_name] = None + return qconfig + + def generate_activation_observer( scheme, algorithm, smooth_quant=False, smooth_quant_enable=False, alpha=0.5 ): # pragma: no cover diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index 49238ec2ee5..2c43f1e59c1 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -1095,6 +1095,7 @@ def __init__( act_algo: str = "minmax", excluded_precisions: list = [], white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST, + model_info: Optional[List[Tuple[str, Callable]]] = None, ): """Init Static Quant Configs.""" super().__init__(white_list=white_list) @@ -1107,6 +1108,7 @@ def __init__( self.act_granularity = act_granularity self.act_algo = act_algo self.excluded_precisions = excluded_precisions + self.model_info = model_info self._post_init() @classmethod @@ -1124,10 +1126,28 @@ def get_model_info_for_ipex(model: torch.nn.Module, example_inputs) -> List[Tupl _, _, _, _, model_info = get_quantizable_ops_recursively(model, example_inputs=example_inputs) return model_info - @staticmethod - def get_model_info(model: torch.nn.Module, example_inputs=None) -> List[Tuple[str, Callable]]: + def get_model_info_for_ipex_xpu(self, model: torch.nn.Module) -> List[Tuple[str, Callable]]: # pragma: no cover + if self.model_info: + return self.model_info + else: + white_list = torch.quantization.quantization_mappings.get_default_qconfig_propagation_list() + filter_result = [] + for op_name, module in model.named_modules(): + if type(module) in white_list: + pair = (op_name, type(module).__name__) + filter_result.append(pair) + logger.debug(f"Get model info: {filter_result}") + self.model_info = filter_result + return filter_result + + def get_model_info(self, model: torch.nn.Module, example_inputs=None) -> List[Tuple[str, Callable]]: + from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator + if is_ipex_imported(): - return StaticQuantConfig.get_model_info_for_ipex(model, example_inputs) + if auto_detect_accelerator().current_device() == "cpu": + return StaticQuantConfig.get_model_info_for_ipex(model, example_inputs) + else: + return StaticQuantConfig.get_model_info_for_ipex_xpu(self, model) def to_config_mapping( self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None diff --git a/test/3x/torch/quantization/test_static_quant.py b/test/3x/torch/quantization/test_static_quant.py index cae22ff79ab..4aecd29eecf 100644 --- a/test/3x/torch/quantization/test_static_quant.py +++ b/test/3x/torch/quantization/test_static_quant.py @@ -4,6 +4,14 @@ import pytest import torch +try: + import intel_extension_for_pytorch as ipex + + is_ipex_available = True +except: # pragma: no cover + is_ipex_available = False + assert False, "Please install IPEX for static quantization." + from neural_compressor.torch.quantization import ( StaticQuantConfig, convert, @@ -11,10 +19,9 @@ prepare, quantize, ) -from neural_compressor.torch.utils import is_ipex_available +from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator -if is_ipex_available(): - import intel_extension_for_pytorch as ipex +device = auto_detect_accelerator().current_device() def build_simple_torch_model(): @@ -53,7 +60,7 @@ def setup_class(self): def teardown_class(self): shutil.rmtree("saved_results", ignore_errors=True) - @pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX") + @pytest.mark.skipif(not is_ipex_available or device != "cpu", reason="Requires IPEX on CPU device") def test_static_quant_default(self): fp32_model = copy.deepcopy(self.fp32_model) quant_config = get_default_static_config() @@ -70,7 +77,7 @@ def test_static_quant_default(self): q_model = convert(prepared_model) assert q_model is not None, "Quantization failed!" - @pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX") + @pytest.mark.skipif(not is_ipex_available or device != "cpu", reason="Requires IPEX on CPU device") def test_static_quant_fallback(self): fp32_model = copy.deepcopy(self.fp32_model) quant_config = get_default_static_config() @@ -100,7 +107,7 @@ def test_static_quant_fallback(self): dtype = q_model.tune_cfg[" "]["q_op_infos"][op]["input_tensor_infos"][0]["force_dtype"] assert dtype == "torch.float32", "Failed to fallback fc2 layer, please check!" - @pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX") + @pytest.mark.skipif(not is_ipex_available or device != "cpu", reason="Requires IPEX on CPU device") @pytest.mark.parametrize( "act_sym, act_algo", [ @@ -119,7 +126,7 @@ def test_static_quant_params(self, act_sym, act_algo): q_model = convert(prepared_model) assert q_model is not None, "Quantization failed!" - @pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX") + @pytest.mark.skipif(not is_ipex_available or device != "cpu", reason="Requires IPEX on CPU device") def test_static_quant_accuracy(self): class M(torch.nn.Module): def __init__(self): @@ -148,7 +155,7 @@ def run_fn(model): # set a big atol to avoid random issue assert torch.allclose(output1, output2, atol=2e-2), "Accuracy gap atol > 0.02 is unexpected. Please check." - @pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX") + @pytest.mark.skipif(not is_ipex_available or device != "cpu", reason="Requires IPEX on CPU device") def test_static_quant_save_load(self): from intel_extension_for_pytorch.quantization import convert as ipex_convert from intel_extension_for_pytorch.quantization import prepare as ipex_prepare @@ -196,7 +203,7 @@ def run_fn(model): loaded_model = load("saved_results") assert isinstance(loaded_model, torch.jit.ScriptModule) - @pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX") + @pytest.mark.skipif(not is_ipex_available or device != "cpu", reason="Requires IPEX on CPU device") def test_static_quant_with_quantize_API(self): # quantize API fp32_model = copy.deepcopy(self.fp32_model) @@ -205,7 +212,7 @@ def test_static_quant_with_quantize_API(self): q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs) assert q_model is not None, "Quantization failed!" - @pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX") + @pytest.mark.skipif(not is_ipex_available or device != "cpu", reason="Requires IPEX on CPU device") def test_static_quant_mixed_precision(self): fp32_model = copy.deepcopy(self.fp32_model) example_inputs = self.input @@ -227,3 +234,49 @@ def test_static_quant_mixed_precision(self): run_fn(prepared_model) q_model = convert(prepared_model) assert q_model is not None, "Quantization failed!" + + @pytest.mark.skipif(not is_ipex_available or device == "cpu", reason="Requires IPEX on XPU device") + @pytest.mark.parametrize( + "act_sym, act_algo", + [ + (True, "kl"), + (True, "minmax"), + (False, "kl"), + (False, "minmax"), + ], + ) + def test_static_quant_xpu(self, act_sym, act_algo): + import torchvision.models as models + + model = models.resnet50(pretrained=True) + fp32_model = copy.deepcopy(model) + data = torch.rand(1, 3, 224, 224) + example_inputs = data.to("xpu") + + def run_fn(model): + model(example_inputs) + + quant_config = StaticQuantConfig(act_sym=act_sym, act_algo=act_algo, excluded_precisions=["bf16"]) + # fallback by op_name + quant_config.set_local("conv1", StaticQuantConfig(w_dtype="fp32", act_dtype="fp32")) + prepared_model = prepare(fp32_model, quant_config=quant_config, example_inputs=example_inputs) + run_fn(prepared_model) + q_model = convert(prepared_model) + run_fn(q_model) + assert q_model is not None, "Quantization failed!" + + quant_config = StaticQuantConfig(act_sym=act_sym, act_algo=act_algo, excluded_precisions=["bf16"]) + # fallback by op_type + quant_config.set_local("Conv2d", StaticQuantConfig(w_dtype="fp32", act_dtype="fp32")) + prepared_model = prepare(fp32_model, quant_config=quant_config, example_inputs=example_inputs) + run_fn(prepared_model) + q_model = convert(prepared_model) + run_fn(q_model) + assert q_model is not None, "Quantization failed!" + + q_model.save("saved_results") + from neural_compressor.torch.quantization import load + + # load + loaded_model = load("saved_results") + assert isinstance(loaded_model, torch.jit.ScriptModule), "Loading failed!"