From eeb783512867f40a7757169d2ad4e46637885af1 Mon Sep 17 00:00:00 2001 From: Anastasia Popova Date: Fri, 12 Feb 2021 22:44:21 +0300 Subject: [PATCH] Result rename operation (#4242) * Added result rename operation * Optimize imports * Added ResultRename to package_BOM * ResultRename moved to the end of back phase, code refactoring * Revert incorrect changes * Optimize imports * Added comments and optimized imports. --- .../src/readers/ir_reader/ie_ir_parser.cpp | 2 +- model-optimizer/automation/package_BOM.txt | 1 + .../extensions/back/ResultRename.py | 38 +++++++++++++++ .../extensions/back/ResultRename_test.py | 48 +++++++++++++++++++ .../mo/back/ie_ir_ver_2/emitter.py | 4 +- model-optimizer/mo/graph/port.py | 6 +-- model-optimizer/mo/graph/port_test.py | 36 +++++++------- model-optimizer/mo/pipeline/common.py | 3 ++ 8 files changed, 113 insertions(+), 25 deletions(-) create mode 100644 model-optimizer/extensions/back/ResultRename.py create mode 100644 model-optimizer/extensions/back/ResultRename_test.py diff --git a/inference-engine/src/readers/ir_reader/ie_ir_parser.cpp b/inference-engine/src/readers/ir_reader/ie_ir_parser.cpp index 6eb27e4fdd7320..c6151ad577e67e 100644 --- a/inference-engine/src/readers/ir_reader/ie_ir_parser.cpp +++ b/inference-engine/src/readers/ir_reader/ie_ir_parser.cpp @@ -392,7 +392,7 @@ std::shared_ptr V10Parser::XmlDeserializer::parse_function(con // Read all layers and store their parameters in params map FOREACH_CHILD(node, root.child("layers"), "layer") { auto node_param = parseGenericParams(node); - if (opName.find(node_param.name) != opName.end()) + if (opName.find(node_param.name) != opName.end() && node_param.type != "Result") THROW_IE_EXCEPTION << "Invalid IR! " << node_param.name << " name is not unique!"; opName.insert(node_param.name); params[node_param.layerId] = {node, node_param}; diff --git a/model-optimizer/automation/package_BOM.txt b/model-optimizer/automation/package_BOM.txt index 33b38ca4a495f4..a54be463cbf70b 100644 --- a/model-optimizer/automation/package_BOM.txt +++ b/model-optimizer/automation/package_BOM.txt @@ -47,6 +47,7 @@ extensions/back/RemoveUselessConvert.py extensions/back/Reshape0DToSqueeze.py extensions/back/ReshapeMutation.py extensions/back/ResultNormalizer.py +extensions/back/ResultRename.py extensions/back/ReverseInputChannels.py extensions/back/RNNSequenceTypeRename.py extensions/back/ScalarConstNormalize.py diff --git a/model-optimizer/extensions/back/ResultRename.py b/model-optimizer/extensions/back/ResultRename.py new file mode 100644 index 00000000000000..44bd4f9df4507e --- /dev/null +++ b/model-optimizer/extensions/back/ResultRename.py @@ -0,0 +1,38 @@ +""" + Copyright (C) 2018-2021 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from mo.back.replacement import BackReplacementPattern +from mo.graph.graph import Graph + + +class ResultRename(BackReplacementPattern): + # This transformation sets the Result operation name equal to the incoming tensor name. + # For some frameworks like kaldi and onnx this may result in appearance of nodes with identical names, + # which can lead to errors in other transformations. + # So ResultRename should be launched at the end of back phase. + enabled = False + + def find_and_replace_pattern(self, graph: Graph): + for node in graph.get_op_nodes(type='Result'): + if node.in_ports(): + prev_node_out_port = node.in_port(0).get_connection().get_source() + tensor_names = prev_node_out_port.get_tensor_names() + if tensor_names: + result_name = tensor_names[0] + else: + result_name = prev_node_out_port.node.soft_get('name', prev_node_out_port.node.id) + \ + '/sink_port_' + str(prev_node_out_port.idx) + node['name'] = result_name diff --git a/model-optimizer/extensions/back/ResultRename_test.py b/model-optimizer/extensions/back/ResultRename_test.py new file mode 100644 index 00000000000000..ba75438d50ff01 --- /dev/null +++ b/model-optimizer/extensions/back/ResultRename_test.py @@ -0,0 +1,48 @@ +""" + Copyright (C) 2018-2021 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import unittest + +from extensions.back.ResultRename import ResultRename +from mo.graph.graph import Node +from mo.utils.ir_engine.compare_graphs import compare_graphs +from mo.utils.unittest.graph import build_graph, regular_op, result + +nodes = { + **regular_op('Op1', {'type': 'Op1', 'kind': 'op', 'op': 'Op1'}), + **result('result'), + 'Op1_data': {'kind': 'data', 'fw_tensor_debug_info': [('Op1', 0, 'Op1_tensor')]} +} + + +class ResultRenameTest(unittest.TestCase): + def test_case1(self): + graph = build_graph(nodes, [('Op1', 'Op1_data'), ('Op1_data', 'result')]) + graph_ref = build_graph(nodes, [('Op1', 'Op1_data'), ('Op1_data', 'result')]) + res_node = Node(graph_ref, 'result') + res_node['name'] = 'Op1_tensor' + + ResultRename().find_and_replace_pattern(graph) + (flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True) + self.assertTrue(flag, resp) + + def test_case2(self): + graph = build_graph(nodes, []) + graph_ref = build_graph(nodes, []) + + ResultRename().find_and_replace_pattern(graph) + (flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True) + self.assertTrue(flag, resp) diff --git a/model-optimizer/mo/back/ie_ir_ver_2/emitter.py b/model-optimizer/mo/back/ie_ir_ver_2/emitter.py index c4a11d3765cd2f..1b422ad6c5cd32 100644 --- a/model-optimizer/mo/back/ie_ir_ver_2/emitter.py +++ b/model-optimizer/mo/back/ie_ir_ver_2/emitter.py @@ -174,8 +174,8 @@ def xml_ports(node: Node, element: Element, edges: Element): assert node.graph.node[v]['shape'] is not None, 'Output shape is not calculated properly for node {}' \ ''.format(node.id) tensor_names = node.out_port(port_id).get_tensor_names(port_renumber=True) - if tensor_names is not None: - port.set('names', tensor_names) + if tensor_names: + port.set('names', ','.join(tensor_names)) xml_shape(node.graph.node[v]['shape'], port) diff --git a/model-optimizer/mo/graph/port.py b/model-optimizer/mo/graph/port.py index 2ddb239b244311..163d0e5d5a262c 100644 --- a/model-optimizer/mo/graph/port.py +++ b/model-optimizer/mo/graph/port.py @@ -1,5 +1,5 @@ """ - Copyright (C) 2018-2020 Intel Corporation + Copyright (C) 2018-2021 Intel Corporation Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -310,9 +310,7 @@ def get_tensor_names_list(attrs): if node_idx in self.node.out_nodes(): out_node = self.node.out_node(node_idx) fw_names += get_tensor_names_list(out_node.attrs()) - if len(fw_names) > 0: - return ','.join(fw_names) - return None + return fw_names def disconnect(self): if self.type == 'out': diff --git a/model-optimizer/mo/graph/port_test.py b/model-optimizer/mo/graph/port_test.py index 42b03c327da649..a8ef569bd2f5d4 100644 --- a/model-optimizer/mo/graph/port_test.py +++ b/model-optimizer/mo/graph/port_test.py @@ -39,50 +39,50 @@ def test_front(self): ('Op1', 0, 'Op1,Op2')]})]) graph.stage = 'front' input_node = Node(graph, 'input') - self.assertTrue(input_node.out_port(0).get_tensor_names() == 'input,Op1\\,Op2') + self.assertTrue(input_node.out_port(0).get_tensor_names() == ['input', 'Op1\\,Op2']) op1_node = Node(graph, 'Op1') op1_node.add_output_port(0) - self.assertTrue(op1_node.out_port(0).get_tensor_names() is None) + self.assertTrue(op1_node.out_port(0).get_tensor_names() == []) def test_middle(self): graph = build_graph(nodes, [('input', 'input_data'), ('input_data', 'Op1'), ('input_data', 'Op2')]) input_node = Node(graph, 'input') - self.assertTrue(input_node.out_port(0).get_tensor_names() == 'input,Op1\\,Op2') + self.assertTrue(input_node.out_port(0).get_tensor_names() == ['input', 'Op1\\,Op2']) op1_node = Node(graph, 'Op1') op1_node.add_output_port(0) - self.assertTrue(op1_node.out_port(0).get_tensor_names() is None) + self.assertTrue(op1_node.out_port(0).get_tensor_names() == []) op2_node = Node(graph, 'Op2') op2_node.add_output_port(0) - self.assertTrue(op2_node.out_port(0).get_tensor_names() is None) + self.assertTrue(op2_node.out_port(0).get_tensor_names() == []) def test_port_renumber(self): graph = build_graph(nodes, [('input', 'input_data'), ('input_data', 'Op1'), ('Op1', 'Op1_data', {'out': 1}), ('Op1_data', 'Op2')]) input_node = Node(graph, 'input') - self.assertTrue(input_node.out_port(0).get_tensor_names(port_renumber=True) == 'input,Op1\\,Op2') + self.assertTrue(input_node.out_port(0).get_tensor_names(port_renumber=True) == ['input', 'Op1\\,Op2']) op1_node = Node(graph, 'Op1') op1_node.add_output_port(0) - self.assertTrue(op1_node.out_port(0).get_tensor_names(port_renumber=True) == 'Op1\\,Op2') + self.assertTrue(op1_node.out_port(0).get_tensor_names(port_renumber=True) == ['Op1\\,Op2']) def test_reconnect_middle_case1(self): graph = build_graph(nodes, [('input', 'input_data'), ('input_data', 'Op1'), ('Op3', 'Op3_data')]) input_node = Node(graph, 'input') input_node_out_port = input_node.out_port(0) - self.assertTrue(input_node_out_port.get_tensor_names() == 'input,Op1\\,Op2') + self.assertTrue(input_node_out_port.get_tensor_names() == ['input', 'Op1\\,Op2']) op3_node = Node(graph, 'Op3') input_node_out_port.get_connection().set_source(op3_node.out_port(0)) self.assertTrue(input_node_out_port.get_tensor_names() is None) - self.assertTrue(op3_node.out_port(0).get_tensor_names() == 'Op3,input,Op1\\,Op2') + self.assertTrue(op3_node.out_port(0).get_tensor_names() == ['Op3', 'input', 'Op1\\,Op2']) def test_reconnect_front_case1(self): graph = build_graph(nodes, [('input', 'Op1', {'in': 0, 'out': 0, 'fw_tensor_debug_info': [('input', 0, 'input'), @@ -93,26 +93,26 @@ def test_reconnect_front_case1(self): input_node = Node(graph, 'input') input_node_out_port = input_node.out_port(0) - self.assertTrue(input_node_out_port.get_tensor_names() == 'input,Op1\\,Op2') + self.assertTrue(input_node_out_port.get_tensor_names() == ['input', 'Op1\\,Op2']) op3_node = Node(graph, 'Op3') input_node_out_port.get_connection().set_source(op3_node.out_port(0)) - self.assertTrue(input_node_out_port.get_tensor_names() is None) - self.assertTrue(op3_node.out_port(0).get_tensor_names() == 'Op3,input,Op1\\,Op2') + self.assertTrue(input_node_out_port.get_tensor_names() == []) + self.assertTrue(op3_node.out_port(0).get_tensor_names() == ['Op3', 'input', 'Op1\\,Op2']) def test_reconnect_middle_case1(self): graph = build_graph(nodes, [('input', 'input_data'), ('input_data', 'Op1'), ('Op3', 'Op3_data')]) input_node = Node(graph, 'input') input_node_out_port = input_node.out_port(0) - self.assertTrue(input_node_out_port.get_tensor_names() == 'input,Op1\\,Op2') + self.assertTrue(input_node_out_port.get_tensor_names() == ['input', 'Op1\\,Op2']) op3_node = Node(graph, 'Op3') input_node_out_port.get_connection().set_source(op3_node.out_port(0)) - self.assertTrue(input_node_out_port.get_tensor_names() is None) - self.assertTrue(op3_node.out_port(0).get_tensor_names() == 'Op3,input,Op1\\,Op2') + self.assertTrue(input_node_out_port.get_tensor_names() == []) + self.assertTrue(op3_node.out_port(0).get_tensor_names() == ['Op3', 'input', 'Op1\\,Op2']) def test_reconnect_middle_case2(self): graph = build_graph(nodes, [('input', 'input_data'), ('input_data', 'Op1', {'out': 0}), @@ -120,10 +120,10 @@ def test_reconnect_middle_case2(self): input_node = Node(graph, 'input') input_node_out_port = input_node.out_port(0) - self.assertTrue(input_node_out_port.get_tensor_names() == 'input,Op1\\,Op2') + self.assertTrue(input_node_out_port.get_tensor_names() == ['input', 'Op1\\,Op2']) op3_node = Node(graph, 'Op3') input_node_out_port.get_connection().set_source(op3_node.out_port(0)) - self.assertTrue(input_node_out_port.get_tensor_names() is None) - self.assertTrue(op3_node.out_port(0).get_tensor_names() == 'Op3,input,Op1\\,Op2') + self.assertTrue(input_node_out_port.get_tensor_names() == []) + self.assertTrue(op3_node.out_port(0).get_tensor_names() == ['Op3', 'input', 'Op1\\,Op2']) diff --git a/model-optimizer/mo/pipeline/common.py b/model-optimizer/mo/pipeline/common.py index 65e28cfb33b100..51b8d8bc795199 100644 --- a/model-optimizer/mo/pipeline/common.py +++ b/model-optimizer/mo/pipeline/common.py @@ -22,6 +22,7 @@ import networkx as nx from extensions.back.RemoveUselessConvert import RemoveUselessConvert +from extensions.back.ResultRename import ResultRename from extensions.back.op_versioning import OpVersioning from extensions.ops.Cast import Cast from mo.back.ie_ir_ver_2.emitter import port_renumber, serialize_constants, generate_ie_ir, serialize_mean_image @@ -208,6 +209,8 @@ def prepare_emit_ir(graph: Graph, data_type: str, output_dir: str, output_model_ type_infer(graph) RemoveUselessConvert().find_and_replace_pattern(graph) + ResultRename().find_and_replace_pattern(graph) + for sub_graph in [graph] + collect_sub_graphs(graph): op_order, data_order = determined_sort(get_sorted_outputs(sub_graph)) mapping = {v: u for u, v in enumerate(op_order)}