diff --git a/README.md b/README.md index e68a18ba..91f16297 100644 --- a/README.md +++ b/README.md @@ -255,7 +255,7 @@ Video speed is adjusted approximately 50 times slower than actual speed. $ docker run --rm -it \ -v `pwd`:/workdir \ -w /workdir \ - ghcr.io/pinto0309/onnx2tf:1.18.1 + ghcr.io/pinto0309/onnx2tf:1.18.2 or @@ -263,7 +263,7 @@ Video speed is adjusted approximately 50 times slower than actual speed. $ docker run --rm -it \ -v `pwd`:/workdir \ -w /workdir \ - docker.io/pinto0309/onnx2tf:1.18.1 + docker.io/pinto0309/onnx2tf:1.18.2 or diff --git a/json_samples/replace_retinaface_dynamic_shape.json b/json_samples/replace_retinaface_dynamic_shape.json new file mode 100644 index 00000000..a04fb3b4 --- /dev/null +++ b/json_samples/replace_retinaface_dynamic_shape.json @@ -0,0 +1,43 @@ +{ + "format_version": 1, + "operations": [ + { + "op_name": "/fpn/Gather", + "param_target": "inputs", + "param_name": "/fpn/Constant_output_0", + "values": 1 + }, + { + "op_name": "/fpn/Gather_1", + "param_target": "inputs", + "param_name": "/fpn/Constant_1_output_0", + "values": 2 + }, + { + "op_name": "/fpn/Concat_1", + "param_target": "outputs", + "param_name": "/fpn/Concat_1_output_0", + "post_process_transpose_perm": [0,2,3,1] + }, + + + { + "op_name": "/fpn/Gather_2", + "param_target": "inputs", + "param_name": "/fpn/Constant_output_0", + "values": 1 + }, + { + "op_name": "/fpn/Gather_3", + "param_target": "inputs", + "param_name": "/fpn/Constant_1_output_0", + "values": 2 + }, + { + "op_name": "/fpn/Concat_3", + "param_target": "outputs", + "param_name": "/fpn/Concat_3_output_0", + "post_process_transpose_perm": [0,2,3,1] + } + ] +} \ No newline at end of file diff --git a/onnx2tf/__init__.py b/onnx2tf/__init__.py index 7362a5d0..d5e0a5dd 100644 --- a/onnx2tf/__init__.py +++ b/onnx2tf/__init__.py @@ -1,3 +1,3 @@ from onnx2tf.onnx2tf import convert, main -__version__ = '1.18.1' +__version__ = '1.18.2' diff --git a/onnx2tf/onnx2tf.py b/onnx2tf/onnx2tf.py index 5c002d74..f9b6ccd2 100644 --- a/onnx2tf/onnx2tf.py +++ b/onnx2tf/onnx2tf.py @@ -564,6 +564,7 @@ def convert( operations['op_name'] = operations['op_name'].replace(':','_') if output_signaturedefs or output_integer_quantized_tflite: operations['op_name'] = re.sub('^/', 'wa/', operations['op_name']) + operations['param_name'] = re.sub('^/', 'wa/', operations['param_name']) except json.decoder.JSONDecodeError as ex: error( f'The file specified in param_replacement_file is not in JSON format. \n' + @@ -694,6 +695,12 @@ def sanitizing(node): o._name = o._name.replace(':','__') if output_signaturedefs or output_integer_quantized_tflite: node.name = re.sub('^/', 'wa/', node.name) + if hasattr(node, 'inputs'): + for i in node.inputs: + if hasattr(i, 'name'): + i.name = re.sub('^/', 'wa/', i.name) + elif hasattr(i, '_name'): + i._name = re.sub('^/', 'wa/', i._name) if hasattr(node, 'outputs'): for o in node.outputs: if hasattr(o, 'name'): @@ -710,6 +717,12 @@ def sanitizing(node): o._name = o._name.replace(':','__') if output_signaturedefs or output_integer_quantized_tflite: node._name = re.sub('^/', 'wa/', node._name) + if hasattr(node, 'inputs'): + for i in node.inputs: + if hasattr(i, 'name'): + i.name = re.sub('^/', 'wa/', i.name) + elif hasattr(i, '_name'): + i._name = re.sub('^/', 'wa/', i._name) if hasattr(node, 'outputs'): for o in node.outputs: if hasattr(o, 'name'): diff --git a/onnx2tf/ops/Concat.py b/onnx2tf/ops/Concat.py index 0fe11d2f..9ee9917a 100644 --- a/onnx2tf/ops/Concat.py +++ b/onnx2tf/ops/Concat.py @@ -497,6 +497,7 @@ def define_concat( value_before_transpose=tf_layers_dict[graph_node_output.name]['tf_node'], param_target='outputs', param_name=graph_node.outputs[0].name, + graph_node=graph_node, **kwargs, ) diff --git a/onnx2tf/ops/Gather.py b/onnx2tf/ops/Gather.py index 008ab8c9..26c9d69d 100644 --- a/onnx2tf/ops/Gather.py +++ b/onnx2tf/ops/Gather.py @@ -158,6 +158,13 @@ def make_node( param_name=graph_node.inputs[1].name, **kwargs, ) + if simple_indices is not None: + simple_indices = replace_parameter( + value_before_replacement=simple_indices, + param_target='inputs', + param_name=graph_node.inputs[1].name, + **kwargs, + ) # Pre-process transpose input_tensor = pre_process_transpose( diff --git a/onnx2tf/utils/common_functions.py b/onnx2tf/utils/common_functions.py index 336efb2c..06bf6232 100644 --- a/onnx2tf/utils/common_functions.py +++ b/onnx2tf/utils/common_functions.py @@ -150,6 +150,7 @@ def post_process_transpose( value_before_transpose: Any, param_target: str, param_name: str, + graph_node: gs.Node = None, **kwargs: Dict, ): """Add Transpose as a post-processing step for Reshape OP. @@ -172,11 +173,31 @@ def post_process_transpose( and op_rep_param['param_name'] == param_name: transpose_perm = op_rep_param.get('post_process_transpose_perm', None) if transpose_perm is not None: - transposed_value = transpose_with_flexing_deterrence( - input_tensor=value_before_transpose, - perm=transpose_perm, - **kwargs, - ) + if graph_node is not None \ + and graph_node.op != "Concat": + transposed_value = \ + transpose_with_flexing_deterrence( + input_tensor=value_before_transpose, + perm=transpose_perm, + **kwargs, + ) + else: + if value_before_transpose.shape is not None \ + and len(value_before_transpose.shape) == 1 \ + and value_before_transpose.shape[0] is not None: + # Gather + transposed_value = tf.gather( + params=value_before_transpose, + indices=tf.convert_to_tensor(transpose_perm) + ) + else: + # Normal + transposed_value = \ + transpose_with_flexing_deterrence( + input_tensor=value_before_transpose, + perm=transpose_perm, + **kwargs, + ) break return transposed_value