Skip to content

Commit

Permalink
Support msev2 for onnxrt adaptor (#865)
Browse files Browse the repository at this point in the history
Signed-off-by: Mengni Wang <[email protected]>
  • Loading branch information
mengniwang95 authored May 22, 2023
1 parent ad7a2ae commit 62122dd
Show file tree
Hide file tree
Showing 3 changed files with 211 additions and 1 deletion.
125 changes: 125 additions & 0 deletions neural_compressor/adaptor/onnxrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1340,6 +1340,131 @@ def save(self, model, path):
"""
model.save(os.path.join(path, "best_model.onnx"))

def get_output_op_names(self, qmodel):
"""Get the ouput ops' names."""
outputs = qmodel.output()
output_op_names = []
for output in outputs:
output_op_names.append(qmodel.output_name_to_node[output].name)
logger.debug(f"output op names: {output_op_names}")
return output_op_names

def calculate_op_sensitivity(self, model, dataloader, tune_cfg, output_op_names,
confidence_batches, fallback=True, requantize_cfgs=None):
"""Compute the op sensitivity.
The sensitivity metric is the mse between the output of the last quantized op of
the quantized model and the output of its corresponding op in the fp32 model.
1. Backup the tune cfg
2. Fallback each int8 op and compute its mse if use fallback (with 'fallback == True'),
or re-quantize each fp32 op(fallen back in the previous stage) and compute its MSE if not.
3. Sorted op name list according to its MSE
Args:
fp32_model: The fp32 model.
dataloader: the dataloader with full dataset.
tune_cfg: tuning config
fallback: denote fallback stage or re-quantize stage
requantize_cfgs: the dict of tuning configs for all re-quantizable ops
Returns:
A list of op names, sorted by its MSE sensitivity.
"""
from copy import deepcopy

fp32_op_cfg = {'activation': {'dtype': 'fp32', 'quant_mode': 'fp32'},
'weight': {'dtype': 'fp32'}}

if fallback:
ops_list = [op for op, config in tune_cfg['op'].items()
if config['activation']['quant_mode'] in ('static', 'dynamic')]
replace_cfgs = {op : fp32_op_cfg for op in tune_cfg['op']}
else:
ops_list = [op for op, config in tune_cfg['op'].items()
if config['activation']['quant_mode'] == 'fp32' and op in requantize_cfgs]
replace_cfgs = requantize_cfgs

# Step2. compute mse
mse_result = self._get_mse_order(
model, deepcopy(tune_cfg), replace_cfgs, ops_list, dataloader,
output_op_names, confidence_batches)

# Step3. sort
mse_order = [op for op, _ in sorted(mse_result.items(), key=lambda i: i[1])]
logger.debug("Dump MSE order:")
for op in mse_order:
logger.debug(f"{op}: {mse_result[op]}")
return mse_order

def _get_mse_order(self, fp32_model, tune_cfg, replace_cfgs, ops_lst, dataloader,
output_op_names, confidence_batches):
"""Compute MSE."""
op_cfg = tune_cfg['op']
mse_result = {}

fp32_output = self._inference_model_on_batches(
fp32_model, tune_cfg, dataloader, output_op_names, confidence_batches)

for op in ops_lst:
# backup and set replace tuning config
backup_cfg = op_cfg[op]
op_cfg[op] = replace_cfgs[op]

# quantize and inference the model
q_model = self.quantize(tune_cfg, fp32_model, dataloader)
q_output = self._inference_model_on_batches(
q_model, tune_cfg, dataloader, output_op_names, confidence_batches)

mse_result[op] = self._calculate_mse(fp32_output, q_output)

# recover tune_cfg
op_cfg[op] = backup_cfg

return mse_result

def _calculate_mse(self, fp32_output, q_output):
"""MSE calculation."""
result = []
for i, j in zip(fp32_output, q_output):
result.append(np.square(i - j).mean())
return np.array(result).mean()

def _inference_model_on_batches(self, model, tune_cfg, dataloader,
output_op_name, iterations):
"""Inference model on batches."""
ort_inputs = {}
predictions = []

session = ort.InferenceSession(self.work_space + 'eval.onnx',
providers=[self.backend]) if model.is_large_model else \
ort.InferenceSession(model.model.SerializeToString(),
providers=[self.backend])
inputs_names = [i.name for i in session.get_inputs()]
len_inputs = len(session.get_inputs())
for idx, (inputs, _) in enumerate(dataloader):
if idx + 1 > iterations:
break
if len_inputs == 1:
ort_inputs.update(
inputs if isinstance(inputs, dict) else {inputs_names[0]: inputs}
)
else:
assert len_inputs == len(inputs), \
'number of input tensors must align with graph inputs'

if isinstance(inputs, dict): # pragma: no cover
ort_inputs.update(inputs)
else:
for i in range(len_inputs):
# in case dataloader contains non-array input
if not isinstance(inputs[i], np.ndarray):
ort_inputs.update({inputs_names[i]: np.array(inputs[i])})
else:
ort_inputs.update({inputs_names[i]: inputs[i]})

predictions.extend(session.run(None, ort_inputs))
return predictions

@adaptor_registry
class ONNXRT_QLinearOpsAdaptor(ONNXRUNTIMEAdaptor):
Expand Down
2 changes: 1 addition & 1 deletion neural_compressor/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def eval_func(model):

if strategy_name == "mse_v2":
if not (conf.framework.startswith("tensorflow")\
or conf.framework == 'pytorch_fx'): # pragma: no cover
or conf.framework in ['pytorch_fx', 'onnxruntime']): # pragma: no cover
strategy_name = "basic"
logger.warning(f"MSE_v2 does not support {conf.framework} now, use basic instead.")
logger.warning("Only tensorflow, pytorch_fx is supported by MSE_v2 currently.")
Expand Down
85 changes: 85 additions & 0 deletions test/strategy/test_mse_v2_2.x.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,52 @@
import tensorflow as tf
import numpy as np
import torchvision
import torch
import onnx
from onnx import onnx_pb as onnx_proto
from onnx import helper, TensorProto, numpy_helper

