Skip to content

Commit

Permalink
Enhanced parameter replacement and axis transposition of Concat's 1D …
Browse files Browse the repository at this point in the history
…tensor
  • Loading branch information
PINTO0309 committed Sep 29, 2023
1 parent 0ec5156 commit 5bb8d70
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 5 deletions.
14 changes: 14 additions & 0 deletions json_samples/replace_retinaface_dynamic_shape.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@
"param_name": "/fpn/Constant_1_output_0",
"values": 2
},
{
"op_name": "/fpn/Concat_1",
"param_target": "outputs",
"param_name": "/fpn/Concat_1_output_0",
"post_process_transpose_perm": [0,2,3,1]
},


{
"op_name": "/fpn/Gather_2",
"param_target": "inputs",
Expand All @@ -24,6 +32,12 @@
"param_target": "inputs",
"param_name": "/fpn/Constant_1_output_0",
"values": 2
},
{
"op_name": "/fpn/Concat_3",
"param_target": "outputs",
"param_name": "/fpn/Concat_3_output_0",
"post_process_transpose_perm": [0,2,3,1]
}
]
}
13 changes: 13 additions & 0 deletions onnx2tf/onnx2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,7 @@ def convert(
operations['op_name'] = operations['op_name'].replace(':','_')
if output_signaturedefs or output_integer_quantized_tflite:
operations['op_name'] = re.sub('^/', 'wa/', operations['op_name'])
operations['param_name'] = re.sub('^/', 'wa/', operations['param_name'])
except json.decoder.JSONDecodeError as ex:
error(
f'The file specified in param_replacement_file is not in JSON format. \n' +
Expand Down Expand Up @@ -694,6 +695,12 @@ def sanitizing(node):
o._name = o._name.replace(':','__')
if output_signaturedefs or output_integer_quantized_tflite:
node.name = re.sub('^/', 'wa/', node.name)
if hasattr(node, 'inputs'):
for i in node.inputs:
if hasattr(i, 'name'):
i.name = re.sub('^/', 'wa/', i.name)
elif hasattr(i, '_name'):
i._name = re.sub('^/', 'wa/', i._name)
if hasattr(node, 'outputs'):
for o in node.outputs:
if hasattr(o, 'name'):
Expand All @@ -710,6 +717,12 @@ def sanitizing(node):
o._name = o._name.replace(':','__')
if output_signaturedefs or output_integer_quantized_tflite:
node._name = re.sub('^/', 'wa/', node._name)
if hasattr(node, 'inputs'):
for i in node.inputs:
if hasattr(i, 'name'):
i.name = re.sub('^/', 'wa/', i.name)
elif hasattr(i, '_name'):
i._name = re.sub('^/', 'wa/', i._name)
if hasattr(node, 'outputs'):
for o in node.outputs:
if hasattr(o, 'name'):
Expand Down
1 change: 1 addition & 0 deletions onnx2tf/ops/Concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,7 @@ def define_concat(
value_before_transpose=tf_layers_dict[graph_node_output.name]['tf_node'],
param_target='outputs',
param_name=graph_node.outputs[0].name,
graph_node=graph_node,
**kwargs,
)

Expand Down
31 changes: 26 additions & 5 deletions onnx2tf/utils/common_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def post_process_transpose(
value_before_transpose: Any,
param_target: str,
param_name: str,
graph_node: gs.Node = None,
**kwargs: Dict,
):
"""Add Transpose as a post-processing step for Reshape OP.
Expand All @@ -172,11 +173,31 @@ def post_process_transpose(
and op_rep_param['param_name'] == param_name:
transpose_perm = op_rep_param.get('post_process_transpose_perm', None)
if transpose_perm is not None:
transposed_value = transpose_with_flexing_deterrence(
input_tensor=value_before_transpose,
perm=transpose_perm,
**kwargs,
)
if graph_node is not None \
and graph_node.op != "Concat":
transposed_value = \
transpose_with_flexing_deterrence(
input_tensor=value_before_transpose,
perm=transpose_perm,
**kwargs,
)
else:
if value_before_transpose.shape is not None \
and len(value_before_transpose.shape) == 1 \
and value_before_transpose.shape[0] is not None:
# Gather
transposed_value = tf.gather(
params=value_before_transpose,
indices=tf.convert_to_tensor(transpose_perm)
)
else:
# Normal
transposed_value = \
transpose_with_flexing_deterrence(
input_tensor=value_before_transpose,
perm=transpose_perm,
**kwargs,
)
break
return transposed_value

Expand Down

0 comments on commit 5bb8d70

Please sign in to comment.