Skip to content

Commit

Permalink
Result rename operation (openvinotoolkit#4242)
Browse files Browse the repository at this point in the history
* Added result rename operation

* Optimize imports

* Added ResultRename to package_BOM

* ResultRename moved to the end of back phase, code refactoring

* Revert incorrect changes

* Optimize imports

* Added comments and optimized imports.
  • Loading branch information
popovaan authored Feb 12, 2021
1 parent 2bf0e88 commit eeb7835
Show file tree
Hide file tree
Showing 8 changed files with 113 additions and 25 deletions.
2 changes: 1 addition & 1 deletion inference-engine/src/readers/ir_reader/ie_ir_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ std::shared_ptr<ngraph::Function> V10Parser::XmlDeserializer::parse_function(con
// Read all layers and store their parameters in params map
FOREACH_CHILD(node, root.child("layers"), "layer") {
auto node_param = parseGenericParams(node);
if (opName.find(node_param.name) != opName.end())
if (opName.find(node_param.name) != opName.end() && node_param.type != "Result")
THROW_IE_EXCEPTION << "Invalid IR! " << node_param.name << " name is not unique!";
opName.insert(node_param.name);
params[node_param.layerId] = {node, node_param};
Expand Down
1 change: 1 addition & 0 deletions model-optimizer/automation/package_BOM.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ extensions/back/RemoveUselessConvert.py
extensions/back/Reshape0DToSqueeze.py
extensions/back/ReshapeMutation.py
extensions/back/ResultNormalizer.py
extensions/back/ResultRename.py
extensions/back/ReverseInputChannels.py
extensions/back/RNNSequenceTypeRename.py
extensions/back/ScalarConstNormalize.py
Expand Down
38 changes: 38 additions & 0 deletions model-optimizer/extensions/back/ResultRename.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""
Copyright (C) 2018-2021 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from mo.back.replacement import BackReplacementPattern
from mo.graph.graph import Graph


class ResultRename(BackReplacementPattern):
# This transformation sets the Result operation name equal to the incoming tensor name.
# For some frameworks like kaldi and onnx this may result in appearance of nodes with identical names,
# which can lead to errors in other transformations.
# So ResultRename should be launched at the end of back phase.
enabled = False

def find_and_replace_pattern(self, graph: Graph):
for node in graph.get_op_nodes(type='Result'):
if node.in_ports():
prev_node_out_port = node.in_port(0).get_connection().get_source()
tensor_names = prev_node_out_port.get_tensor_names()
if tensor_names:
result_name = tensor_names[0]
else:
result_name = prev_node_out_port.node.soft_get('name', prev_node_out_port.node.id) + \
'/sink_port_' + str(prev_node_out_port.idx)
node['name'] = result_name
48 changes: 48 additions & 0 deletions model-optimizer/extensions/back/ResultRename_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""
Copyright (C) 2018-2021 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import unittest

from extensions.back.ResultRename import ResultRename
from mo.graph.graph import Node
from mo.utils.ir_engine.compare_graphs import compare_graphs
from mo.utils.unittest.graph import build_graph, regular_op, result

nodes = {
**regular_op('Op1', {'type': 'Op1', 'kind': 'op', 'op': 'Op1'}),
**result('result'),
'Op1_data': {'kind': 'data', 'fw_tensor_debug_info': [('Op1', 0, 'Op1_tensor')]}
}


class ResultRenameTest(unittest.TestCase):
def test_case1(self):
graph = build_graph(nodes, [('Op1', 'Op1_data'), ('Op1_data', 'result')])
graph_ref = build_graph(nodes, [('Op1', 'Op1_data'), ('Op1_data', 'result')])
res_node = Node(graph_ref, 'result')
res_node['name'] = 'Op1_tensor'

ResultRename().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
self.assertTrue(flag, resp)

def test_case2(self):
graph = build_graph(nodes, [])
graph_ref = build_graph(nodes, [])

ResultRename().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
self.assertTrue(flag, resp)
4 changes: 2 additions & 2 deletions model-optimizer/mo/back/ie_ir_ver_2/emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,8 @@ def xml_ports(node: Node, element: Element, edges: Element):
assert node.graph.node[v]['shape'] is not None, 'Output shape is not calculated properly for node {}' \
''.format(node.id)
tensor_names = node.out_port(port_id).get_tensor_names(port_renumber=True)
if tensor_names is not None:
port.set('names', tensor_names)
if tensor_names:
port.set('names', ','.join(tensor_names))
xml_shape(node.graph.node[v]['shape'], port)


Expand Down
6 changes: 2 additions & 4 deletions model-optimizer/mo/graph/port.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Copyright (C) 2018-2021 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -310,9 +310,7 @@ def get_tensor_names_list(attrs):
if node_idx in self.node.out_nodes():
out_node = self.node.out_node(node_idx)
fw_names += get_tensor_names_list(out_node.attrs())
if len(fw_names) > 0:
return ','.join(fw_names)
return None
return fw_names

def disconnect(self):
if self.type == 'out':
Expand Down
36 changes: 18 additions & 18 deletions model-optimizer/mo/graph/port_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,50 +39,50 @@ def test_front(self):
('Op1', 0, 'Op1,Op2')]})])
graph.stage = 'front'
input_node = Node(graph, 'input')
self.assertTrue(input_node.out_port(0).get_tensor_names() == 'input,Op1\\,Op2')
self.assertTrue(input_node.out_port(0).get_tensor_names() == ['input', 'Op1\\,Op2'])

