Skip to content

Commit

Permalink
add more ut to smooth_quant for 3.x API (#1657)
Browse files Browse the repository at this point in the history
Signed-off-by: Cheng, Zixuan <[email protected]>
  • Loading branch information
violetch24 authored Mar 11, 2024
1 parent c4de198 commit c214f90
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,8 @@ def qdq_quantize(
if ipex_ver.release > Version("2.1.0").release:
update_sq_scale(ipex_config_path, smoothquant_scale_info)
model.load_qconf_summary(qconf_summary=ipex_config_path)
_ipex_post_quant_process(model, example_inputs, inplace=inplace)
model.save_qconf_summary(qconf_summary=ipex_config_path)
model = _ipex_post_quant_process(model, example_inputs, inplace=inplace)

with open(ipex_config_path, "r") as f:
model.tune_cfg = json.load(f)
Expand Down
36 changes: 36 additions & 0 deletions test/3x/torch/quantization/test_smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,39 @@ def run_fn(model):
output1 = fp32_model(example_inputs)
output2 = q_model(example_inputs)
assert torch.allclose(output1, output2, atol=2e-2), "Accuracy gap atol > 0.02 is unexpected. Please check."

@pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX")
def test_sq_ipex_save_load(self):
from intel_extension_for_pytorch.quantization import convert, prepare

example_inputs = torch.zeros([1, 3])
qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=0.5)
user_model = copy.deepcopy(model)
user_model = prepare(user_model.eval(), qconfig, example_inputs=example_inputs, inplace=True)

def run_fn(model):
model(example_inputs)

run_fn(user_model)
with torch.no_grad():
user_model = convert(user_model.eval(), inplace=True).eval()
user_model(example_inputs)
user_model = torch.jit.trace(user_model.eval(), example_inputs, strict=False)
user_model = torch.jit.freeze(user_model.eval())
user_model(example_inputs)
user_model(example_inputs)
ipex_out = user_model(example_inputs)

fp32_model = copy.deepcopy(model)
quant_config = get_default_sq_config()
q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs)
assert q_model is not None, "Quantization failed!"
inc_out = q_model(example_inputs)
q_model.save("saved")

# load
loaded_model = torch.jit.load("saved")
loaded_out = loaded_model(example_inputs)
assert torch.allclose(inc_out, ipex_out, atol=1e-05), "Unexpected result. Please double check."

assert torch.allclose(inc_out, loaded_out, atol=1e-05), "Unexpected result. Please double check."

0 comments on commit c214f90

Please sign in to comment.