diff --git a/model-optimizer/extensions/front/onnx/AttributedSliceToSlice_test.py b/model-optimizer/extensions/front/onnx/AttributedSliceToSlice_test.py index 89a675438bdae2..e9837a1c489ed2 100644 --- a/model-optimizer/extensions/front/onnx/AttributedSliceToSlice_test.py +++ b/model-optimizer/extensions/front/onnx/AttributedSliceToSlice_test.py @@ -21,7 +21,7 @@ from extensions.front.onnx.AttributedSliceToSlice import AttributedSliceToSliceReplacer from mo.utils.ir_engine.compare_graphs import compare_graphs -from mo.utils.unittest.graph import build_graph, regular_op_with_empty_data, result, const, connect +from mo.utils.unittest.graph import build_graph, regular_op_with_empty_data, result, const, connect_front @generator @@ -53,9 +53,9 @@ def test_attributed_slice_replacer(self, attributed_slice_attrs): graph_ref = build_graph(nodes_attrs=nodes, edges=[ ('input', 'slice'), - *connect('start', '1:slice', on_front=True), - *connect('end', '2:slice', on_front=True), - *connect('axis', '3:slice', on_front=True), + *connect_front('start', '1:slice'), + *connect_front('end', '2:slice'), + *connect_front('axis', '3:slice'), ('slice', 'output'), ], nodes_with_edges_only=True) diff --git a/model-optimizer/extensions/front/tf/TFSliceToSlice_test.py b/model-optimizer/extensions/front/tf/TFSliceToSlice_test.py index 508b4156541fc3..14be81eb3c43cd 100644 --- a/model-optimizer/extensions/front/tf/TFSliceToSlice_test.py +++ b/model-optimizer/extensions/front/tf/TFSliceToSlice_test.py @@ -20,7 +20,7 @@ from extensions.front.tf.TFSliceToSlice import TFSliceToSliceReplacer from mo.utils.ir_engine.compare_graphs import compare_graphs -from mo.utils.unittest.graph import build_graph, regular_op_with_empty_data, result, const, connect +from mo.utils.unittest.graph import build_graph, regular_op_with_empty_data, result, const, connect_front nodes = { **regular_op_with_empty_data('input', {'type': 'Parameter'}), @@ -45,33 +45,33 @@ class SliceReplacerTest(unittest.TestCase): def test_slice_replacer_begin_with_2_inputs(self): graph = build_graph(nodes_attrs=nodes, edges=[ ('input', 'tfslice'), - *connect('begin:0', '1:tfslice', on_front=True), - *connect('begin:0', '0:john_doe', on_front=True), - *connect('size:0', '2:tfslice', on_front=True), - *connect('tfslice:0', 'output', on_front=True), + *connect_front('begin:0', '1:tfslice'), + *connect_front('begin:0', '0:john_doe'), + *connect_front('size:0', '2:tfslice'), + *connect_front('tfslice:0', 'output'), ], nodes_with_edges_only=True) graph.stage = 'front' TFSliceToSliceReplacer().find_and_replace_pattern(graph) graph_ref = build_graph(nodes_attrs=nodes, edges=[ - *connect('input:0', 'slice', on_front=True), - *connect('begin:0', 'slice:1', on_front=True), - *connect('begin:0', 'john_doe:1', on_front=True), + *connect_front('input:0', 'slice'), + *connect_front('begin:0', 'slice:1'), + *connect_front('begin:0', 'john_doe:1'), - *connect('begin:0', 'end_const:0', on_front=True), - *connect('size:0', 'end_const:1', on_front=True), - *connect('size:0', 'equal:0', on_front=True), + *connect_front('begin:0', 'end_const:0'), + *connect_front('size:0', 'end_const:1'), + *connect_front('size:0', 'equal:0'), - *connect('int32_max:0', 'select:1', on_front=True), - *connect('minus_one:0', 'equal:1', on_front=True), + *connect_front('int32_max:0', 'select:1'), + *connect_front('minus_one:0', 'equal:1'), - *connect('equal:0', 'select:0', on_front=True), + *connect_front('equal:0', 'select:0'), - *connect('end_const:0', 'select:2', on_front=True), - *connect('select:0', 'slice:2', on_front=True), + *connect_front('end_const:0', 'select:2'), + *connect_front('select:0', 'slice:2'), - *connect('slice:0', 'output', on_front=True), + *connect_front('slice:0', 'output'), ], nodes_with_edges_only=True) (flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True) @@ -79,27 +79,27 @@ def test_slice_replacer_begin_with_2_inputs(self): def test_slice_replacer(self): graph = build_graph(nodes_attrs=nodes, edges=[ - *connect('input:0', 'tfslice', on_front=True), - *connect('begin:0', '1:tfslice', on_front=True), - *connect('size:0', '2:tfslice', on_front=True), - *connect('tfslice:0', 'output', on_front=True), + *connect_front('input:0', 'tfslice'), + *connect_front('begin:0', '1:tfslice'), + *connect_front('size:0', '2:tfslice'), + *connect_front('tfslice:0', 'output'), ], nodes_with_edges_only=True) graph.stage = 'front' TFSliceToSliceReplacer().find_and_replace_pattern(graph) graph_ref = build_graph(nodes_attrs=nodes, edges=[ - *connect('input:0', 'slice', on_front=True), - *connect('begin:0', '1:slice', on_front=True), - *connect('begin:0', '0:end_const', on_front=True), - *connect('size:0', '1:end_const', on_front=True), - *connect('size:0', '0:equal', on_front=True), - *connect('int32_max:0', '1:select', on_front=True), - *connect('minus_one:0', '1:equal', on_front=True), - *connect('equal:0', '0:select', on_front=True), - *connect('end_const:0', '2:select', on_front=True), - *connect('select:0', '2:slice', on_front=True), - *connect('slice:0', 'output', on_front=True), + *connect_front('input:0', 'slice'), + *connect_front('begin:0', '1:slice'), + *connect_front('begin:0', '0:end_const'), + *connect_front('size:0', '1:end_const'), + *connect_front('size:0', '0:equal'), + *connect_front('int32_max:0', '1:select'), + *connect_front('minus_one:0', '1:equal'), + *connect_front('equal:0', '0:select'), + *connect_front('end_const:0', '2:select'), + *connect_front('select:0', '2:slice'), + *connect_front('slice:0', 'output'), ], nodes_with_edges_only=True) (flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True) diff --git a/model-optimizer/mo/utils/unittest/graph.py b/model-optimizer/mo/utils/unittest/graph.py index 55b024556ddbe4..3ee413cde47c34 100644 --- a/model-optimizer/mo/utils/unittest/graph.py +++ b/model-optimizer/mo/utils/unittest/graph.py @@ -125,8 +125,10 @@ def build_graph_with_attrs(nodes_with_attrs: list, edges_with_attrs: list, new_n for node_id in graph.nodes(): node = Node(graph, node_id) - check_and_update_ports(node, [graph.get_edge_data(edge[0], node_id)[0] for edge in graph.in_edges(node_id)], True) - check_and_update_ports(node, [graph.get_edge_data(node_id, edge[1])[0] for edge in graph.out_edges(node_id)], False) + check_and_update_ports(node, [graph.get_edge_data(edge[0], node_id)[0] for edge in graph.in_edges(node_id)], + True) + check_and_update_ports(node, [graph.get_edge_data(node_id, edge[1])[0] for edge in graph.out_edges(node_id)], + False) for node in graph.get_op_nodes(): # Add in_ports attribute @@ -330,17 +332,19 @@ def get_name_and_port(tensor_name): return node_name, 0 -def connect(first_tensor_name, second_tensor_name, skip_data=False, on_front=False): +def connect(first_tensor_name, second_tensor_name, skip_data=False, front_phase=False): # ports could be skipped -- then zero in/out ports would be used # first_tensor_name = first_op_name:out_port # second_tensor_name = in_port:second_op_name + # if skip_data is True connect directly from data node with postfix '_d' to second + # if front_phase is True connect nodes directly without postfixes and data nodes first_op_name, out_port = get_name_and_port(first_tensor_name) second_op_name, in_port = get_name_and_port(second_tensor_name) if skip_data: return [(first_op_name + '_d', second_op_name, {'in': in_port})] - if on_front: + if front_phase: return [(first_op_name, second_op_name, {'out': out_port, 'in': in_port})] return [ (first_op_name, first_op_name + '_d', {'out': out_port}), @@ -351,3 +355,6 @@ def connect(first_tensor_name, second_tensor_name, skip_data=False, on_front=Fal def connect_data(first_tensor_name, second_tensor_name): return connect(first_tensor_name, second_tensor_name, skip_data=True) + +def connect_front(first_tensor_name, second_tensor_name): + return connect(first_tensor_name, second_tensor_name, skip_data=False, front_phase=True)