Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed mapping of input name #3737

Merged
merged 10 commits into from
Dec 28, 2020
7 changes: 7 additions & 0 deletions model-optimizer/extensions/middle/AddMeanScaleValues.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,14 @@ def insert_pre_processing(graph: Graph, input_node: Node, node_mean_scale_values

for dst in input_node.out_port(0).get_destinations():
if dst.node.soft_get('type') != 'ShapeOf':
# After the insertion of additional operations model optimizer
# should keep the link to the input layer. Parameter node in framework
# should map to parameter node in IR.
# For this reason 'fw_tensor_debug_info' should be kept in data node.
fw_name = input_node.out_node(0)['fw_tensor_debug_info']
dst.get_connection().set_source(preprocessing.out_port(0))
input_node.out_node(0)['fw_tensor_debug_info'] = fw_name
del preprocessing.out_node(0)['fw_tensor_debug_info']

input_node.out_port(0).connect(preprocessing.in_port(0))

Expand Down
60 changes: 54 additions & 6 deletions model-optimizer/extensions/middle/AddMeanScaleValues_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from extensions.middle.AddMeanScaleValues import AddMeanScaleValues
from extensions.middle.ScaleInput import ScaleInput
from mo.graph.graph import Graph, Node
from mo.utils.cli_parser import get_mean_scale_dictionary, parse_tuple_pairs
from mo.utils.ir_engine.compare_graphs import compare_graphs
from mo.utils.unittest.graph import build_graph, regular_op_with_shaped_data, result, connect, connect_data, \
Expand All @@ -45,6 +46,25 @@


class AddMeanScaleValuesTest(unittest.TestCase):
def check_graph_attrs(self, graph: Graph, graph_ref: Graph, parameter_node_names: list):
for node in graph.get_op_nodes():
if node.soft_get('name') in parameter_node_names:
self.assertTrue(node.soft_get('type') == 'Parameter')
out_node = node.out_node(0)
out_node_ref = Node(graph_ref, node.id).out_node(0)
self.assertTrue(out_node['fw_tensor_debug_info'] == out_node_ref['fw_tensor_debug_info'])
else:
if 0 in node.out_nodes():
out_node = node.out_node(0)
self.assertFalse('fw_tensor_debug_info' in out_node)

def set_graph_attrs(self, graph: Graph, parameter_node_names: list):
for node in graph.get_op_nodes():
if node.soft_get('name') in parameter_node_names:
self.assertTrue(node.soft_get('type') == 'Parameter')
out_node = node.out_node(0)
out_node['fw_tensor_debug_info'] = ['fw_name', 0]

def test_mean_values_with_data_name(self):
graph_ref = build_graph(nodes, [
*connect('parameter', '0:add_mean'),
Expand All @@ -58,18 +78,21 @@ def test_mean_values_with_data_name(self):
argv = Namespace(mean_scale_values=mean_scale)

graph = build_graph(nodes, [*connect('parameter', 'result')], nodes_with_edges_only=True, cli=argv)
self.set_graph_attrs(graph, ['parameter'])
self.set_graph_attrs(graph_ref, ['parameter'])
graph.graph['layout'] = 'NCHW'

AddMeanScaleValues().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
self.assertTrue(flag, resp)
self.check_graph_attrs(graph, graph_ref, ['parameter'])

def test_mean_values_without_data_name(self):
graph_ref = build_graph(nodes, [
*connect('parameter', '0:add_mean'),
*connect('mean', '1:add_mean'),
*connect('add_mean', 'result'),
])
], {'parameter': {'name': 'None'}})

mean_values = parse_tuple_pairs('(1,2,3)')
scale_values = parse_tuple_pairs('')
Expand All @@ -78,11 +101,14 @@ def test_mean_values_without_data_name(self):

graph = build_graph(nodes, [*connect('parameter', 'result')], {'parameter': {'name': 'None'}},
nodes_with_edges_only=True, cli=argv)
self.set_graph_attrs(graph, ['None'])
self.set_graph_attrs(graph_ref, ['None'])
graph.graph['layout'] = 'NCHW'

AddMeanScaleValues().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
self.assertTrue(flag, resp)
self.check_graph_attrs(graph, graph_ref, ['None'])

def test_mean_values_explicit_and_optimized(self):
graph_ref = build_graph(nodes, [
Expand All @@ -96,13 +122,16 @@ def test_mean_values_explicit_and_optimized(self):
'parameter_2': {'mean': np.array([0., 0., 0.])}})
graph = build_graph(nodes, [*connect('parameter', 'result'), *connect('parameter_2', 'result_2')],
nodes_with_edges_only=True, cli=argv)
self.set_graph_attrs(graph, ['parameter', 'parameter_2'])
self.set_graph_attrs(graph_ref, ['parameter', 'parameter_2'])
graph.graph['layout'] = 'NCHW'

AddMeanScaleValues().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
self.assertTrue(flag, resp)
(flag, resp) = compare_graphs(graph, graph_ref, 'result_2', check_op_attrs=True)
self.assertTrue(flag, resp)
self.check_graph_attrs(graph, graph_ref, ['parameter', 'parameter_2'])

def test_mean_values_explicit_and_scale_values_optimized(self):
graph_ref = build_graph(nodes, [
Expand All @@ -113,11 +142,14 @@ def test_mean_values_explicit_and_scale_values_optimized(self):

argv = Namespace(mean_scale_values={'parameter': {'scale': np.array([1.]), 'mean': np.array([1., 2., 3.])}})
graph = build_graph(nodes, [*connect('parameter', 'result')], nodes_with_edges_only=True, cli=argv)
self.set_graph_attrs(graph, ['parameter'])
self.set_graph_attrs(graph_ref, ['parameter'])
graph.graph['layout'] = 'NCHW'

AddMeanScaleValues().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
self.assertTrue(flag, resp)
self.check_graph_attrs(graph, graph_ref, ['parameter'])

def test_mean_values_optimized_and_scale_values_explicit(self):
graph_ref = build_graph(nodes, [
Expand All @@ -129,11 +161,14 @@ def test_mean_values_optimized_and_scale_values_explicit(self):
argv = Namespace(
mean_scale_values={'parameter': {'scale': np.array([1., 2., 3.]), 'mean': np.array([0., 0., 0.])}})
graph = build_graph(nodes, [*connect('parameter', 'result')], nodes_with_edges_only=True, cli=argv)
self.set_graph_attrs(graph, ['parameter'])
self.set_graph_attrs(graph_ref, ['parameter'])
graph.graph['layout'] = 'NCHW'

AddMeanScaleValues().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
self.assertTrue(flag, resp)
self.check_graph_attrs(graph, graph_ref, ['parameter'])

def test_mean_values_explicit_and_scale_values_explicit(self):
graph_ref = build_graph(nodes, [
Expand All @@ -147,11 +182,14 @@ def test_mean_values_explicit_and_scale_values_explicit(self):
argv = Namespace(mean_scale_values=[[np.array([1., 2., 3.]), np.array([1., 2., 3.])]])
graph = build_graph(nodes, [*connect('parameter', 'result')],
nodes_with_edges_only=True, cli=argv)
self.set_graph_attrs(graph, ['parameter'])
self.set_graph_attrs(graph_ref, ['parameter'])
graph.graph['layout'] = 'NCHW'

AddMeanScaleValues().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
self.assertTrue(flag, resp)
self.check_graph_attrs(graph, graph_ref, ['parameter'])

def test_mean_values_explicit_and_scale_values_explicit_on_cutted_graph(self):
"""
Expand All @@ -173,13 +211,16 @@ def test_mean_values_explicit_and_scale_values_explicit_on_cutted_graph(self):
graph = build_graph(
nodes, [*connect('parameter', 'result'), *connect('parameter_2', 'op'), *connect('op', 'result_2')],
{'parameter_2': {'initial_node_name': 'op'}}, nodes_with_edges_only=True, cli=argv)
self.set_graph_attrs(graph, ['parameter', 'parameter_2'])
self.set_graph_attrs(graph_ref, ['parameter', 'parameter_2'])
graph.graph['layout'] = 'NCHW'
AddMeanScaleValues().find_and_replace_pattern(graph)

(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
self.assertTrue(flag, resp)
(flag, resp) = compare_graphs(graph, graph_ref, 'result_2', check_op_attrs=True)
self.assertTrue(flag, resp)
self.check_graph_attrs(graph, graph_ref, ['parameter', 'parameter_2'])

def test_mean_values_explicit_and_scale_values_explicit_with_shape_of(self):
graph_ref = build_graph(nodes,
Expand All @@ -203,36 +244,43 @@ def test_mean_values_explicit_and_scale_values_explicit_with_shape_of(self):
*connect('shape_of', 'result_2'),
],
nodes_with_edges_only=True, cli=argv)
self.set_graph_attrs(graph, ['parameter'])
self.set_graph_attrs(graph_ref, ['parameter'])
graph.graph['layout'] = 'NCHW'

AddMeanScaleValues().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
self.assertTrue(flag, resp)
(flag, resp) = compare_graphs(graph, graph_ref, 'result_2', check_op_attrs=True)
self.assertTrue(flag, resp)
self.check_graph_attrs(graph, graph_ref, ['parameter'])


class ScaleInputTests(unittest.TestCase):
def test_scale_input(self):
graph_ref = build_graph(nodes, [
*connect('parameter', '0:mul_scale'),
*connect('scale', '1:mul_scale'),
*connect('mul_scale', 'result'),
], {'scale': {'shape': [1, 1, 1, 1], 'value': np.array(1/255)},
'scale_d': {'shape': [1, 1, 1, 1], 'value': np.array(1/255)}})
], {'scale': {'shape': [1, 1, 1, 1], 'value': np.array(1 / 255)},
'scale_d': {'shape': [1, 1, 1, 1], 'value': np.array(1 / 255)}})

graph = build_graph(nodes, connect('parameter', 'result'), nodes_with_edges_only=True, cli=Namespace(scale=255))
self.set_graph_attrs(graph, ['parameter'])
self.set_graph_attrs(graph_ref, ['parameter'])
graph.graph['layout'] = 'NCHW'

ScaleInput().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
self.assertTrue(flag, resp)
self.check_graph_attrs(graph, graph_ref, ['parameter'])

def test_scale_input_2(self):
graph_ref = build_graph(nodes, connect('parameter', 'result'), nodes_with_edges_only=True)
graph = build_graph(nodes, connect('parameter', 'result'), nodes_with_edges_only=True, cli=Namespace(scale=1))
self.set_graph_attrs(graph, ['parameter'])
self.set_graph_attrs(graph_ref, ['parameter'])
graph.graph['layout'] = 'NCHW'

ScaleInput().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
self.assertTrue(flag, resp)
self.assertTrue(flag, resp)
self.check_graph_attrs(graph, graph_ref, ['parameter'])