diff --git a/model-optimizer/extensions/middle/AddMeanScaleValues.py b/model-optimizer/extensions/middle/AddMeanScaleValues.py index d6b3c637153dfe..1df8ca3bd0143d 100644 --- a/model-optimizer/extensions/middle/AddMeanScaleValues.py +++ b/model-optimizer/extensions/middle/AddMeanScaleValues.py @@ -64,7 +64,14 @@ def insert_pre_processing(graph: Graph, input_node: Node, node_mean_scale_values for dst in input_node.out_port(0).get_destinations(): if dst.node.soft_get('type') != 'ShapeOf': + # After the insertion of additional operations model optimizer + # should keep the link to the input layer. Parameter node in framework + # should map to parameter node in IR. + # For this reason 'fw_tensor_debug_info' should be kept in data node. + fw_name = input_node.out_node(0)['fw_tensor_debug_info'] dst.get_connection().set_source(preprocessing.out_port(0)) + input_node.out_node(0)['fw_tensor_debug_info'] = fw_name + del preprocessing.out_node(0)['fw_tensor_debug_info'] input_node.out_port(0).connect(preprocessing.in_port(0)) diff --git a/model-optimizer/extensions/middle/AddMeanScaleValues_test.py b/model-optimizer/extensions/middle/AddMeanScaleValues_test.py index 7c31a2ef984dd7..b2163d90d425ab 100644 --- a/model-optimizer/extensions/middle/AddMeanScaleValues_test.py +++ b/model-optimizer/extensions/middle/AddMeanScaleValues_test.py @@ -20,6 +20,7 @@ from extensions.middle.AddMeanScaleValues import AddMeanScaleValues from extensions.middle.ScaleInput import ScaleInput +from mo.graph.graph import Graph, Node from mo.utils.cli_parser import get_mean_scale_dictionary, parse_tuple_pairs from mo.utils.ir_engine.compare_graphs import compare_graphs from mo.utils.unittest.graph import build_graph, regular_op_with_shaped_data, result, connect, connect_data, \ @@ -45,6 +46,25 @@ class AddMeanScaleValuesTest(unittest.TestCase): + def check_graph_attrs(self, graph: Graph, graph_ref: Graph, parameter_node_names: list): + for node in graph.get_op_nodes(): + if node.soft_get('name') in parameter_node_names: + self.assertTrue(node.soft_get('type') == 'Parameter') + out_node = node.out_node(0) + out_node_ref = Node(graph_ref, node.id).out_node(0) + self.assertTrue(out_node['fw_tensor_debug_info'] == out_node_ref['fw_tensor_debug_info']) + else: + if 0 in node.out_nodes(): + out_node = node.out_node(0) + self.assertFalse('fw_tensor_debug_info' in out_node) + + def set_graph_attrs(self, graph: Graph, parameter_node_names: list): + for node in graph.get_op_nodes(): + if node.soft_get('name') in parameter_node_names: + self.assertTrue(node.soft_get('type') == 'Parameter') + out_node = node.out_node(0) + out_node['fw_tensor_debug_info'] = ['fw_name', 0] + def test_mean_values_with_data_name(self): graph_ref = build_graph(nodes, [ *connect('parameter', '0:add_mean'), @@ -58,18 +78,21 @@ def test_mean_values_with_data_name(self): argv = Namespace(mean_scale_values=mean_scale) graph = build_graph(nodes, [*connect('parameter', 'result')], nodes_with_edges_only=True, cli=argv) + self.set_graph_attrs(graph, ['parameter']) + self.set_graph_attrs(graph_ref, ['parameter']) graph.graph['layout'] = 'NCHW' AddMeanScaleValues().find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True) self.assertTrue(flag, resp) + self.check_graph_attrs(graph, graph_ref, ['parameter']) def test_mean_values_without_data_name(self): graph_ref = build_graph(nodes, [ *connect('parameter', '0:add_mean'), *connect('mean', '1:add_mean'), *connect('add_mean', 'result'), - ]) + ], {'parameter': {'name': 'None'}}) mean_values = parse_tuple_pairs('(1,2,3)') scale_values = parse_tuple_pairs('') @@ -78,11 +101,14 @@ def test_mean_values_without_data_name(self): graph = build_graph(nodes, [*connect('parameter', 'result')], {'parameter': {'name': 'None'}}, nodes_with_edges_only=True, cli=argv) + self.set_graph_attrs(graph, ['None']) + self.set_graph_attrs(graph_ref, ['None']) graph.graph['layout'] = 'NCHW' AddMeanScaleValues().find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True) self.assertTrue(flag, resp) + self.check_graph_attrs(graph, graph_ref, ['None']) def test_mean_values_explicit_and_optimized(self): graph_ref = build_graph(nodes, [ @@ -96,6 +122,8 @@ def test_mean_values_explicit_and_optimized(self): 'parameter_2': {'mean': np.array([0., 0., 0.])}}) graph = build_graph(nodes, [*connect('parameter', 'result'), *connect('parameter_2', 'result_2')], nodes_with_edges_only=True, cli=argv) + self.set_graph_attrs(graph, ['parameter', 'parameter_2']) + self.set_graph_attrs(graph_ref, ['parameter', 'parameter_2']) graph.graph['layout'] = 'NCHW' AddMeanScaleValues().find_and_replace_pattern(graph) @@ -103,6 +131,7 @@ def test_mean_values_explicit_and_optimized(self): self.assertTrue(flag, resp) (flag, resp) = compare_graphs(graph, graph_ref, 'result_2', check_op_attrs=True) self.assertTrue(flag, resp) + self.check_graph_attrs(graph, graph_ref, ['parameter', 'parameter_2']) def test_mean_values_explicit_and_scale_values_optimized(self): graph_ref = build_graph(nodes, [ @@ -113,11 +142,14 @@ def test_mean_values_explicit_and_scale_values_optimized(self): argv = Namespace(mean_scale_values={'parameter': {'scale': np.array([1.]), 'mean': np.array([1., 2., 3.])}}) graph = build_graph(nodes, [*connect('parameter', 'result')], nodes_with_edges_only=True, cli=argv) + self.set_graph_attrs(graph, ['parameter']) + self.set_graph_attrs(graph_ref, ['parameter']) graph.graph['layout'] = 'NCHW' AddMeanScaleValues().find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True) self.assertTrue(flag, resp) + self.check_graph_attrs(graph, graph_ref, ['parameter']) def test_mean_values_optimized_and_scale_values_explicit(self): graph_ref = build_graph(nodes, [ @@ -129,11 +161,14 @@ def test_mean_values_optimized_and_scale_values_explicit(self): argv = Namespace( mean_scale_values={'parameter': {'scale': np.array([1., 2., 3.]), 'mean': np.array([0., 0., 0.])}}) graph = build_graph(nodes, [*connect('parameter', 'result')], nodes_with_edges_only=True, cli=argv) + self.set_graph_attrs(graph, ['parameter']) + self.set_graph_attrs(graph_ref, ['parameter']) graph.graph['layout'] = 'NCHW' AddMeanScaleValues().find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True) self.assertTrue(flag, resp) + self.check_graph_attrs(graph, graph_ref, ['parameter']) def test_mean_values_explicit_and_scale_values_explicit(self): graph_ref = build_graph(nodes, [ @@ -147,11 +182,14 @@ def test_mean_values_explicit_and_scale_values_explicit(self): argv = Namespace(mean_scale_values=[[np.array([1., 2., 3.]), np.array([1., 2., 3.])]]) graph = build_graph(nodes, [*connect('parameter', 'result')], nodes_with_edges_only=True, cli=argv) + self.set_graph_attrs(graph, ['parameter']) + self.set_graph_attrs(graph_ref, ['parameter']) graph.graph['layout'] = 'NCHW' AddMeanScaleValues().find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True) self.assertTrue(flag, resp) + self.check_graph_attrs(graph, graph_ref, ['parameter']) def test_mean_values_explicit_and_scale_values_explicit_on_cutted_graph(self): """ @@ -173,6 +211,8 @@ def test_mean_values_explicit_and_scale_values_explicit_on_cutted_graph(self): graph = build_graph( nodes, [*connect('parameter', 'result'), *connect('parameter_2', 'op'), *connect('op', 'result_2')], {'parameter_2': {'initial_node_name': 'op'}}, nodes_with_edges_only=True, cli=argv) + self.set_graph_attrs(graph, ['parameter', 'parameter_2']) + self.set_graph_attrs(graph_ref, ['parameter', 'parameter_2']) graph.graph['layout'] = 'NCHW' AddMeanScaleValues().find_and_replace_pattern(graph) @@ -180,6 +220,7 @@ def test_mean_values_explicit_and_scale_values_explicit_on_cutted_graph(self): self.assertTrue(flag, resp) (flag, resp) = compare_graphs(graph, graph_ref, 'result_2', check_op_attrs=True) self.assertTrue(flag, resp) + self.check_graph_attrs(graph, graph_ref, ['parameter', 'parameter_2']) def test_mean_values_explicit_and_scale_values_explicit_with_shape_of(self): graph_ref = build_graph(nodes, @@ -203,6 +244,8 @@ def test_mean_values_explicit_and_scale_values_explicit_with_shape_of(self): *connect('shape_of', 'result_2'), ], nodes_with_edges_only=True, cli=argv) + self.set_graph_attrs(graph, ['parameter']) + self.set_graph_attrs(graph_ref, ['parameter']) graph.graph['layout'] = 'NCHW' AddMeanScaleValues().find_and_replace_pattern(graph) @@ -210,29 +253,34 @@ def test_mean_values_explicit_and_scale_values_explicit_with_shape_of(self): self.assertTrue(flag, resp) (flag, resp) = compare_graphs(graph, graph_ref, 'result_2', check_op_attrs=True) self.assertTrue(flag, resp) + self.check_graph_attrs(graph, graph_ref, ['parameter']) - -class ScaleInputTests(unittest.TestCase): def test_scale_input(self): graph_ref = build_graph(nodes, [ *connect('parameter', '0:mul_scale'), *connect('scale', '1:mul_scale'), *connect('mul_scale', 'result'), - ], {'scale': {'shape': [1, 1, 1, 1], 'value': np.array(1/255)}, - 'scale_d': {'shape': [1, 1, 1, 1], 'value': np.array(1/255)}}) + ], {'scale': {'shape': [1, 1, 1, 1], 'value': np.array(1 / 255)}, + 'scale_d': {'shape': [1, 1, 1, 1], 'value': np.array(1 / 255)}}) graph = build_graph(nodes, connect('parameter', 'result'), nodes_with_edges_only=True, cli=Namespace(scale=255)) + self.set_graph_attrs(graph, ['parameter']) + self.set_graph_attrs(graph_ref, ['parameter']) graph.graph['layout'] = 'NCHW' ScaleInput().find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'result') self.assertTrue(flag, resp) + self.check_graph_attrs(graph, graph_ref, ['parameter']) def test_scale_input_2(self): graph_ref = build_graph(nodes, connect('parameter', 'result'), nodes_with_edges_only=True) graph = build_graph(nodes, connect('parameter', 'result'), nodes_with_edges_only=True, cli=Namespace(scale=1)) + self.set_graph_attrs(graph, ['parameter']) + self.set_graph_attrs(graph_ref, ['parameter']) graph.graph['layout'] = 'NCHW' ScaleInput().find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'result') - self.assertTrue(flag, resp) \ No newline at end of file + self.assertTrue(flag, resp) + self.check_graph_attrs(graph, graph_ref, ['parameter'])