def build_ox_model():
A = helper.make_tensor_value_info('A', TensorProto.FLOAT, [1, 5, 5])
C = helper.make_tensor_value_info('C', TensorProto.FLOAT, [1, 5, 2])
D = helper.make_tensor_value_info('D', TensorProto.FLOAT, [1, 5, 2])
H = helper.make_tensor_value_info('H', TensorProto.FLOAT, [1, 5, 2])

e_value = np.random.randint(2, size=(10)).astype(np.float32)
B_init = helper.make_tensor('B', TensorProto.FLOAT, [5, 2], e_value.reshape(10).tolist())
E_init = helper.make_tensor('E', TensorProto.FLOAT, [1, 5, 2], e_value.reshape(10).tolist())

matmul_node = onnx.helper.make_node('MatMul', ['A', 'B'], ['C'], name='Matmul')
add = onnx.helper.make_node('Add', ['C', 'E'], ['D'], name='add')

f_value = np.random.randint(2, size=(10)).astype(np.float32)
F_init = helper.make_tensor('F', TensorProto.FLOAT, [1, 5, 2], e_value.reshape(10).tolist())
add2 = onnx.helper.make_node('Add', ['D', 'F'], ['H'], name='add2')

graph = helper.make_graph([matmul_node, add, add2], 'test_graph_1', [A], [H], [B_init, E_init, F_init])
model = helper.make_model(graph)
model = helper.make_model(graph, **{'opset_imports': [helper.make_opsetid('', 13)]})
return model

def build_ox_model2():
A = helper.make_tensor_value_info('A', TensorProto.FLOAT, [1, 5, 5])
D = helper.make_tensor_value_info('D', TensorProto.FLOAT, [1, 5, 2])
H = helper.make_tensor_value_info('H', TensorProto.FLOAT, [1, 5, 2])
F = helper.make_tensor_value_info('F', TensorProto.FLOAT, [1, 5, 2])

e_value = np.random.randint(2, size=(10)).astype(np.float32)
B_init = helper.make_tensor('B', TensorProto.FLOAT, [5, 2], e_value.reshape(10).tolist())
E_init = helper.make_tensor('E', TensorProto.FLOAT, [1, 5, 2], e_value.reshape(10).tolist())

matmul_node = onnx.helper.make_node('MatMul', ['A', 'B'], ['C'], name='Matmul')
add = onnx.helper.make_node('Add', ['C', 'E'], ['D'], name='add')

add2 = onnx.helper.make_node('Add', ['D', 'F'], ['H'], name='add2')

graph = helper.make_graph([matmul_node, add, add2], 'test_graph_1', [A, F], [H], [B_init, E_init])
model = helper.make_model(graph)
model = helper.make_model(graph, **{'opset_imports': [helper.make_opsetid('', 13)]})
return model

def build_fake_model():
try:
Expand Down Expand Up @@ -47,6 +93,8 @@ class Test_MSEV2Strategy(unittest.TestCase):
def setUpClass(self):
self.tf_model = build_fake_model()
self.torch_model = torchvision.models.resnet18()
self.onnx_model = build_ox_model()
self.onnx_model2 = build_ox_model2()

@classmethod
def tearDownClass(self):
Expand Down Expand Up @@ -137,5 +185,42 @@ def fake_eval_func(model):
eval_func=fake_eval_func)
self.assertIsNotNone(q_model)

def test_mse_v2_saved_onnx(self):
i = [0]
def fake_eval_func(model):
acc_lst = [1, 1, 0, 0, 0, 0, 1, 1.1, 1.5, 1.1]
i[0] += 1
return acc_lst[i[0]]

from neural_compressor.quantization import fit
from neural_compressor.config import TuningCriterion, PostTrainingQuantConfig
from neural_compressor.data import Datasets, DATALOADERS
dataset = Datasets("onnxrt_qdq")["dummy_v2"]((5,5), (5,1))
dataloader = DATALOADERS["onnxrt_qdq"](dataset)

conf = PostTrainingQuantConfig(
approach="static",
quant_level=1,
tuning_criterion=TuningCriterion(strategy="mse_v2", max_trials=9))

q_model = fit(
model=self.onnx_model,
conf=conf,
calib_dataloader=dataloader,
eval_dataloader=dataloader,
eval_func=fake_eval_func)
self.assertIsNotNone(q_model)

i = [0]
dataset = Datasets("onnxrt_qdq")["dummy_v2"]([(5,5), (5,2)], [(5,1), (5,1)])
dataloader = DATALOADERS["onnxrt_qdq"](dataset)
q_model = fit(
model=self.onnx_model2,
conf=conf,
calib_dataloader=dataloader,
eval_dataloader=dataloader,
eval_func=fake_eval_func)
self.assertIsNotNone(q_model)

if __name__ == "__main__":
unittest.main()

0 comments on commit 62122dd

Please sign in to comment.