Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support msev2 for onnxrt adaptor #865

Merged
merged 5 commits into from
May 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 125 additions & 0 deletions neural_compressor/adaptor/onnxrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1315,6 +1315,131 @@ def save(self, model, path):
"""
model.save(os.path.join(path, "best_model.onnx"))

def get_output_op_names(self, qmodel):
chensuyue marked this conversation as resolved.
Show resolved Hide resolved
"""Get the ouput ops' names."""
outputs = qmodel.output()
output_op_names = []
for output in outputs:
output_op_names.append(qmodel.output_name_to_node[output].name)
logger.debug(f"output op names: {output_op_names}")
return output_op_names

def calculate_op_sensitivity(self, model, dataloader, tune_cfg, output_op_names,
confidence_batches, fallback=True, requantize_cfgs=None):
"""Compute the op sensitivity.

The sensitivity metric is the mse between the output of the last quantized op of
the quantized model and the output of its corresponding op in the fp32 model.

1. Backup the tune cfg
2. Fallback each int8 op and compute its mse if use fallback (with 'fallback == True'),
or re-quantize each fp32 op(fallen back in the previous stage) and compute its MSE if not.
3. Sorted op name list according to its MSE

Args:
fp32_model: The fp32 model.
dataloader: the dataloader with full dataset.
tune_cfg: tuning config
fallback: denote fallback stage or re-quantize stage
requantize_cfgs: the dict of tuning configs for all re-quantizable ops

Returns:
A list of op names, sorted by its MSE sensitivity.
"""
from copy import deepcopy

fp32_op_cfg = {'activation': {'dtype': 'fp32', 'quant_mode': 'fp32'},
'weight': {'dtype': 'fp32'}}

if fallback:
ops_list = [op for op, config in tune_cfg['op'].items()
if config['activation']['quant_mode'] in ('static', 'dynamic')]
replace_cfgs = {op : fp32_op_cfg for op in tune_cfg['op']}
else:
ops_list = [op for op, config in tune_cfg['op'].items()
if config['activation']['quant_mode'] == 'fp32' and op in requantize_cfgs]
replace_cfgs = requantize_cfgs

# Step2. compute mse
mse_result = self._get_mse_order(
model, deepcopy(tune_cfg), replace_cfgs, ops_list, dataloader,
output_op_names, confidence_batches)

# Step3. sort
mse_order = [op for op, _ in sorted(mse_result.items(), key=lambda i: i[1])]
logger.debug("Dump MSE order:")
for op in mse_order:
logger.debug(f"{op}: {mse_result[op]}")
return mse_order

def _get_mse_order(self, fp32_model, tune_cfg, replace_cfgs, ops_lst, dataloader,
output_op_names, confidence_batches):
"""Compute MSE."""
op_cfg = tune_cfg['op']
mse_result = {}

fp32_output = self._inference_model_on_batches(
fp32_model, tune_cfg, dataloader, output_op_names, confidence_batches)

for op in ops_lst:
# backup and set replace tuning config
backup_cfg = op_cfg[op]
op_cfg[op] = replace_cfgs[op]

# quantize and inference the model
q_model = self.quantize(tune_cfg, fp32_model, dataloader)
q_output = self._inference_model_on_batches(
q_model, tune_cfg, dataloader, output_op_names, confidence_batches)

mse_result[op] = self._calculate_mse(fp32_output, q_output)

# recover tune_cfg
op_cfg[op] = backup_cfg

return mse_result

def _calculate_mse(self, fp32_output, q_output):
"""MSE calculation."""
result = []
for i, j in zip(fp32_output, q_output):
result.append(np.square(i - j).mean())
return np.array(result).mean()

def _inference_model_on_batches(self, model, tune_cfg, dataloader,
output_op_name, iterations):
"""Inference model on batches."""
ort_inputs = {}
predictions = []

session = ort.InferenceSession(self.work_space + 'eval.onnx',
providers=[self.backend]) if model.is_large_model else \
ort.InferenceSession(model.model.SerializeToString(),
providers=[self.backend])
inputs_names = [i.name for i in session.get_inputs()]
len_inputs = len(session.get_inputs())
for idx, (inputs, _) in enumerate(dataloader):
if idx + 1 > iterations:
break
if len_inputs == 1:
ort_inputs.update(
inputs if isinstance(inputs, dict) else {inputs_names[0]: inputs}
)
else:
assert len_inputs == len(inputs), \
'number of input tensors must align with graph inputs'

if isinstance(inputs, dict): # pragma: no cover
ort_inputs.update(inputs)
else:
for i in range(len_inputs):
# in case dataloader contains non-array input
if not isinstance(inputs[i], np.ndarray):
ort_inputs.update({inputs_names[i]: np.array(inputs[i])})
else:
ort_inputs.update({inputs_names[i]: inputs[i]})

predictions.extend(session.run(None, ort_inputs))
return predictions

@adaptor_registry
class ONNXRT_QLinearOpsAdaptor(ONNXRUNTIMEAdaptor):
Expand Down
2 changes: 1 addition & 1 deletion neural_compressor/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def eval_func(model):

