Skip to content

Commit

Permalink
Merge pull request #152 from PINTO0309/sofmax_fix
Browse files Browse the repository at this point in the history
`Softmax` Detect conversion errors in axis and identify the axis with the smallest possible error and replace it.
  • Loading branch information
PINTO0309 authored Jan 28, 2023
2 parents c7f6d9b + a8b6f17 commit 8e0b913
Show file tree
Hide file tree
Showing 7 changed files with 145 additions and 74 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/test-models.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ jobs:
pip install tensorflow==2.10.0
pip install nvidia-pyindex
pip install onnx-graphsurgeon
pip install onnxruntime
pip install protobuf==3.20.3
pip install onnxsim
pip install sng4onnx
pip install -e .
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ Video speed is adjusted approximately 50 times slower than actual speed.
$ docker run --rm -it \
-v `pwd`:/workdir \
-w /workdir \
ghcr.io/pinto0309/onnx2tf:1.5.32
ghcr.io/pinto0309/onnx2tf:1.5.33
or
Expand Down
54 changes: 0 additions & 54 deletions json_samples/replace_MobileFormer-e9.json
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,6 @@
"param_name": "perm",
"values": [1,2,0,3]
},
{
"op_name": "Softmax_126",
"param_target": "attributes",
"param_name": "axis",
"values": 3
},
{
"op_name": "Transpose_129",
"param_target": "attributes",
Expand All @@ -73,12 +67,6 @@
"param_name": "perm",
"values": [1,2,0,3]
},
{
"op_name": "Softmax_297",
"param_target": "attributes",
"param_name": "axis",
"values": 3
},
{
"op_name": "Transpose_300",
"param_target": "attributes",
Expand All @@ -91,12 +79,6 @@
"param_name": "perm",
"values": [1,2,0,3]
},
{
"op_name": "Softmax_468",
"param_target": "attributes",
"param_name": "axis",
"values": 3
},
{
"op_name": "Transpose_471",
"param_target": "attributes",
Expand All @@ -109,12 +91,6 @@
"param_name": "perm",
"values": [1,2,0,3]
},
{
"op_name": "Softmax_638",
"param_target": "attributes",
"param_name": "axis",
"values": 3
},
{
"op_name": "Transpose_641",
"param_target": "attributes",
Expand All @@ -127,12 +103,6 @@
"param_name": "perm",
"values": [1,2,0,3]
},
{
"op_name": "Softmax_809",
"param_target": "attributes",
"param_name": "axis",
"values": 3
},
{
"op_name": "Transpose_812",
"param_target": "attributes",
Expand All @@ -145,12 +115,6 @@
"param_name": "perm",
"values": [1,2,0,3]
},
{
"op_name": "Softmax_979",
"param_target": "attributes",
"param_name": "axis",
"values": 3
},
{
"op_name": "Transpose_982",
"param_target": "attributes",
Expand All @@ -163,12 +127,6 @@
"param_name": "perm",
"values": [1,2,0,3]
},
{
"op_name": "Softmax_1148",
"param_target": "attributes",
"param_name": "axis",
"values": 3
},
{
"op_name": "Transpose_1151",
"param_target": "attributes",
Expand All @@ -181,12 +139,6 @@
"param_name": "perm",
"values": [1,2,0,3]
},
{
"op_name": "Softmax_1319",
"param_target": "attributes",
"param_name": "axis",
"values": 3
},
{
"op_name": "Transpose_1322",
"param_target": "attributes",
Expand All @@ -199,12 +151,6 @@
"param_name": "perm",
"values": [1,2,0,3]
},
{
"op_name": "Softmax_1449",
"param_target": "attributes",
"param_name": "axis",
"values": 3
},
{
"op_name": "Transpose_1452",
"param_target": "attributes",
Expand Down
2 changes: 1 addition & 1 deletion onnx2tf/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from onnx2tf.onnx2tf import convert, main

__version__ = '1.5.32'
__version__ = '1.5.33'
24 changes: 11 additions & 13 deletions onnx2tf/onnx2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import sys
sys.setrecursionlimit(10000)
import ast
import copy
import json
import logging
import warnings
Expand Down Expand Up @@ -45,6 +44,8 @@
onnx_tf_tensor_validation,
weights_export,
download_test_image_data,
get_tf_model_inputs,
get_tf_model_outputs,
)
from onnx2tf.utils.colors import Color
from sng4onnx import generate as op_name_auto_generate
Expand Down Expand Up @@ -599,7 +600,7 @@ def convert(
output_op_name \
for output_op_name in output_names_to_interrupt_model_conversion
]
onnx_graph = extraction(
onnx_graph: onnx.ModelProto = extraction(
input_op_names=[graph_input.name for graph_input in graph.inputs],
output_op_names=output_names,
onnx_graph=onnx_graph,
Expand Down Expand Up @@ -720,16 +721,13 @@ def convert(
)

# List "optype"="Input"
inputs = [
layer_info['op'] \
for layer_info in tf_layers_dict.values() \
if layer_info['optype'] == 'Input'
]
outputs = [
layer_info['tf_node'] \
for opname, layer_info in tf_layers_dict.items() \
if opname in output_names
]
inputs = get_tf_model_inputs(
tf_layers_dict=tf_layers_dict,
)
outputs = get_tf_model_outputs(
tf_layers_dict=tf_layers_dict,
output_names=output_names,
)

model = tf.keras.Model(inputs=inputs, outputs=outputs)
if not non_verbose:
Expand Down Expand Up @@ -1177,7 +1175,7 @@ def representative_dataset_gen():
equal_nan=True,
)
check_results: Dict[str, List[np.ndarray, bool]]
check_results: Dict[str, List[np.ndarray, int, float|int]]
{
onnx_output_name: [
onnx_tensor,
Expand Down
76 changes: 75 additions & 1 deletion onnx2tf/ops/Softmax.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import sys
import random
random.seed(0)
import numpy as np
np.random.seed(0)
import onnx
import tensorflow as tf
import onnx_graphsurgeon as gs
from onnx2tf.utils.common_functions import (
Expand All @@ -14,7 +16,12 @@
get_replacement_parameter,
pre_process_transpose,
post_process_transpose,
dummy_onnx_inference,
dummy_tf_inference,
get_tf_model_inputs,
onnx_tf_tensor_validation,
)
from typing import List, Any, Dict


@print_node_info
Expand Down Expand Up @@ -98,11 +105,78 @@ def make_node(
and before_trans_shape != after_trans_shape:
tf_layers_dict[graph_node_output.name].pop('nhwc')

# Detect conversion errors in axis and identify the axis
# with the smallest possible error and replace it.
# ONNX dummy inference
min_abs_err = sys.maxsize
min_abs_err_axis: int = axis
try:
onnx_graph: onnx.ModelProto = kwargs['onnx_graph']
check_axes = reversed([idx for idx in range(tensor_rank)])
dummy_onnx_outputs: List[np.ndarray] = dummy_onnx_inference(
onnx_graph=onnx_graph,
output_names=[graph_node_output.name],
)
del onnx_graph
# Search for the axis with the smallest error
tf_model_inputs = get_tf_model_inputs(
tf_layers_dict=tf_layers_dict,
)
for check_axis in check_axes:
# TF dummy inference
val_model = tf.keras.Model(
inputs=tf_model_inputs,
outputs=[
tf.nn.softmax(
logits=input_tensor,
axis=check_axis,
name=graph_node.name,
)
],
)
tf_tensor_infos: Dict[Any] = dummy_tf_inference(
model=val_model,
inputs=tf_model_inputs,
)
del val_model
# Validation
onnx_tensor_infos = {
output_name: dummy_onnx_output \
for output_name, dummy_onnx_output in zip([graph_node_output.name], dummy_onnx_outputs)
}
onnx_tf_output_pairs = {
(oi[0], ti[0]): (oi[1], ti[1]) \
for oi, ti in zip(onnx_tensor_infos.items(), tf_tensor_infos.items())
}
"""
check_results: Dict[str, List[np.ndarray, int, float|int]]
{
onnx_output_name: [
onnx_tensor,
matched_flg, <--- 0: Unmatched, 1: Matched, 2: Skipped (Deleted or Shape Unmatched)
max_abs_err,
]
}
"""
check_results = onnx_tf_tensor_validation(
output_pairs=onnx_tf_output_pairs,
rtol=0.0,
atol=0.0,
)
result_err = sum([val[2] for val in check_results.values()])
if result_err < min_abs_err:
min_abs_err = result_err
min_abs_err_axis = check_axis
if min_abs_err < 1e-3:
break
except tf.errors.InvalidArgumentError as ex:
pass

# Generation of TF OP
tf_layers_dict[graph_node_output.name]['tf_node'] = \
tf.nn.softmax(
logits=input_tensor,
axis=axis,
axis=min_abs_err_axis,
name=graph_node.name,
)

Expand Down
59 changes: 55 additions & 4 deletions onnx2tf/utils/common_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2812,7 +2812,7 @@ def dummy_onnx_inference(
# reduce all axes except batch axis
gs_graph.nodes[i].attrs['axes'] = [
i for i in range(1, len(gs_graph.nodes[i].inputs[0].shape))
]
] if len(gs_graph.nodes[i].inputs[0].shape) > 1 else [0]

# instead, modify onnx graph manually
gs_graph.outputs = []
Expand Down Expand Up @@ -2965,7 +2965,7 @@ def onnx_tf_tensor_validation(
Returns
----------
check_results: Dict[str, List[np.ndarray, int]]
check_results: Dict[str, List[np.ndarray, int, float|int]]
Tensor Comparison Results
{
onnx_output_name: [
Expand Down Expand Up @@ -3273,8 +3273,7 @@ def y_tile(


def calc_tf_pooling_pads(input_shape, kernel, strides, func):
"""
Calculate how much padding is needed for tensorflow mode 'SAME'
"""Calculate how much padding is needed for tensorflow mode 'SAME'.
Parameters
----------
Expand Down Expand Up @@ -3316,3 +3315,55 @@ def calc_tf_pooling_pads(input_shape, kernel, strides, func):
same_pads.extend(same_pads_end)

return same_pads


def get_tf_model_inputs(
*,
tf_layers_dict: dict,
) -> List[Any]:
"""Get a list of input OPs for a TensorFlow model.
Parameters
----------
tf_layers_dict: dict
Graph structure of TensorFlow models
Returns
-------
tf_model_inputs: List
List of input OPs for TensorFlow model
"""
tf_model_inputs = [
layer_info['op'] \
for layer_info in tf_layers_dict.values() \
if layer_info['optype'] == 'Input'
]
return tf_model_inputs


def get_tf_model_outputs(
*,
tf_layers_dict: dict,
output_names: List[str],
) -> List[Any]:
"""Get a list of output OPs for a TensorFlow model.
Parameters
----------
tf_layers_dict: dict
Graph structure of TensorFlow models
output_names: List[str]
Name of ONNX output OP to be extracted
Returns
-------
tf_model_outputs: List
List of output OPs for TensorFlow model
"""
tf_model_outputs = [
layer_info['tf_node'] \
for opname, layer_info in tf_layers_dict.items() \
if opname in output_names
]
return tf_model_outputs

0 comments on commit 8e0b913

Please sign in to comment.