Skip to content

Commit

Permalink
Add save&load API to SmoothQuant ipex model (#1673)
Browse files Browse the repository at this point in the history
Signed-off-by: Cheng, Zixuan <[email protected]>
  • Loading branch information
violetch24 authored Mar 19, 2024
1 parent e81a2dd commit 9c6102b
Show file tree
Hide file tree
Showing 7 changed files with 207 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@

from .utility import *
from .smooth_quant import smooth_quantize
from .save_load import save, load, recover_model_from_json
74 changes: 74 additions & 0 deletions neural_compressor/torch/algorithms/smooth_quant/save_load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright (c) 2024 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint:disable=import-error
import torch

try:
import intel_extension_for_pytorch as ipex
except:
assert False, "Please install IPEX for smooth quantization."

from neural_compressor.torch.algorithms.static_quant import load, save


def recover_model_from_json(model, json_file_path, example_inputs): # pragma: no cover
"""Recover ipex model from JSON file.
Args:
model (object): fp32 model need to do quantization.
json_file_path (json): configuration JSON file for ipex.
example_inputs (tuple or torch.Tensor or dict): example inputs that will be passed to the ipex function.
Returns:
(object): quantized model
"""
from torch.ao.quantization.observer import MinMaxObserver

if ipex.__version__ >= "2.1.100":
qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=0.5, act_observer=MinMaxObserver)
else:
qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=0.5, act_observer=MinMaxObserver())
if isinstance(example_inputs, dict):
model = ipex.quantization.prepare(model, qconfig, example_kwarg_inputs=example_inputs, inplace=True)
else:
model = ipex.quantization.prepare(model, qconfig, example_inputs=example_inputs, inplace=True)

model.load_qconf_summary(qconf_summary=json_file_path)
model = ipex.quantization.convert(model, inplace=True)
with torch.no_grad():
try:
if isinstance(example_inputs, dict):
# pylint: disable=E1120,E1123
model = torch.jit.trace(model, example_kwarg_inputs=example_inputs)
else:
model = torch.jit.trace(model, example_inputs)
model = torch.jit.freeze(model.eval())
except:
if isinstance(example_inputs, dict):
# pylint: disable=E1120,E1123
model = torch.jit.trace(model, example_kwarg_inputs=example_inputs, strict=False, check_trace=False)
else:
model = torch.jit.trace(model, example_inputs, strict=False)
model = torch.jit.freeze(model.eval())
if isinstance(example_inputs, dict):
model(**example_inputs)
model(**example_inputs)
elif isinstance(example_inputs, tuple) or isinstance(example_inputs, list):
model(*example_inputs)
model(*example_inputs)
else:
model(example_inputs)
model(example_inputs)
return model
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@

from .utility import *
from .static_quant import static_quantize
from .save_load import save, load
48 changes: 48 additions & 0 deletions neural_compressor/torch/algorithms/static_quant/save_load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) 2024 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint:disable=import-error
import json
import os

import torch

try:
import intel_extension_for_pytorch as ipex
except:
assert False, "Please install IPEX for static quantization."

from neural_compressor.torch.utils import QCONFIG_NAME, WEIGHT_NAME, logger


def save(model, output_dir="./saved_results"):
if not os.path.exists(output_dir):
os.mkdir(output_dir)

qmodel_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), WEIGHT_NAME)
qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), QCONFIG_NAME)
model.ori_save(qmodel_file_path)
with open(qconfig_file_path, "w") as f:
json.dump(model.tune_cfg, f, indent=4)

logger.info("Save quantized model to {}.".format(qmodel_file_path))
logger.info("Save configuration of quantized model to {}.".format(qconfig_file_path))


def load(output_dir="./saved_results"):
qmodel_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), WEIGHT_NAME)
model = torch.jit.load(qmodel_file_path)
model = torch.jit.freeze(model.eval())
logger.info("Quantized model loading successful.")
return model
8 changes: 6 additions & 2 deletions neural_compressor/torch/quantization/algorithm_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def static_quant_entry(
model: torch.nn.Module, configs_mapping: Dict[Tuple[str, callable], StaticQuantConfig], *args, **kwargs
) -> torch.nn.Module:
logger.info("Quantize model with the static quant algorithm.")
from neural_compressor.torch.algorithms.static_quant import static_quantize
from neural_compressor.torch.algorithms.static_quant import save, static_quantize

# convert the user config into internal format
quant_config_mapping = {}
Expand Down Expand Up @@ -157,6 +157,8 @@ def static_quant_entry(
inplace=inplace,
)
logger.info("Static quantization done.")
q_model.ori_save = q_model.save
q_model.save = MethodType(save, q_model)
return q_model


Expand All @@ -167,7 +169,7 @@ def smooth_quant_entry(
model: torch.nn.Module, configs_mapping: Dict[Tuple[str, callable], SmoothQuantConfig], *args, **kwargs
) -> torch.nn.Module:
logger.info("Quantize model with the smooth quant algorithm.")
from neural_compressor.torch.algorithms.smooth_quant import smooth_quantize
from neural_compressor.torch.algorithms.smooth_quant import save, smooth_quantize