op1_node = Node(graph, 'Op1')
op1_node.add_output_port(0)
self.assertTrue(op1_node.out_port(0).get_tensor_names() is None)
self.assertTrue(op1_node.out_port(0).get_tensor_names() == [])

def test_middle(self):
graph = build_graph(nodes, [('input', 'input_data'), ('input_data', 'Op1'),
('input_data', 'Op2')])

input_node = Node(graph, 'input')
self.assertTrue(input_node.out_port(0).get_tensor_names() == 'input,Op1\\,Op2')
self.assertTrue(input_node.out_port(0).get_tensor_names() == ['input', 'Op1\\,Op2'])

op1_node = Node(graph, 'Op1')
op1_node.add_output_port(0)
self.assertTrue(op1_node.out_port(0).get_tensor_names() is None)
self.assertTrue(op1_node.out_port(0).get_tensor_names() == [])

op2_node = Node(graph, 'Op2')
op2_node.add_output_port(0)
self.assertTrue(op2_node.out_port(0).get_tensor_names() is None)
self.assertTrue(op2_node.out_port(0).get_tensor_names() == [])

def test_port_renumber(self):
graph = build_graph(nodes, [('input', 'input_data'), ('input_data', 'Op1'),
('Op1', 'Op1_data', {'out': 1}), ('Op1_data', 'Op2')])
input_node = Node(graph, 'input')
self.assertTrue(input_node.out_port(0).get_tensor_names(port_renumber=True) == 'input,Op1\\,Op2')
self.assertTrue(input_node.out_port(0).get_tensor_names(port_renumber=True) == ['input', 'Op1\\,Op2'])

op1_node = Node(graph, 'Op1')
op1_node.add_output_port(0)

self.assertTrue(op1_node.out_port(0).get_tensor_names(port_renumber=True) == 'Op1\\,Op2')
self.assertTrue(op1_node.out_port(0).get_tensor_names(port_renumber=True) == ['Op1\\,Op2'])

def test_reconnect_middle_case1(self):
graph = build_graph(nodes, [('input', 'input_data'), ('input_data', 'Op1'), ('Op3', 'Op3_data')])
input_node = Node(graph, 'input')

input_node_out_port = input_node.out_port(0)
self.assertTrue(input_node_out_port.get_tensor_names() == 'input,Op1\\,Op2')
self.assertTrue(input_node_out_port.get_tensor_names() == ['input', 'Op1\\,Op2'])

op3_node = Node(graph, 'Op3')
input_node_out_port.get_connection().set_source(op3_node.out_port(0))

self.assertTrue(input_node_out_port.get_tensor_names() is None)
self.assertTrue(op3_node.out_port(0).get_tensor_names() == 'Op3,input,Op1\\,Op2')
self.assertTrue(op3_node.out_port(0).get_tensor_names() == ['Op3', 'input', 'Op1\\,Op2'])

