Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Softmax Detect conversion errors in axis and identify the axis with the smallest possible error and replace it. #152

Merged
merged 4 commits into from
Jan 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/test-models.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ jobs:
pip install tensorflow==2.10.0
pip install nvidia-pyindex
pip install onnx-graphsurgeon
pip install onnxruntime
pip install protobuf==3.20.3
pip install onnxsim
pip install sng4onnx
pip install -e .
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ Video speed is adjusted approximately 50 times slower than actual speed.
$ docker run --rm -it \
-v `pwd`:/workdir \
-w /workdir \
ghcr.io/pinto0309/onnx2tf:1.5.32
ghcr.io/pinto0309/onnx2tf:1.5.33

or

Expand Down
54 changes: 0 additions & 54 deletions json_samples/replace_MobileFormer-e9.json
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,6 @@
"param_name": "perm",
"values": [1,2,0,3]
},
{
"op_name": "Softmax_126",
"param_target": "attributes",
"param_name": "axis",
"values": 3
},
{
"op_name": "Transpose_129",
"param_target": "attributes",
Expand All @@ -73,12 +67,6 @@
"param_name": "perm",
"values": [1,2,0,3]
},
{
"op_name": "Softmax_297",
"param_target": "attributes",
"param_name": "axis",
"values": 3
},
{
"op_name": "Transpose_300",
"param_target": "attributes",
Expand All @@ -91,12 +79,6 @@
"param_name": "perm",
"values": [1,2,0,3]
},
{
"op_name": "Softmax_468",
"param_target": "attributes",
"param_name": "axis",
"values": 3
},
{
"op_name": "Transpose_471",
"param_target": "attributes",
Expand All @@ -109,12 +91,6 @@
"param_name": "perm",
"values": [1,2,0,3]
},
{
"op_name": "Softmax_638",
"param_target": "attributes",
"param_name": "axis",
"values": 3
},
{
"op_name": "Transpose_641",
"param_target": "attributes",
Expand All @@ -127,12 +103,6 @@
"param_name": "perm",
"values": [1,2,0,3]
},
{
"op_name": "Softmax_809",
"param_target": "attributes",
"param_name": "axis",
"values": 3
},
{
"op_name": "Transpose_812",
"param_target": "attributes",
Expand All @@ -145,12 +115,6 @@
"param_name": "perm",
"values": [1,2,0,3]
},
{
"op_name": "Softmax_979",
"param_target": "attributes",
"param_name": "axis",
"values": 3
},
{
"op_name": "Transpose_982",
"param_target": "attributes",
Expand All @@ -163,12 +127,6 @@
"param_name": "perm",
"values": [1,2,0,3]
},
{
"op_name": "Softmax_1148",
"param_target": "attributes",
"param_name": "axis",
"values": 3
},
{
"op_name": "Transpose_1151",
"param_target": "attributes",
Expand All @@ -181,12 +139,6 @@
"param_name": "perm",
"values": [1,2,0,3]
},
{
"op_name": "Softmax_1319",
"param_target": "attributes",
"param_name": "axis",
"values": 3
},
{
"op_name": "Transpose_1322",
"param_target": "attributes",
Expand All @@ -199,12 +151,6 @@
"param_name": "perm",
"values": [1,2,0,3]
},
{
"op_name": "Softmax_1449",
"param_target": "attributes",
"param_name": "axis",
"values": 3
},
{
"op_name": "Transpose_1452",
"param_target": "attributes",
Expand Down
2 changes: 1 addition & 1 deletion onnx2tf/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from onnx2tf.onnx2tf import convert, main

__version__ = '1.5.32'
__version__ = '1.5.33'
24 changes: 11 additions & 13 deletions onnx2tf/onnx2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import sys
sys.setrecursionlimit(10000)
import ast
import copy
import json
import logging
import warnings
Expand Down Expand Up @@ -45,6 +44,8 @@
onnx_tf_tensor_validation,
weights_export,
download_test_image_data,
get_tf_model_inputs,
get_tf_model_outputs,
)
from onnx2tf.utils.colors import Color
from sng4onnx import generate as op_name_auto_generate
Expand Down Expand Up @@ -599,7 +600,7 @@ def convert(
output_op_name \
for output_op_name in output_names_to_interrupt_model_conversion
]
onnx_graph = extraction(
onnx_graph: onnx.ModelProto = extraction(
input_op_names=[graph_input.name for graph_input in graph.inputs],
output_op_names=output_names,
onnx_graph=onnx_graph,
Expand Down Expand Up @@ -720,16 +721,13 @@ def convert(
)

# List "optype"="Input"
inputs = [
layer_info['op'] \
for layer_info in tf_layers_dict.values() \
if layer_info['optype'] == 'Input'
]
outputs = [
layer_info['tf_node'] \
for opname, layer_info in tf_layers_dict.items() \
if opname in output_names
]
inputs = get_tf_model_inputs(
tf_layers_dict=tf_layers_dict,
)
outputs = get_tf_model_outputs(
tf_layers_dict=tf_layers_dict,
output_names=output_names,
)

model = tf.keras.Model(inputs=inputs, outputs=outputs)
if not non_verbose:
Expand Down Expand Up @@ -1177,7 +1175,7 @@ def representative_dataset_gen():
equal_nan=True,
)

check_results: Dict[str, List[np.ndarray, bool]]
check_results: Dict[str, List[np.ndarray, int, float|int]]
{
onnx_output_name: [
onnx_tensor,
Expand Down
76 changes: 75 additions & 1 deletion onnx2tf/ops/Softmax.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import sys
import random
random.seed(0)
import numpy as np
np.random.seed(0)
import onnx
import tensorflow as tf
import onnx_graphsurgeon as gs
from onnx2tf.utils.common_functions import (
Expand All @@ -14,7 +16,12 @@
get_replacement_parameter,
pre_process_transpose,
post_process_transpose,
dummy_onnx_inference,
dummy_tf_inference,
get_tf_model_inputs,
onnx_tf_tensor_validation,
)
from typing import List, Any, Dict


@print_node_info
Expand Down Expand Up @@ -98,11 +105,78 @@ def make_node(
and before_trans_shape != after_trans_shape:
tf_layers_dict[graph_node_output.name].pop('nhwc')

# Detect conversion errors in axis and identify the axis
# with the smallest possible error and replace it.
# ONNX dummy inference
min_abs_err = sys.maxsize
min_abs_err_axis: int = axis
try:
onnx_graph: onnx.ModelProto = kwargs['onnx_graph']
check_axes = reversed([idx for idx in range(tensor_rank)])
dummy_onnx_outputs: List[np.ndarray] = dummy_onnx_inference(
onnx_graph=onnx_graph,
output_names=[graph_node_output.name],
)
del onnx_graph
# Search for the axis with the smallest error
tf_model_inputs = get_tf_model_inputs(
tf_layers_dict=tf_layers_dict,
)
for check_axis in check_axes:
# TF dummy inference
val_model = tf.keras.Model(
inputs=tf_model_inputs,
outputs=[
tf.nn.softmax(
logits=input_tensor,
axis=check_axis,
name=graph_node.name,
)
],
)
tf_tensor_infos: Dict[Any] = dummy_tf_inference(
model=val_model,
inputs=tf_model_inputs,
)
del val_model
# Validation
onnx_tensor_infos = {
output_name: dummy_onnx_output \
for output_name, dummy_onnx_output in zip([graph_node_output.name], dummy_onnx_outputs)
}
onnx_tf_output_pairs = {
(oi[0], ti[0]): (oi[1], ti[1]) \
for oi, ti in zip(onnx_tensor_infos.items(), tf_tensor_infos.items())
}
"""
check_results: Dict[str, List[np.ndarray, int, float|int]]
{
onnx_output_name: [
onnx_tensor,
matched_flg, <--- 0: Unmatched, 1: Matched, 2: Skipped (Deleted or Shape Unmatched)
max_abs_err,
]
}
"""
check_results = onnx_tf_tensor_validation(
output_pairs=onnx_tf_output_pairs,
rtol=0.0,
atol=0.0,
)
result_err = sum([val[2] for val in check_results.values()])
if result_err < min_abs_err:
min_abs_err = result_err
min_abs_err_axis = check_axis
if min_abs_err < 1e-3:
break
except tf.errors.InvalidArgumentError as ex:
pass

# Generation of TF OP
tf_layers_dict[graph_node_output.name]['tf_node'] = \
tf.nn.softmax(
logits=input_tensor,
axis=axis,
axis=min_abs_err_axis,
name=graph_node.name,
)

Expand Down
59 changes: 55 additions & 4 deletions onnx2tf/utils/common_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2812,7 +2812,7 @@ def dummy_onnx_inference(
# reduce all axes except batch axis
gs_graph.nodes[i].attrs['axes'] = [
i for i in range(1, len(gs_graph.nodes[i].inputs[0].shape))
]
] if len(gs_graph.nodes[i].inputs[0].shape) > 1 else [0]

# instead, modify onnx graph manually
gs_graph.outputs = []
Expand Down Expand Up @@ -2965,7 +2965,7 @@ def onnx_tf_tensor_validation(

Returns
----------
check_results: Dict[str, List[np.ndarray, int]]
check_results: Dict[str, List[np.ndarray, int, float|int]]
Tensor Comparison Results
{
onnx_output_name: [
Expand Down Expand Up @@ -3273,8 +3273,7 @@ def y_tile(


def calc_tf_pooling_pads(input_shape, kernel, strides, func):
"""
Calculate how much padding is needed for tensorflow mode 'SAME'
"""Calculate how much padding is needed for tensorflow mode 'SAME'.

Parameters
----------
Expand Down Expand Up @@ -3316,3 +3315,55 @@ def calc_tf_pooling_pads(input_shape, kernel, strides, func):
same_pads.extend(same_pads_end)

return same_pads


def get_tf_model_inputs(
*,
tf_layers_dict: dict,
) -> List[Any]:
"""Get a list of input OPs for a TensorFlow model.

Parameters
----------
tf_layers_dict: dict
Graph structure of TensorFlow models

Returns
-------
tf_model_inputs: List
List of input OPs for TensorFlow model
"""
tf_model_inputs = [
layer_info['op'] \
for layer_info in tf_layers_dict.values() \
if layer_info['optype'] == 'Input'
]
return tf_model_inputs


def get_tf_model_outputs(
*,
tf_layers_dict: dict,
output_names: List[str],
) -> List[Any]:
"""Get a list of output OPs for a TensorFlow model.

Parameters
----------
tf_layers_dict: dict
Graph structure of TensorFlow models

output_names: List[str]
Name of ONNX output OP to be extracted

Returns
-------
tf_model_outputs: List
List of output OPs for TensorFlow model
"""
tf_model_outputs = [
layer_info['tf_node'] \
for opname, layer_info in tf_layers_dict.items() \
if opname in output_names
]
return tf_model_outputs