Skip to content

Commit

Permalink
Leave possibility to use a dictionnary for onnx export (#944)
Browse files Browse the repository at this point in the history
Signed-off-by: Ella Charlaix <[email protected]>
  • Loading branch information
echarlaix authored Jun 7, 2023
1 parent ee59a32 commit 17b6425
Showing 1 changed file with 11 additions and 12 deletions.
23 changes: 11 additions & 12 deletions neural_compressor/experimental/export/torch2onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,14 @@
ort = LazyImport('onnxruntime')
ortq = LazyImport('onnxruntime.quantization')

def _prepare_intputs(pt_model, input_names, example_inputs):
def _prepare_inputs(pt_model, input_names, example_inputs):
"""Prepare input_names and example_inputs."""
if input_names is None and \
(isinstance(example_inputs, dict) or isinstance(example_inputs, UserDict)): # pragma: no cover
input_names = list(example_inputs.keys())
example_inputs = list(example_inputs.values())
elif isinstance(example_inputs, dict) or isinstance(example_inputs, UserDict):
example_inputs = list(example_inputs.values())
if isinstance(example_inputs, dict) or isinstance(example_inputs, UserDict):
input_names = input_names or list(example_inputs.keys())
if isinstance(example_inputs, UserDict):
example_inputs = dict(example_inputs)
# match input_names with inspected input_order, especailly for bert in hugginface.
if input_names and len(input_names) > 1:
elif input_names and len(input_names) > 1:
import inspect
input_order = inspect.signature(pt_model.forward).parameters.keys()
flag = [name in input_order for name in input_names] # whether should be checked
Expand All @@ -53,6 +51,7 @@ def _prepare_intputs(pt_model, input_names, example_inputs):
new_example_inputs.append(example_inputs[id])
input_names = new_input_names
example_inputs = new_example_inputs
example_inputs = input2tuple(example_inputs)
return input_names, example_inputs


Expand Down Expand Up @@ -86,12 +85,12 @@ def torch_to_fp32_onnx(
assert is_int8_model(pt_model) == False, "The fp32 model is replaced during quantization. " + \
"please customize a eval_func when quantizing, if not, such as `lambda x: 1`."

input_names, example_inputs = _prepare_intputs(pt_model, input_names, example_inputs)
input_names, example_inputs = _prepare_inputs(pt_model, input_names, example_inputs)

with torch.no_grad():
torch.onnx.export(
pt_model,
input2tuple(example_inputs),
example_inputs,
save_path,
opset_version=opset_version,
input_names=input_names,
Expand Down Expand Up @@ -148,7 +147,7 @@ def torch_to_int8_onnx(
"1. export FP32 PyTorch model to FP32 ONNX model. " \
"2. use FP32 ONNX model as the input model for post training dynamic quantization."

input_names, example_inputs = _prepare_intputs(pt_model, input_names, example_inputs)
input_names, example_inputs = _prepare_inputs(pt_model, input_names, example_inputs)

def model_wrapper(model_fn):
# export doesn't support a dictionary output, so manually turn it into a tuple
Expand Down Expand Up @@ -193,4 +192,4 @@ def wrapper(*args, **kwargs):
info = "The INT8 ONNX Model exported to path: {0}".format(save_path)
logger.info("*"*len(info))
logger.info(info)
logger.info("*"*len(info))
logger.info("*"*len(info))

0 comments on commit 17b6425

Please sign in to comment.