Skip to content

Commit

Permalink
Fixed mapping of input name (#3737)
Browse files Browse the repository at this point in the history
* Fixed mapping of input name

* Fixed unit tests

* Fixed mapping of input name

* Fixed unit tests

* attributes check fix

* PEP8 code format

* code duplicate removal

* variable rename
  • Loading branch information
popovaan authored Dec 28, 2020
1 parent 37b6e75 commit 631d452
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 6 deletions.
7 changes: 7 additions & 0 deletions model-optimizer/extensions/middle/AddMeanScaleValues.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,14 @@ def insert_pre_processing(graph: Graph, input_node: Node, node_mean_scale_values

for dst in input_node.out_port(0).get_destinations():
if dst.node.soft_get('type') != 'ShapeOf':
# After the insertion of additional operations model optimizer
# should keep the link to the input layer. Parameter node in framework
# should map to parameter node in IR.
# For this reason 'fw_tensor_debug_info' should be kept in data node.
fw_name = input_node.out_node(0)['fw_tensor_debug_info']
dst.get_connection().set_source(preprocessing.out_port(0))
input_node.out_node(0)['fw_tensor_debug_info'] = fw_name
del preprocessing.out_node(0)['fw_tensor_debug_info']

input_node.out_port(0).connect(preprocessing.in_port(0))

Expand Down
60 changes: 54 additions & 6 deletions model-optimizer/extensions/middle/AddMeanScaleValues_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from extensions.middle.AddMeanScaleValues import AddMeanScaleValues
from extensions.middle.ScaleInput import ScaleInput
from mo.graph.graph import Graph, Node
from mo.utils.cli_parser import get_mean_scale_dictionary, parse_tuple_pairs
from mo.utils.ir_engine.compare_graphs import compare_graphs
from mo.utils.unittest.graph import build_graph, regular_op_with_shaped_data, result, connect, connect_data, \
Expand All @@ -45,6 +46,25 @@


class AddMeanScaleValuesTest(unittest.TestCase):
def check_graph_attrs(self, graph: Graph, graph_ref: Graph, parameter_node_names: list):
for node in graph.get_op_nodes():
if node.soft_get('name') in parameter_node_names:
self.assertTrue(node.soft_get('type') == 'Parameter')
out_node = node.out_node(0)
out_node_ref = Node(graph_ref, node.id).out_node(0)
self.assertTrue(out_node['fw_tensor_debug_info'] == out_node_ref['fw_tensor_debug_info'])
else:
if 0 in node.out_nodes():
out_node = node.out_node(0)
self.assertFalse('fw_tensor_debug_info' in out_node)

def set_graph_attrs(self, graph: Graph, parameter_node_names: list):
for node in graph.get_op_nodes():
if node.soft_get('name') in parameter_node_names:
self.assertTrue(node.soft_get('type') == 'Parameter')
out_node = node.out_node(0)
out_node['fw_tensor_debug_info'] = ['fw_name', 0]

def test_mean_values_with_data_name(self):
graph_ref = build_graph(nodes, [
*connect('parameter', '0:add_mean'),
Expand All @@ -58,18 +78,21 @@ def test_mean_values_with_data_name(self):
argv = Namespace(mean_scale_values=mean_scale)

graph = build_graph(nodes, [*connect('parameter', 'result')], nodes_with_edges_only=True, cli=argv)
self.set_graph_attrs(graph, ['parameter'])
self.set_graph_attrs(graph_ref, ['parameter'])
graph.graph['layout'] = 'NCHW'

AddMeanScaleValues().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
self.assertTrue(flag, resp)
self.check_graph_attrs(graph, graph_ref, ['parameter'])

def test_mean_values_without_data_name(self):
graph_ref = build_graph(nodes, [
*connect('parameter', '0:add_mean'),
*connect('mean', '1:add_mean'),
*connect('add_mean', 'result'),
])
], {'parameter': {'name': 'None'}})

mean_values = parse_tuple_pairs('(1,2,3)')
scale_values = parse_tuple_pairs('')
Expand All @@ -78,11 +101,14 @@ def test_mean_values_without_data_name(self):

graph = build_graph(nodes, [*connect('parameter', 'result')], {'parameter': {'name': 'None'}},
nodes_with_edges_only=True, cli=argv)
self.set_graph_attrs(graph, ['None'])
self.set_graph_attrs(graph_ref, ['None'])
graph.graph['layout'] = 'NCHW'

AddMeanScaleValues().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
self.assertTrue(flag, resp)
self.check_graph_attrs(graph, graph_ref, ['None'])

def test_mean_values_explicit_and_optimized(self):
graph_ref = build_graph(nodes, [
Expand All @@ -96,13 +122,16 @@ def test_mean_values_explicit_and_optimized(self):
'parameter_2': {'mean': np.array([0., 0., 0.])}})
graph = build_graph(nodes, [*connect('parameter', 'result'), *connect('parameter_2', 'result_2')],
nodes_with_edges_only=True, cli=argv)
self.set_graph_attrs(graph, ['parameter', 'parameter_2'])
self.set_graph_attrs(graph_ref, ['parameter', 'parameter_2'])
graph.graph['layout'] = 'NCHW'

AddMeanScaleValues().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
self.assertTrue(flag, resp)
(flag, resp) = compare_graphs(graph, graph_ref, 'result_2', check_op_attrs=True)
self.assertTrue(flag, resp)
self.check_graph_attrs(graph, graph_ref, ['parameter', 'parameter_2'])

def test_mean_values_explicit_and_scale_values_optimized(self):
graph_ref = build_graph(nodes, [
Expand All @@ -113,11 +142,14 @@ def test_mean_values_explicit_and_scale_values_optimized(self):

argv = Namespace(mean_scale_values={'parameter': {'scale': np.array([1.]), 'mean': np.array([1., 2., 3.])}})
graph = build_graph(nodes, [*connect('parameter', 'result')], nodes_with_edges_only=True, cli=argv)
self.set_graph_attrs(graph, ['parameter'])
self.set_graph_attrs(graph_ref, ['parameter'])
graph.graph['layout'] = 'NCHW'

AddMeanScaleValues().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
self.assertTrue(flag, resp)
self.check_graph_attrs(graph, graph_ref, ['parameter'])

def test_mean_values_optimized_and_scale_values_explicit(self):
graph_ref = build_graph(nodes, [
Expand All @@ -129,11 +161,14 @@ def test_mean_values_optimized_and_scale_values_explicit(self):
argv = Namespace(
mean_scale_values={'parameter': {'scale': np.array([1., 2., 3.]), 'mean': np.array([0., 0., 0.])}})
graph = build_graph(nodes, [*connect('parameter', 'result')], nodes_with_edges_only=True, cli=argv)
self.set_graph_attrs(graph, ['parameter'])
self.set_graph_attrs(graph_ref, ['parameter'])
graph.graph['layout'] = 'NCHW'

AddMeanScaleValues().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
self.assertTrue(flag, resp)
self.check_graph_attrs(graph, graph_ref, ['parameter'])

def test_mean_values_explicit_and_scale_values_explicit(self):
graph_ref = build_graph(nodes, [
Expand All @@ -147,11 +182,14 @@ def test_mean_values_explicit_and_scale_values_explicit(self):
argv = Namespace(mean_scale_values=[[np.array([1., 2., 3.]), np.array([1., 2., 3.])]])
graph = build_graph(nodes, [*connect('parameter', 'result')],
nodes_with_edges_only=True, cli=argv)
self.set_graph_attrs(graph, ['parameter'])
self.set_graph_attrs(graph_ref, ['parameter'])
graph.graph['layout'] = 'NCHW'

AddMeanScaleValues().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
self.assertTrue(flag, resp)
self.check_graph_attrs(graph, graph_ref, ['parameter'])

def test_mean_values_explicit_and_scale_values_explicit_on_cutted_graph(self):
"""
Expand All @@ -173,13 +211,16 @@ def test_mean_values_explicit_and_scale_values_explicit_on_cutted_graph(self):
graph = build_graph(
nodes, [*connect('parameter', 'result'), *connect('parameter_2', 'op'), *connect('op', 'result_2')],
{'parameter_2': {'initial_node_name': 'op'}}, nodes_with_edges_only=True, cli=argv)
self.set_graph_attrs(graph, ['parameter', 'parameter_2'])
self.set_graph_attrs(graph_ref, ['parameter', 'parameter_2'])
graph.graph['layout'] = 'NCHW'
AddMeanScaleValues().find_and_replace_pattern(graph)

(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
self.assertTrue(flag, resp)
(flag, resp) = compare_graphs(graph, graph_ref, 'result_2', check_op_attrs=True)
self.assertTrue(flag, resp)
self.check_graph_attrs(graph, graph_ref, ['parameter', 'parameter_2'])

def test_mean_values_explicit_and_scale_values_explicit_with_shape_of(self):
graph_ref = build_graph(nodes,
Expand All @@ -203,36 +244,43 @@ def test_mean_values_explicit_and_scale_values_explicit_with_shape_of(self):
*connect('shape_of', 'result_2'),
],
nodes_with_edges_only=True, cli=argv)
self.set_graph_attrs(graph, ['parameter'])
self.set_graph_attrs(graph_ref, ['parameter'])
graph.graph['layout'] = 'NCHW'

AddMeanScaleValues().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
self.assertTrue(flag, resp)
(flag, resp) = compare_graphs(graph, graph_ref, 'result_2', check_op_attrs=True)
self.assertTrue(flag, resp)
self.check_graph_attrs(graph, graph_ref, ['parameter'])


class ScaleInputTests(unittest.TestCase):
def test_scale_input(self):
graph_ref = build_graph(nodes, [
*connect('parameter', '0:mul_scale'),
*connect('scale', '1:mul_scale'),
*connect('mul_scale', 'result'),
], {'scale': {'shape': [1, 1, 1, 1], 'value': np.array(1/255)},
'scale_d': {'shape': [1, 1, 1, 1], 'value': np.array(1/255)}})
], {'scale': {'shape': [1, 1, 1, 1], 'value': np.array(1 / 255)},
'scale_d': {'shape': [1, 1, 1, 1], 'value': np.array(1 / 255)}})

graph = build_graph(nodes, connect('parameter', 'result'), nodes_with_edges_only=True, cli=Namespace(scale=255))
self.set_graph_attrs(graph, ['parameter'])
self.set_graph_attrs(graph_ref, ['parameter'])
graph.graph['layout'] = 'NCHW'

ScaleInput().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
self.assertTrue(flag, resp)
self.check_graph_attrs(graph, graph_ref, ['parameter'])

def test_scale_input_2(self):
graph_ref = build_graph(nodes, connect('parameter', 'result'), nodes_with_edges_only=True)
graph = build_graph(nodes, connect('parameter', 'result'), nodes_with_edges_only=True, cli=Namespace(scale=1))
self.set_graph_attrs(graph, ['parameter'])
self.set_graph_attrs(graph_ref, ['parameter'])
graph.graph['layout'] = 'NCHW'

ScaleInput().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
self.assertTrue(flag, resp)
self.assertTrue(flag, resp)
self.check_graph_attrs(graph, graph_ref, ['parameter'])

0 comments on commit 631d452

Please sign in to comment.