Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed mapping of input name #3737

Merged
merged 10 commits into from
Dec 28, 2020
5 changes: 5 additions & 0 deletions model-optimizer/extensions/middle/AddMeanScaleValues.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,12 @@ def insert_pre_processing(graph: Graph, input_node: Node, node_mean_scale_values

for dst in input_node.out_port(0).get_destinations():
if dst.node.soft_get('type') != 'ShapeOf':
# 'fw_tensor_debug_info' should be kept in data node for the correct
# mapping of input names in mapping file
tmp = input_node.out_node(0)['fw_tensor_debug_info']
popovaan marked this conversation as resolved.
Show resolved Hide resolved
popovaan marked this conversation as resolved.
Show resolved Hide resolved
dst.get_connection().set_source(preprocessing.out_port(0))
input_node.out_node(0)['fw_tensor_debug_info'] = tmp
del preprocessing.out_node(0)['fw_tensor_debug_info']

input_node.out_port(0).connect(preprocessing.in_port(0))

Expand Down
65 changes: 64 additions & 1 deletion model-optimizer/extensions/middle/AddMeanScaleValues_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,28 @@
}


def set_graph_attrs(graph):
for node in graph.get_op_nodes():
if node.has_valid('type') and node['type'] == 'Parameter':
if 0 in node.out_nodes():
out_node = node.out_node(0)
out_node['fw_tensor_debug_info'] = ['data_attributes', 0]


class AddMeanScaleValuesTest(unittest.TestCase):
def check_graph_attrs(self, graph, graph_ref):
for node in graph.get_op_nodes():
if node.has_valid('type') and node['type'] == 'Parameter':
if 0 in node.out_nodes():
out_node = node.out_node(0)
self.assertTrue(out_node['fw_tensor_debug_info'] ==
graph_ref.node[out_node['name']]['fw_tensor_debug_info'])
else:
if 0 in node.out_nodes():
out_node = node.out_node(0)
self.assertFalse('fw_tensor_debug_info' in out_node)