# convert the user config into internal format
quant_config_mapping = {}
Expand Down Expand Up @@ -214,6 +216,8 @@ def smooth_quant_entry(
inplace=inplace,
)
logger.info("Smooth quantization done.")
q_model.ori_save = q_model.save
q_model.save = MethodType(save, q_model)
return q_model


Expand Down
37 changes: 32 additions & 5 deletions test/3x/torch/quantization/test_smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def run_fn(model):
assert torch.allclose(output1, output2, atol=2e-2), "Accuracy gap atol > 0.02 is unexpected. Please check."

@pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX")
def test_sq_ipex_save_load(self):
def test_sq_ipex_accuracy(self):
from intel_extension_for_pytorch.quantization import convert, prepare

example_inputs = torch.zeros([1, 3])
Expand All @@ -96,6 +96,7 @@ def run_fn(model):
model(example_inputs)

run_fn(user_model)
user_model.save_qconf_summary(qconf_summary="ipex.json")
with torch.no_grad():
user_model = convert(user_model.eval(), inplace=True).eval()
user_model(example_inputs)
Expand All @@ -109,12 +110,38 @@ def run_fn(model):
quant_config = get_default_sq_config()
q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs)
assert q_model is not None, "Quantization failed!"
q_model.save("saved_results")

inc_out = q_model(example_inputs)
# set a big atol to avoid random issue
assert torch.allclose(inc_out, ipex_out, atol=2e-02), "Unexpected result. Please double check."

from neural_compressor.torch.algorithms.smooth_quant import recover_model_from_json

fp32_model = copy.deepcopy(model)
ipex_model = recover_model_from_json(fp32_model, "ipex.json", example_inputs=example_inputs)
ipex_out = ipex_model(example_inputs)
assert torch.allclose(inc_out, ipex_out, atol=2e-02), "Unexpected result. Please double check."

@pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX")
def test_sq_save_load(self):
fp32_model = copy.deepcopy(model)
quant_config = get_default_sq_config()
example_inputs = torch.zeros([1, 3])
q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs)
assert q_model is not None, "Quantization failed!"
q_model.save("saved_results")
inc_out = q_model(example_inputs)
q_model.save("saved")

# load
loaded_model = torch.jit.load("saved")
from neural_compressor.torch.algorithms.smooth_quant import load, recover_model_from_json

# load using saved model
loaded_model = load("saved_results")
loaded_out = loaded_model(example_inputs)
assert torch.allclose(inc_out, ipex_out, atol=1e-05), "Unexpected result. Please double check."
# set a big atol to avoid random issue
assert torch.allclose(inc_out, loaded_out, atol=2e-02), "Unexpected result. Please double check."

# compare saved json file
loaded_model = recover_model_from_json(fp32_model, "saved_results/qconfig.json", example_inputs=example_inputs)
loaded_out = loaded_model(example_inputs)
assert torch.allclose(inc_out, loaded_out, atol=1e-05), "Unexpected result. Please double check."
45 changes: 45 additions & 0 deletions test/3x/torch/quantization/test_static_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,48 @@ def run_fn(model):
output2 = q_model(example_inputs)
# set a big atol to avoid random issue
assert torch.allclose(output1, output2, atol=2e-2), "Accuracy gap atol > 0.02 is unexpected. Please check."

@pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX")
def test_static_quant_save_load(self):
from intel_extension_for_pytorch.quantization import convert, prepare

example_inputs = torch.zeros(1, 30)
try:
qconfig = ipex.quantization.default_static_qconfig_mapping
except:
from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver, QConfig

qconfig = QConfig(
activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8),
weight=PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric),
)
user_model = copy.deepcopy(self.fp32_model)
user_model = prepare(user_model.eval(), qconfig, example_inputs=example_inputs, inplace=True)

def run_fn(model):
model(example_inputs)

run_fn(user_model)
with torch.no_grad():
user_model = convert(user_model.eval(), inplace=True).eval()
user_model(example_inputs)
user_model = torch.jit.trace(user_model.eval(), example_inputs, strict=False)
user_model = torch.jit.freeze(user_model.eval())
user_model(example_inputs)
user_model(example_inputs)
ipex_out = user_model(example_inputs)

fp32_model = copy.deepcopy(self.fp32_model)
quant_config = get_default_static_config()
q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs)
assert q_model is not None, "Quantization failed!"
inc_out = q_model(example_inputs)
# set a big atol to avoid random issue
assert torch.allclose(inc_out, ipex_out, atol=2e-02), "Unexpected result. Please double check."
q_model.save("saved_results")

from neural_compressor.torch.algorithms.static_quant import load

# load
loaded_model = load("saved_results")
assert isinstance(loaded_model, torch.jit.ScriptModule)

0 comments on commit 9c6102b

Please sign in to comment.