Skip to content

Commit

Permalink
Merge pull request #522 from PINTO0309/fix/dynamic_shape
Browse files Browse the repository at this point in the history
Fix `Gather` parameter substitution logic
  • Loading branch information
PINTO0309 authored Sep 29, 2023
2 parents 81ee38d + 5bb8d70 commit c9e578e
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 8 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,15 +255,15 @@ Video speed is adjusted approximately 50 times slower than actual speed.
$ docker run --rm -it \
-v `pwd`:/workdir \
-w /workdir \
ghcr.io/pinto0309/onnx2tf:1.18.1
ghcr.io/pinto0309/onnx2tf:1.18.2

or

# Authentication is not required for pulls from Docker Hub.
$ docker run --rm -it \
-v `pwd`:/workdir \
-w /workdir \
docker.io/pinto0309/onnx2tf:1.18.1
docker.io/pinto0309/onnx2tf:1.18.2

or

Expand Down
43 changes: 43 additions & 0 deletions json_samples/replace_retinaface_dynamic_shape.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
{
"format_version": 1,
"operations": [
{
"op_name": "/fpn/Gather",
"param_target": "inputs",
"param_name": "/fpn/Constant_output_0",
"values": 1
},
{
"op_name": "/fpn/Gather_1",
"param_target": "inputs",
"param_name": "/fpn/Constant_1_output_0",
"values": 2
},
{
"op_name": "/fpn/Concat_1",
"param_target": "outputs",
"param_name": "/fpn/Concat_1_output_0",
"post_process_transpose_perm": [0,2,3,1]
},


{
"op_name": "/fpn/Gather_2",
"param_target": "inputs",
"param_name": "/fpn/Constant_output_0",
"values": 1
},
{
"op_name": "/fpn/Gather_3",
"param_target": "inputs",
"param_name": "/fpn/Constant_1_output_0",
"values": 2
},
{
"op_name": "/fpn/Concat_3",
"param_target": "outputs",
"param_name": "/fpn/Concat_3_output_0",
"post_process_transpose_perm": [0,2,3,1]
}
]
}
2 changes: 1 addition & 1 deletion onnx2tf/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from onnx2tf.onnx2tf import convert, main

__version__ = '1.18.1'
__version__ = '1.18.2'
13 changes: 13 additions & 0 deletions onnx2tf/onnx2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,7 @@ def convert(
operations['op_name'] = operations['op_name'].replace(':','_')
if output_signaturedefs or output_integer_quantized_tflite:
operations['op_name'] = re.sub('^/', 'wa/', operations['op_name'])
operations['param_name'] = re.sub('^/', 'wa/', operations['param_name'])
except json.decoder.JSONDecodeError as ex:
error(
f'The file specified in param_replacement_file is not in JSON format. \n' +
Expand Down Expand Up @@ -694,6 +695,12 @@ def sanitizing(node):
o._name = o._name.replace(':','__')
if output_signaturedefs or output_integer_quantized_tflite:
node.name = re.sub('^/', 'wa/', node.name)
if hasattr(node, 'inputs'):
for i in node.inputs:
if hasattr(i, 'name'):
i.name = re.sub('^/', 'wa/', i.name)
elif hasattr(i, '_name'):
i._name = re.sub('^/', 'wa/', i._name)
if hasattr(node, 'outputs'):
for o in node.outputs:
if hasattr(o, 'name'):
Expand All @@ -710,6 +717,12 @@ def sanitizing(node):
o._name = o._name.replace(':','__')
if output_signaturedefs or output_integer_quantized_tflite:
node._name = re.sub('^/', 'wa/', node._name)
if hasattr(node, 'inputs'):
for i in node.inputs:
if hasattr(i, 'name'):
i.name = re.sub('^/', 'wa/', i.name)
elif hasattr(i, '_name'):
i._name = re.sub('^/', 'wa/', i._name)
if hasattr(node, 'outputs'):
for o in node.outputs:
if hasattr(o, 'name'):
Expand Down
1 change: 1 addition & 0 deletions onnx2tf/ops/Concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,7 @@ def define_concat(
value_before_transpose=tf_layers_dict[graph_node_output.name]['tf_node'],
param_target='outputs',
param_name=graph_node.outputs[0].name,
graph_node=graph_node,
**kwargs,
)

Expand Down
7 changes: 7 additions & 0 deletions onnx2tf/ops/Gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,13 @@ def make_node(
param_name=graph_node.inputs[1].name,
**kwargs,
)
if simple_indices is not None:
simple_indices = replace_parameter(
value_before_replacement=simple_indices,
param_target='inputs',
param_name=graph_node.inputs[1].name,
**kwargs,
)

# Pre-process transpose
input_tensor = pre_process_transpose(
Expand Down
31 changes: 26 additions & 5 deletions onnx2tf/utils/common_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def post_process_transpose(
value_before_transpose: Any,
param_target: str,
param_name: str,
graph_node: gs.Node = None,
**kwargs: Dict,
):
"""Add Transpose as a post-processing step for Reshape OP.
Expand All @@ -172,11 +173,31 @@ def post_process_transpose(
and op_rep_param['param_name'] == param_name:
transpose_perm = op_rep_param.get('post_process_transpose_perm', None)
if transpose_perm is not None:
transposed_value = transpose_with_flexing_deterrence(
input_tensor=value_before_transpose,
perm=transpose_perm,
**kwargs,
)
if graph_node is not None \
and graph_node.op != "Concat":
transposed_value = \
transpose_with_flexing_deterrence(
input_tensor=value_before_transpose,
perm=transpose_perm,
**kwargs,
)
else:
if value_before_transpose.shape is not None \
and len(value_before_transpose.shape) == 1 \
and value_before_transpose.shape[0] is not None:
# Gather
transposed_value = tf.gather(
params=value_before_transpose,
indices=tf.convert_to_tensor(transpose_perm)
)
else:
# Normal
transposed_value = \
transpose_with_flexing_deterrence(
input_tensor=value_before_transpose,
perm=transpose_perm,
**kwargs,
)
break
return transposed_value

Expand Down

0 comments on commit c9e578e

Please sign in to comment.