def test_mean_values_with_data_name(self):
graph_ref = build_graph(nodes, [
*connect('parameter', '0:add_mean'),
Expand All @@ -58,11 +79,14 @@ def test_mean_values_with_data_name(self):
argv = Namespace(mean_scale_values=mean_scale)

graph = build_graph(nodes, [*connect('parameter', 'result')], nodes_with_edges_only=True, cli=argv)
set_graph_attrs(graph)
set_graph_attrs(graph_ref)
graph.graph['layout'] = 'NCHW'

AddMeanScaleValues().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
self.assertTrue(flag, resp)
self.check_graph_attrs(graph, graph_ref)

def test_mean_values_without_data_name(self):
graph_ref = build_graph(nodes, [
Expand All @@ -78,11 +102,14 @@ def test_mean_values_without_data_name(self):

graph = build_graph(nodes, [*connect('parameter', 'result')], {'parameter': {'name': 'None'}},
nodes_with_edges_only=True, cli=argv)
set_graph_attrs(graph)
set_graph_attrs(graph_ref)
graph.graph['layout'] = 'NCHW'

AddMeanScaleValues().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
self.assertTrue(flag, resp)
self.check_graph_attrs(graph, graph_ref)

def test_mean_values_explicit_and_optimized(self):
graph_ref = build_graph(nodes, [
Expand All @@ -96,13 +123,16 @@ def test_mean_values_explicit_and_optimized(self):
'parameter_2': {'mean': np.array([0., 0., 0.])}})
graph = build_graph(nodes, [*connect('parameter', 'result'), *connect('parameter_2', 'result_2')],
nodes_with_edges_only=True, cli=argv)
set_graph_attrs(graph)
set_graph_attrs(graph_ref)
graph.graph['layout'] = 'NCHW'

AddMeanScaleValues().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
self.assertTrue(flag, resp)
(flag, resp) = compare_graphs(graph, graph_ref, 'result_2', check_op_attrs=True)
self.assertTrue(flag, resp)
self.check_graph_attrs(graph, graph_ref)

def test_mean_values_explicit_and_scale_values_optimized(self):
graph_ref = build_graph(nodes, [
Expand All @@ -113,11 +143,14 @@ def test_mean_values_explicit_and_scale_values_optimized(self):

argv = Namespace(mean_scale_values={'parameter': {'scale': np.array([1.]), 'mean': np.array([1., 2., 3.])}})
graph = build_graph(nodes, [*connect('parameter', 'result')], nodes_with_edges_only=True, cli=argv)
set_graph_attrs(graph)
set_graph_attrs(graph_ref)
graph.graph['layout'] = 'NCHW'

AddMeanScaleValues().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
self.assertTrue(flag, resp)
self.check_graph_attrs(graph, graph_ref)

def test_mean_values_optimized_and_scale_values_explicit(self):
graph_ref = build_graph(nodes, [
Expand All @@ -129,11 +162,14 @@ def test_mean_values_optimized_and_scale_values_explicit(self):
argv = Namespace(
mean_scale_values={'parameter': {'scale': np.array([1., 2., 3.]), 'mean': np.array([0., 0., 0.])}})
graph = build_graph(nodes, [*connect('parameter', 'result')], nodes_with_edges_only=True, cli=argv)
set_graph_attrs(graph)
set_graph_attrs(graph_ref)
graph.graph['layout'] = 'NCHW'

AddMeanScaleValues().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
self.assertTrue(flag, resp)
self.check_graph_attrs(graph, graph_ref)

def test_mean_values_explicit_and_scale_values_explicit(self):
graph_ref = build_graph(nodes, [
Expand All @@ -147,11 +183,14 @@ def test_mean_values_explicit_and_scale_values_explicit(self):
argv = Namespace(mean_scale_values=[[np.array([1., 2., 3.]), np.array([1., 2., 3.])]])
graph = build_graph(nodes, [*connect('parameter', 'result')],
nodes_with_edges_only=True, cli=argv)
set_graph_attrs(graph)
set_graph_attrs(graph_ref)
graph.graph['layout'] = 'NCHW'

AddMeanScaleValues().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
self.assertTrue(flag, resp)
self.check_graph_attrs(graph, graph_ref)

def test_mean_values_explicit_and_scale_values_explicit_on_cutted_graph(self):
"""
Expand All @@ -173,13 +212,16 @@ def test_mean_values_explicit_and_scale_values_explicit_on_cutted_graph(self):
graph = build_graph(
nodes, [*connect('parameter', 'result'), *connect('parameter_2', 'op'), *connect('op', 'result_2')],
{'parameter_2': {'initial_node_name': 'op'}}, nodes_with_edges_only=True, cli=argv)
set_graph_attrs(graph)
set_graph_attrs(graph_ref)
graph.graph['layout'] = 'NCHW'
AddMeanScaleValues().find_and_replace_pattern(graph)

(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
self.assertTrue(flag, resp)
(flag, resp) = compare_graphs(graph, graph_ref, 'result_2', check_op_attrs=True)
self.assertTrue(flag, resp)
self.check_graph_attrs(graph, graph_ref)

def test_mean_values_explicit_and_scale_values_explicit_with_shape_of(self):
graph_ref = build_graph(nodes,
Expand All @@ -203,16 +245,31 @@ def test_mean_values_explicit_and_scale_values_explicit_with_shape_of(self):
*connect('shape_of', 'result_2'),
],
nodes_with_edges_only=True, cli=argv)
set_graph_attrs(graph)
set_graph_attrs(graph_ref)
graph.graph['layout'] = 'NCHW'

AddMeanScaleValues().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
self.assertTrue(flag, resp)
(flag, resp) = compare_graphs(graph, graph_ref, 'result_2', check_op_attrs=True)
self.assertTrue(flag, resp)
self.check_graph_attrs(graph, graph_ref)


class ScaleInputTests(unittest.TestCase):
def check_graph_attrs(self, graph, graph_ref):
for node in graph.get_op_nodes():
if node.has_valid('type') and node['type'] == 'Parameter':
if 0 in node.out_nodes():
out_node = node.out_node(0)
self.assertTrue(out_node['fw_tensor_debug_info'] ==
graph_ref.node[out_node['name']]['fw_tensor_debug_info'])
else:
if 0 in node.out_nodes():
out_node = node.out_node(0)
self.assertFalse('fw_tensor_debug_info' in out_node)

def test_scale_input(self):
graph_ref = build_graph(nodes, [
*connect('parameter', '0:mul_scale'),
Expand All @@ -222,17 +279,23 @@ def test_scale_input(self):
'scale_d': {'shape': [1, 1, 1, 1], 'value': np.array(1/255)}})

graph = build_graph(nodes, connect('parameter', 'result'), nodes_with_edges_only=True, cli=Namespace(scale=255))
set_graph_attrs(graph)
set_graph_attrs(graph_ref)
graph.graph['layout'] = 'NCHW'

ScaleInput().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
self.assertTrue(flag, resp)
self.check_graph_attrs(graph, graph_ref)

def test_scale_input_2(self):
graph_ref = build_graph(nodes, connect('parameter', 'result'), nodes_with_edges_only=True)
graph = build_graph(nodes, connect('parameter', 'result'), nodes_with_edges_only=True, cli=Namespace(scale=1))
set_graph_attrs(graph)
set_graph_attrs(graph_ref)
graph.graph['layout'] = 'NCHW'

ScaleInput().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
self.assertTrue(flag, resp)
self.assertTrue(flag, resp)
self.check_graph_attrs(graph, graph_ref)