def test_reconnect_front_case1(self):
graph = build_graph(nodes, [('input', 'Op1', {'in': 0, 'out': 0, 'fw_tensor_debug_info': [('input', 0, 'input'),
Expand All @@ -93,37 +93,37 @@ def test_reconnect_front_case1(self):
input_node = Node(graph, 'input')

input_node_out_port = input_node.out_port(0)
self.assertTrue(input_node_out_port.get_tensor_names() == 'input,Op1\\,Op2')
self.assertTrue(input_node_out_port.get_tensor_names() == ['input', 'Op1\\,Op2'])

op3_node = Node(graph, 'Op3')
input_node_out_port.get_connection().set_source(op3_node.out_port(0))

self.assertTrue(input_node_out_port.get_tensor_names() is None)
self.assertTrue(op3_node.out_port(0).get_tensor_names() == 'Op3,input,Op1\\,Op2')
self.assertTrue(input_node_out_port.get_tensor_names() == [])
self.assertTrue(op3_node.out_port(0).get_tensor_names() == ['Op3', 'input', 'Op1\\,Op2'])

def test_reconnect_middle_case1(self):
graph = build_graph(nodes, [('input', 'input_data'), ('input_data', 'Op1'), ('Op3', 'Op3_data')])
input_node = Node(graph, 'input')

input_node_out_port = input_node.out_port(0)
self.assertTrue(input_node_out_port.get_tensor_names() == 'input,Op1\\,Op2')
self.assertTrue(input_node_out_port.get_tensor_names() == ['input', 'Op1\\,Op2'])

op3_node = Node(graph, 'Op3')
input_node_out_port.get_connection().set_source(op3_node.out_port(0))

self.assertTrue(input_node_out_port.get_tensor_names() is None)
self.assertTrue(op3_node.out_port(0).get_tensor_names() == 'Op3,input,Op1\\,Op2')
self.assertTrue(input_node_out_port.get_tensor_names() == [])
self.assertTrue(op3_node.out_port(0).get_tensor_names() == ['Op3', 'input', 'Op1\\,Op2'])

def test_reconnect_middle_case2(self):
graph = build_graph(nodes, [('input', 'input_data'), ('input_data', 'Op1', {'out': 0}),
('input_data', 'Op1', {'out': 1}), ('Op3', 'Op3_data')])
input_node = Node(graph, 'input')

input_node_out_port = input_node.out_port(0)
self.assertTrue(input_node_out_port.get_tensor_names() == 'input,Op1\\,Op2')
self.assertTrue(input_node_out_port.get_tensor_names() == ['input', 'Op1\\,Op2'])

op3_node = Node(graph, 'Op3')
input_node_out_port.get_connection().set_source(op3_node.out_port(0))

self.assertTrue(input_node_out_port.get_tensor_names() is None)
self.assertTrue(op3_node.out_port(0).get_tensor_names() == 'Op3,input,Op1\\,Op2')
self.assertTrue(input_node_out_port.get_tensor_names() == [])
self.assertTrue(op3_node.out_port(0).get_tensor_names() == ['Op3', 'input', 'Op1\\,Op2'])
3 changes: 3 additions & 0 deletions model-optimizer/mo/pipeline/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import networkx as nx

from extensions.back.RemoveUselessConvert import RemoveUselessConvert
from extensions.back.ResultRename import ResultRename
from extensions.back.op_versioning import OpVersioning
from extensions.ops.Cast import Cast
from mo.back.ie_ir_ver_2.emitter import port_renumber, serialize_constants, generate_ie_ir, serialize_mean_image
Expand Down Expand Up @@ -208,6 +209,8 @@ def prepare_emit_ir(graph: Graph, data_type: str, output_dir: str, output_model_
type_infer(graph)
RemoveUselessConvert().find_and_replace_pattern(graph)

ResultRename().find_and_replace_pattern(graph)

for sub_graph in [graph] + collect_sub_graphs(graph):
op_order, data_order = determined_sort(get_sorted_outputs(sub_graph))
mapping = {v: u for u, v in enumerate(op_order)}
Expand Down

0 comments on commit eeb7835

Please sign in to comment.