Skip to content

Commit

Permalink
fix 3x ipex static quant regression (#1864)
Browse files Browse the repository at this point in the history
Description
fix 3x ipex static quant regression
cannot fallback with op type name ('linear')
dump wrong op stats (no 'Linear&relu' op type)
---------

Signed-off-by: Cheng, Zixuan <[email protected]>
  • Loading branch information
violetch24 authored Jun 14, 2024
1 parent 4e45f8f commit 70a1d50
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 40 deletions.
55 changes: 54 additions & 1 deletion neural_compressor/torch/algorithms/smooth_quant/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@

from neural_compressor.torch.algorithms.static_quant import (
CpuInfo,
Statistics,
TransformerBasedModelBlockPatternDetector,
dump_model_op_stats,
generate_activation_observer,
get_quantizable_ops_from_cfgs,
ipex_config_path,
Expand Down Expand Up @@ -251,6 +251,59 @@ def cfg_to_qconfig(
return None


def dump_model_op_stats(user_cfg):
"""This is a function to dump quantizable ops of model to user.
Args:
user_cfg (dict): quantization config
Returns:
None
"""
res = dict()
for k, v in user_cfg.items():
op_type_list = k[-1].split("><")
op_type = ""
for op in op_type_list:
if "class" in op:
op_type = (
op[op.rfind(".") + 1 : op.rfind("'")]
if op_type == ""
else op_type + "&" + op[op.rfind(".") + 1 : op.rfind("'")]
)
elif "method" in op:
start = op.find("'") + 1
if start > 1:
op_type = (
op[start : op.find("'", start)]
if op_type == ""
else op_type + "&" + op[start : op.find("'", start)]
)
else:
start = op.find("method") + 7
op_type = (
op[start : op.find(" ", start)]
if op_type == ""
else op_type + "&" + op[start : op.find(" ", start)]
)
else:
op_type = op if op_type == "" else op_type + "&" + op
if op_type not in res.keys():
res[op_type] = {"INT8": 0, "BF16": 0, "FP32": 0}
if v["weight"]["dtype"] == "int8":
res[op_type]["INT8"] += 1
elif v["weight"]["dtype"] == "fp32":
res[op_type]["FP32"] += 1

output_data = [
[op_type, sum(res[op_type].values()), res[op_type]["INT8"], res[op_type]["BF16"], res[op_type]["FP32"]]
for op_type in res.keys()
]

Statistics(
output_data, header="Mixed Precision Statistics", field_names=["Op Type", "Total", "INT8", "BF16", "FP32"]
).print_stat()


def get_parent(node, all_parents=False): # pragma: no cover
if node.inputs() is None:
return None
Expand Down
48 changes: 15 additions & 33 deletions neural_compressor/torch/algorithms/static_quant/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,19 @@
"<class 'torch.nn.modules.conv.Conv2d'>": "Conv2d",
"<class 'torch.nn.modules.conv.Conv3d'>": "Conv3d",
"<class 'torch.nn.modules.activation.ReLU'>": "ReLU",
"<class 'torch.nn.modules.sparse.EmbeddingBag'>": "EmbeddingBag",
"<method 'add' of 'torch._C._TensorBase' objects>": "add", # for IPEX < 2.2
"<method 'add' of 'torch._C.TensorBase' objects>": "add", # for IPEX >= 2.2
"<class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>": "AdaptiveAvgPool2d",
"Linear_Relu": "Linear",
"Linear_add": "Linear",
"<class 'torch.nn.modules.linear.Linear'>": "Linear",
"<class 'torch.nn.modules.pooling.MaxPool2d'>": "MaxPool2d",
"re": {"<built-in method matmul of type object at": "matmul"},
"re": {
"<built-in method matmul of type object at": "matmul",
"<built-in method add of type object at": "add",
"<built-in method bmm of type object at": "bmm",
},
}

BLOCK_PATTERNS = [
Expand Down Expand Up @@ -85,6 +91,7 @@ def check_cfg_and_qconfig(user_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_
Returns:
cfgs (dict): updated configs.
"""
ori_user_cfg = copy.deepcopy(user_cfg)
tmp_user_cfg = OrderedDict()
for op in user_cfg: # map ipex op_name to pt op_name
for i, op_name in enumerate(op):
Expand All @@ -94,9 +101,9 @@ def check_cfg_and_qconfig(user_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_
ori_op = (tuple(ops), unify_op_type_mapping_ipex[op_infos_from_cfgs[ops]["op_type"]])
tmp_user_cfg[((ori_op[0],), ori_op[1])] = user_cfg[op]
break
user_cfg = tmp_user_cfg
for op_name in user_cfg:
inc_op_cfg = user_cfg[op_name]

for op_name in tmp_user_cfg:
inc_op_cfg = tmp_user_cfg[op_name]
for i, name in enumerate(op_name[0]):
# to int8
ipex_op_cfg = op_infos_from_cfgs[name]
Expand Down Expand Up @@ -154,7 +161,7 @@ def check_cfg_and_qconfig(user_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_
else:
pass
cfgs[name[0]][name[1]][name[2]] = ipex_op_cfg
return cfgs, user_cfg
return cfgs, ori_user_cfg


def generate_activation_observer(scheme, algorithm, smooth_quant=False, smooth_quant_enable=False): # pragma: no cover
Expand Down Expand Up @@ -333,8 +340,8 @@ def get_quantizable_ops_recursively(model, example_inputs): # pragma: no cover
elif "method" in ipex_op_type: # "<method 'add' of 'torch._C._TensorBase' objects>"
method = ipex_op_type.split("'")[1]
op_name_info.append((module_fqn, method))
elif "Convolution" in ipex_op_type: # "Convolution_Relu"
op_name_info.append((module_fqn, "Conv2d"))
elif "_" in ipex_op_type: # "Convolution_Relu", "Linear_Relu"
op_name_info.append((module_fqn, ipex_op_type.split("_")[0]))
else:
re_flag = False
for pattern, unify_op_type in unify_op_type_mapping_ipex["re"].items():
Expand Down Expand Up @@ -394,32 +401,7 @@ def dump_model_op_stats(user_cfg):
"""
res = dict()
for k, v in user_cfg.items():
op_type_list = k[-1].split("><")
op_type = ""
for op in op_type_list:
if "class" in op:
op_type = (
op[op.rfind(".") + 1 : op.rfind("'")]
if op_type == ""
else op_type + "&" + op[op.rfind(".") + 1 : op.rfind("'")]
)
elif "method" in op:
start = op.find("'") + 1
if start > 1:
op_type = (
op[start : op.find("'", start)]
if op_type == ""
else op_type + "&" + op[start : op.find("'", start)]
)
else:
start = op.find("method") + 7
op_type = (
op[start : op.find(" ", start)]
if op_type == ""
else op_type + "&" + op[start : op.find(" ", start)]
)
else:
op_type = op if op_type == "" else op_type + "&" + op
op_type = k[1]
if op_type not in res.keys():
res[op_type] = {"INT8": 0, "BF16": 0, "FP32": 0}
if v["weight"]["dtype"] == "int8":
Expand Down
18 changes: 12 additions & 6 deletions test/3x/torch/quantization/test_static_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,18 @@ class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.fc1 = torch.nn.Linear(30, 50)
self.fc2 = torch.nn.Linear(50, 30)
self.fc3 = torch.nn.Linear(30, 5)
self.fc2 = torch.nn.Linear(50, 50)
self.fc3 = torch.nn.Linear(50, 30)
self.fc4 = torch.nn.Linear(30, 5)
self.relu = torch.nn.ReLU()

def forward(self, x):
out = self.fc1(x)
out = self.fc2(out)
out = self.relu(out)
out = self.fc3(out)
out = out + x
out = self.fc4(out)
return out

model = Model()
Expand Down Expand Up @@ -78,21 +83,22 @@ def test_static_quant_fallback(self):
assert q_model is not None, "Quantization failed!"

for op, op_info in q_model.tune_cfg[" "]["q_op_infos"].items():
if op_info["op_type"] == "<class 'torch.nn.modules.linear.Linear'>":
if op_info["op_type"] == "Linear":
dtype = q_model.tune_cfg[" "]["q_op_infos"][op]["input_tensor_infos"][0]["force_dtype"]
assert dtype == "torch.float32", "Failed to fallback linear op, please check!"

# fallback by op_name
quant_config.set_local("fc1", StaticQuantConfig(w_dtype="fp32", act_dtype="fp32"))
quant_config = get_default_static_config()
quant_config.set_local("fc2", StaticQuantConfig(w_dtype="fp32", act_dtype="fp32"))
prepared_model = prepare(fp32_model, quant_config=quant_config, example_inputs=example_inputs)
run_fn(prepared_model)
q_model = convert(prepared_model)
assert q_model is not None, "Quantization failed!"

for op, op_info in q_model.tune_cfg[" "]["q_op_infos"].items():
if op_info["fqn"] == "fc1":
if op_info["fqn"] == "fc2":
dtype = q_model.tune_cfg[" "]["q_op_infos"][op]["input_tensor_infos"][0]["force_dtype"]
assert dtype == "torch.float32", "Failed to fallback fc1 layer, please check!"
assert dtype == "torch.float32", "Failed to fallback fc2 layer, please check!"

@pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX")
@pytest.mark.parametrize(
Expand Down

0 comments on commit 70a1d50

Please sign in to comment.