Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support xpu for ipex static quant #1916

Merged
merged 14 commits into from
Jul 17, 2024
14 changes: 11 additions & 3 deletions neural_compressor/torch/algorithms/static_quant/save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,17 @@ def save(model, output_dir="./saved_results"):

qmodel_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), WEIGHT_NAME)
qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), QCONFIG_NAME)
model.ori_save(qmodel_file_path)
with open(qconfig_file_path, "w") as f:
json.dump(model.tune_cfg, f, indent=4)
if next(model.parameters()).device.type == "cpu":
model.ori_save(qmodel_file_path)
with open(qconfig_file_path, "w") as f:
json.dump(model.tune_cfg, f, indent=4)
else:
from neural_compressor.common.utils import save_config_mapping

save_config_mapping(model.qconfig, qconfig_file_path)
# MethodType 'save' not in state_dict
del model.save
torch.save(model.state_dict(), qmodel_file_path)
violetch24 marked this conversation as resolved.
Show resolved Hide resolved

logger.info("Save quantized model to {}.".format(qmodel_file_path))
logger.info("Save configuration of quantized model to {}.".format(qconfig_file_path))
Expand Down
109 changes: 69 additions & 40 deletions neural_compressor/torch/algorithms/static_quant/static_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,13 @@

from neural_compressor.torch.algorithms import Quantizer
from neural_compressor.torch.utils import logger
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator

from .utility import (
CpuInfo,
cfg_to_qconfig,
dump_model_op_stats,
generate_xpu_qconfig,
get_ipex_version,
get_quantizable_ops_recursively,
ipex_config_path,
Expand All @@ -56,6 +58,7 @@ def __init__(self, quant_config: OrderedDict = {}):
"""
super().__init__(quant_config)
self.user_cfg = OrderedDict()
self.device = auto_detect_accelerator().current_device()

def prepare(self, model, example_inputs, inplace=True, *args, **kwargs):
"""Prepares a given model for quantization.
Expand All @@ -70,43 +73,61 @@ def prepare(self, model, example_inputs, inplace=True, *args, **kwargs):
"""
assert example_inputs is not None, "Please provide example_inputs for static quantization."

_, cfgs, op_infos_from_cfgs, output_tensor_id_op_name, _ = get_quantizable_ops_recursively(
model, example_inputs
)
# update json file in ipex_config_path; map ipex op_name to pt op_name
self.user_cfg = cfg_to_qconfig(self.quant_config, cfgs, op_infos_from_cfgs, output_tensor_id_op_name)
model.eval()
if self.device == "cpu":
_, cfgs, op_infos_from_cfgs, output_tensor_id_op_name, _ = get_quantizable_ops_recursively(
model, example_inputs
)
# update json file in ipex_config_path; map ipex op_name to pt op_name
self.user_cfg = cfg_to_qconfig(self.quant_config, cfgs, op_infos_from_cfgs, output_tensor_id_op_name)
else:
model = model.to("xpu")

use_bf16 = self.quant_config.get("use_bf16", None)
model.eval()

# Check save_qconf_summary part is a workaround for IPEX bug.
# Sometimes the prepared model from get_op_capablitiy loss this attribute
if not hasattr(model, "save_qconf_summary") or not hasattr(model, "load_qconf_summary"):
from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver, QConfig

if ipex_ver.release >= Version("2.1").release:
# HistogramObserver will cause a performance issue.
# static_qconfig = ipex.quantization.default_static_qconfig_mapping
qconfig = QConfig(
activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8),
weight=PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric),
)
from torch.ao.quantization import QConfigMapping

static_qconfig = QConfigMapping().set_global(qconfig)
else:
static_qconfig = QConfig(
activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8),
weight=PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric),
)
if isinstance(example_inputs, dict):
model = ipex.quantization.prepare(
model, static_qconfig, example_kwarg_inputs=example_inputs, inplace=inplace
)
# Sometimes the prepared model from get_op_capablitiy loss this attributes
if not hasattr(model, "save_qconf_summary") or not hasattr(model, "load_qconf_summary"): # pragma: no cover
from torch.ao.quantization import HistogramObserver, MinMaxObserver, PerChannelMinMaxObserver, QConfig

if self.device != "cpu": # pragma: no cover
from torch.quantization.quantize_jit import prepare_jit

with torch.no_grad():
modelJit = torch.jit.trace(model, example_inputs)
qconfig = generate_xpu_qconfig(self.quant_config)
model = prepare_jit(modelJit, qconfig, inplace)
else:
model = ipex.quantization.prepare(model, static_qconfig, example_inputs=example_inputs, inplace=inplace)
if ipex_ver.release >= Version("2.1").release:
# HistogramObserver will cause a performance issue.
# static_qconfig = ipex.quantization.default_static_qconfig_mapping
qconfig = QConfig(
activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8),
weight=PerChannelMinMaxObserver.with_args(
dtype=torch.qint8, qscheme=torch.per_channel_symmetric
),
)
from torch.ao.quantization import QConfigMapping

static_qconfig = QConfigMapping().set_global(qconfig)
else:
static_qconfig = QConfig(
activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8),
weight=PerChannelMinMaxObserver.with_args(
dtype=torch.qint8, qscheme=torch.per_channel_symmetric
),
)
if isinstance(example_inputs, dict):
model = ipex.quantization.prepare(
model, static_qconfig, example_kwarg_inputs=example_inputs, inplace=inplace
)
else:
model = ipex.quantization.prepare(
model, static_qconfig, example_inputs=example_inputs, inplace=inplace
)

if self.device == "cpu":
model.load_qconf_summary(qconf_summary=ipex_config_path)

model.load_qconf_summary(qconf_summary=ipex_config_path)
return model

def convert(self, model, example_inputs, inplace=True, *args, **kwargs):
Expand All @@ -124,18 +145,26 @@ def convert(self, model, example_inputs, inplace=True, *args, **kwargs):

from neural_compressor.torch.algorithms.static_quant import save

model.save_qconf_summary(qconf_summary=ipex_config_path)
model = _ipex_post_quant_process(model, example_inputs, use_bf16, inplace=inplace)
if self.device != "cpu": # pragma: no cover
from torch.quantization.quantize_jit import convert_jit

model = convert_jit(model, inplace)
simple_inference(model, example_inputs, iterations=2)
dump_model_op_stats(self.quant_config["op"])
else:
model.save_qconf_summary(qconf_summary=ipex_config_path)
model = _ipex_post_quant_process(model, example_inputs, use_bf16, inplace=inplace)

with open(ipex_config_path, "r") as f:
model.tune_cfg = json.load(f)
model.ipex_config_path = ipex_config_path

with open(ipex_config_path, "r") as f:
model.tune_cfg = json.load(f)
model.ipex_config_path = ipex_config_path
dump_model_op_stats(self.user_cfg)

dump_model_op_stats(self.user_cfg)
model.ori_save = model.save
model.save = MethodType(save, model)

logger.info("Static quantization done.")
model.ori_save = model.save
model.save = MethodType(save, model)
return model


Expand Down
41 changes: 41 additions & 0 deletions neural_compressor/torch/algorithms/static_quant/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,47 @@ def check_cfg_and_qconfig(user_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_
return cfgs, ori_user_cfg


def generate_xpu_qconfig(tune_cfg):
# qconfig observer & config constants for ipex-xpu
from torch.ao.quantization import HistogramObserver, MinMaxObserver, QConfig

act_observer_minmax_asym = MinMaxObserver.with_args(quant_min=0, quant_max=127)
act_observer_minmax_sym = MinMaxObserver.with_args(
dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, quant_min=-128, quant_max=127
)
act_observer_kl_asym = HistogramObserver.with_args(quant_min=0, quant_max=127)
act_observer_kl_sym = HistogramObserver.with_args(
dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, quant_min=-128, quant_max=127
)
# no tuning for granularity due to tuning space
weight_observer_minmax_sym = MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric)

qconfig = {}
user_cfg = copy.deepcopy(tune_cfg["op"])
for _, cfg in user_cfg.items():
act_algo = cfg["activation"]["algorithm"]
act_sym = cfg["activation"]["scheme"]
break

if act_algo == "minmax":
if act_sym == "sym":
activation = act_observer_minmax_sym
else:
activation = act_observer_minmax_asym
else:
if act_sym == "sym":
activation = act_observer_kl_sym
else:
activation = act_observer_kl_asym

qconfig[""] = QConfig(activation=activation, weight=weight_observer_minmax_sym)

for (op_name, op_type), cfg in user_cfg.items():
if cfg["weight"]["dtype"] == "fp32":
qconfig[op_name] = None
return qconfig


def generate_activation_observer(
scheme, algorithm, smooth_quant=False, smooth_quant_enable=False, alpha=0.5
): # pragma: no cover
Expand Down
26 changes: 23 additions & 3 deletions neural_compressor/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,6 +1043,7 @@ def __init__(
act_algo: str = "minmax",
excluded_precisions: list = [],
white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST,
model_info: Optional[List[Tuple[str, Callable]]] = None,
):
"""Init Static Quant Configs."""
super().__init__(white_list=white_list)
Expand All @@ -1055,6 +1056,7 @@ def __init__(
self.act_granularity = act_granularity
self.act_algo = act_algo
self.excluded_precisions = excluded_precisions
self.model_info = model_info
self._post_init()

@classmethod
Expand All @@ -1072,10 +1074,28 @@ def get_model_info_for_ipex(model: torch.nn.Module, example_inputs) -> List[Tupl
_, _, _, _, model_info = get_quantizable_ops_recursively(model, example_inputs=example_inputs)
return model_info

@staticmethod
def get_model_info(model: torch.nn.Module, example_inputs=None) -> List[Tuple[str, Callable]]:
def get_model_info_for_ipex_xpu(self, model: torch.nn.Module) -> List[Tuple[str, Callable]]:
if self.model_info:
return self.model_info
else:
white_list = torch.quantization.quantization_mappings.get_default_qconfig_propagation_list()
filter_result = []
for op_name, module in model.named_modules():
if type(module) in white_list:
pair = (op_name, type(module).__name__)
filter_result.append(pair)
logger.debug(f"Get model info: {filter_result}")
self.model_info = filter_result
return filter_result

def get_model_info(self, model: torch.nn.Module, example_inputs=None) -> List[Tuple[str, Callable]]:
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator

if is_ipex_imported():
return StaticQuantConfig.get_model_info_for_ipex(model, example_inputs)
if auto_detect_accelerator().current_device() == "cpu":
return StaticQuantConfig.get_model_info_for_ipex(model, example_inputs)
else:
return StaticQuantConfig.get_model_info_for_ipex_xpu(self, model)

def to_config_mapping(
self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None
Expand Down
57 changes: 47 additions & 10 deletions test/3x/torch/quantization/test_static_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,24 @@
import pytest
import torch

try:
import intel_extension_for_pytorch as ipex

is_ipex_available = True
except: # pragma: no cover
is_ipex_available = False
assert False, "Please install IPEX for static quantization."

from neural_compressor.torch.quantization import (
StaticQuantConfig,
convert,
get_default_static_config,
prepare,
quantize,
)
from neural_compressor.torch.utils import is_ipex_available
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator

if is_ipex_available():
import intel_extension_for_pytorch as ipex
device = auto_detect_accelerator().current_device()


def build_simple_torch_model():
Expand Down Expand Up @@ -53,7 +60,7 @@ def setup_class(self):
def teardown_class(self):
shutil.rmtree("saved_results", ignore_errors=True)

@pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX")
@pytest.mark.skipif(not is_ipex_available or device != "cpu", reason="Requires IPEX on CPU device")
def test_static_quant_default(self):
fp32_model = copy.deepcopy(self.fp32_model)
quant_config = get_default_static_config()
Expand All @@ -70,7 +77,7 @@ def test_static_quant_default(self):
q_model = convert(prepared_model)
assert q_model is not None, "Quantization failed!"

@pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX")
@pytest.mark.skipif(not is_ipex_available or device != "cpu", reason="Requires IPEX on CPU device")
def test_static_quant_fallback(self):
fp32_model = copy.deepcopy(self.fp32_model)
quant_config = get_default_static_config()
Expand Down Expand Up @@ -100,7 +107,7 @@ def test_static_quant_fallback(self):
dtype = q_model.tune_cfg[" "]["q_op_infos"][op]["input_tensor_infos"][0]["force_dtype"]
assert dtype == "torch.float32", "Failed to fallback fc2 layer, please check!"

@pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX")
@pytest.mark.skipif(not is_ipex_available or device != "cpu", reason="Requires IPEX on CPU device")
@pytest.mark.parametrize(
"act_sym, act_algo",
[
Expand All @@ -119,7 +126,7 @@ def test_static_quant_params(self, act_sym, act_algo):
q_model = convert(prepared_model)
assert q_model is not None, "Quantization failed!"

@pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX")
@pytest.mark.skipif(not is_ipex_available or device != "cpu", reason="Requires IPEX on CPU device")
def test_static_quant_accuracy(self):
class M(torch.nn.Module):
def __init__(self):
Expand Down Expand Up @@ -148,7 +155,7 @@ def run_fn(model):
# set a big atol to avoid random issue
assert torch.allclose(output1, output2, atol=2e-2), "Accuracy gap atol > 0.02 is unexpected. Please check."

@pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX")
@pytest.mark.skipif(not is_ipex_available or device != "cpu", reason="Requires IPEX on CPU device")
def test_static_quant_save_load(self):
from intel_extension_for_pytorch.quantization import convert as ipex_convert
from intel_extension_for_pytorch.quantization import prepare as ipex_prepare
Expand Down Expand Up @@ -196,7 +203,7 @@ def run_fn(model):
loaded_model = load("saved_results")
assert isinstance(loaded_model, torch.jit.ScriptModule)

@pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX")
@pytest.mark.skipif(not is_ipex_available or device != "cpu", reason="Requires IPEX on CPU device")
def test_static_quant_with_quantize_API(self):
# quantize API
fp32_model = copy.deepcopy(self.fp32_model)
Expand All @@ -205,7 +212,7 @@ def test_static_quant_with_quantize_API(self):
q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs)
assert q_model is not None, "Quantization failed!"

@pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX")
@pytest.mark.skipif(not is_ipex_available or device != "cpu", reason="Requires IPEX on CPU device")
def test_static_quant_mixed_precision(self):
fp32_model = copy.deepcopy(self.fp32_model)
example_inputs = self.input
Expand All @@ -227,3 +234,33 @@ def test_static_quant_mixed_precision(self):
run_fn(prepared_model)
q_model = convert(prepared_model)
assert q_model is not None, "Quantization failed!"

@pytest.mark.skipif(not is_ipex_available or device == "cpu", reason="Requires IPEX on XPU device")
@pytest.mark.parametrize(
"act_sym, act_algo",
[
(True, "kl"),
(True, "minmax"),
(False, "kl"),
(False, "minmax"),
],
)
def test_static_quant_params_xpu(self, act_sym, act_algo):
import torchvision.models as models

model = models.resnet50(pretrained=True)
fp32_model = copy.deepcopy(model)
data = torch.rand(1, 3, 224, 224)
example_inputs = data.to("xpu")

def run_fn(model):
model(example_inputs)

quant_config = StaticQuantConfig(act_sym=act_sym, act_algo=act_algo, excluded_precisions=["bf16"])
# fallback by op_name
quant_config.set_local("conv1", StaticQuantConfig(w_dtype="fp32", act_dtype="fp32"))
prepared_model = prepare(fp32_model, quant_config=quant_config, example_inputs=example_inputs)
run_fn(prepared_model)
q_model = convert(prepared_model)
run_fn(q_model)
assert q_model is not None, "Quantization failed!"
Loading