if strategy_name == "mse_v2":
if not (conf.framework.startswith("tensorflow")\
or conf.framework == 'pytorch_fx'): # pragma: no cover
or conf.framework in ['pytorch_fx', 'onnxruntime']): # pragma: no cover
strategy_name = "basic"
logger.warning(f"MSE_v2 does not support {conf.framework} now, use basic instead.")
logger.warning("Only tensorflow, pytorch_fx is supported by MSE_v2 currently.")
Expand Down
85 changes: 85 additions & 0 deletions test/strategy/test_mse_v2_2.x.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,52 @@
import tensorflow as tf
import numpy as np
import torchvision
import torch
import onnx
from onnx import onnx_pb as onnx_proto
from onnx import helper, TensorProto, numpy_helper

def build_ox_model():
A = helper.make_tensor_value_info('A', TensorProto.FLOAT, [1, 5, 5])
C = helper.make_tensor_value_info('C', TensorProto.FLOAT, [1, 5, 2])
D = helper.make_tensor_value_info('D', TensorProto.FLOAT, [1, 5, 2])
H = helper.make_tensor_value_info('H', TensorProto.FLOAT, [1, 5, 2])

e_value = np.random.randint(2, size=(10)).astype(np.float32)
B_init = helper.make_tensor('B', TensorProto.FLOAT, [5, 2], e_value.reshape(10).tolist())
E_init = helper.make_tensor('E', TensorProto.FLOAT, [1, 5, 2], e_value.reshape(10).tolist())

matmul_node = onnx.helper.make_node('MatMul', ['A', 'B'], ['C'], name='Matmul')
add = onnx.helper.make_node('Add', ['C', 'E'], ['D'], name='add')

f_value = np.random.randint(2, size=(10)).astype(np.float32)
F_init = helper.make_tensor('F', TensorProto.FLOAT, [1, 5, 2], e_value.reshape(10).tolist())
add2 = onnx.helper.make_node('Add', ['D', 'F'], ['H'], name='add2')

graph = helper.make_graph([matmul_node, add, add2], 'test_graph_1', [A], [H], [B_init, E_init, F_init])
model = helper.make_model(graph)
model = helper.make_model(graph, **{'opset_imports': [helper.make_opsetid('', 13)]})
return model

def build_ox_model2():
A = helper.make_tensor_value_info('A', TensorProto.FLOAT, [1, 5, 5])
D = helper.make_tensor_value_info('D', TensorProto.FLOAT, [1, 5, 2])
H = helper.make_tensor_value_info('H', TensorProto.FLOAT, [1, 5, 2])
F = helper.make_tensor_value_info('F', TensorProto.FLOAT, [1, 5, 2])

e_value = np.random.randint(2, size=(10)).astype(np.float32)
B_init = helper.make_tensor('B', TensorProto.FLOAT, [5, 2], e_value.reshape(10).tolist())
E_init = helper.make_tensor('E', TensorProto.FLOAT, [1, 5, 2], e_value.reshape(10).tolist())

matmul_node = onnx.helper.make_node('MatMul', ['A', 'B'], ['C'], name='Matmul')
add = onnx.helper.make_node('Add', ['C', 'E'], ['D'], name='add')

add2 = onnx.helper.make_node('Add', ['D', 'F'], ['H'], name='add2')

graph = helper.make_graph([matmul_node, add, add2], 'test_graph_1', [A, F], [H], [B_init, E_init])
model = helper.make_model(graph)
model = helper.make_model(graph, **{'opset_imports': [helper.make_opsetid('', 13)]})
return model

def build_fake_model():
try:
Expand Down Expand Up @@ -47,6 +93,8 @@ class Test_MSEV2Strategy(unittest.TestCase):
def setUpClass(self):
self.tf_model = build_fake_model()
self.torch_model = torchvision.models.resnet18()
self.onnx_model = build_ox_model()
self.onnx_model2 = build_ox_model2()

@classmethod
def tearDownClass(self):
Expand Down Expand Up @@ -137,5 +185,42 @@ def fake_eval_func(model):
eval_func=fake_eval_func)
self.assertIsNotNone(q_model)

def test_mse_v2_saved_onnx(self):
i = [0]
def fake_eval_func(model):
acc_lst = [1, 1, 0, 0, 0, 0, 1, 1.1, 1.5, 1.1]
i[0] += 1
return acc_lst[i[0]]

from neural_compressor.quantization import fit
from neural_compressor.config import TuningCriterion, PostTrainingQuantConfig
from neural_compressor.data import Datasets, DATALOADERS
dataset = Datasets("onnxrt_qdq")["dummy_v2"]((5,5), (5,1))
dataloader = DATALOADERS["onnxrt_qdq"](dataset)

conf = PostTrainingQuantConfig(
approach="static",
quant_level=1,
tuning_criterion=TuningCriterion(strategy="mse_v2", max_trials=9))

q_model = fit(
model=self.onnx_model,
conf=conf,
calib_dataloader=dataloader,
eval_dataloader=dataloader,
eval_func=fake_eval_func)
self.assertIsNotNone(q_model)

i = [0]
dataset = Datasets("onnxrt_qdq")["dummy_v2"]([(5,5), (5,2)], [(5,1), (5,1)])
dataloader = DATALOADERS["onnxrt_qdq"](dataset)
q_model = fit(
model=self.onnx_model2,
conf=conf,
calib_dataloader=dataloader,
eval_dataloader=dataloader,
eval_func=fake_eval_func)
self.assertIsNotNone(q_model)

if __name__ == "__main__":
unittest.main()