Skip to content

Commit

Permalink
refactored unit-tests for front transforms of Slice
Browse files Browse the repository at this point in the history
  • Loading branch information
pavel-esir committed Aug 5, 2020
1 parent 724b44e commit a28d5bc
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from extensions.front.onnx.AttributedSliceToSlice import AttributedSliceToSliceReplacer
from mo.utils.ir_engine.compare_graphs import compare_graphs
from mo.utils.unittest.graph import build_graph, regular_op_with_empty_data, result, const, connect
from mo.utils.unittest.graph import build_graph, regular_op_with_empty_data, result, const, connect_front


@generator
Expand Down Expand Up @@ -53,9 +53,9 @@ def test_attributed_slice_replacer(self, attributed_slice_attrs):

graph_ref = build_graph(nodes_attrs=nodes, edges=[
('input', 'slice'),
*connect('start', '1:slice', on_front=True),
*connect('end', '2:slice', on_front=True),
*connect('axis', '3:slice', on_front=True),
*connect_front('start', '1:slice'),
*connect_front('end', '2:slice'),
*connect_front('axis', '3:slice'),
('slice', 'output'),
], nodes_with_edges_only=True)

Expand Down
64 changes: 32 additions & 32 deletions model-optimizer/extensions/front/tf/TFSliceToSlice_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from extensions.front.tf.TFSliceToSlice import TFSliceToSliceReplacer
from mo.utils.ir_engine.compare_graphs import compare_graphs
from mo.utils.unittest.graph import build_graph, regular_op_with_empty_data, result, const, connect
from mo.utils.unittest.graph import build_graph, regular_op_with_empty_data, result, const, connect_front

nodes = {
**regular_op_with_empty_data('input', {'type': 'Parameter'}),
Expand All @@ -45,61 +45,61 @@ class SliceReplacerTest(unittest.TestCase):
def test_slice_replacer_begin_with_2_inputs(self):
graph = build_graph(nodes_attrs=nodes, edges=[
('input', 'tfslice'),
*connect('begin:0', '1:tfslice', on_front=True),
*connect('begin:0', '0:john_doe', on_front=True),
*connect('size:0', '2:tfslice', on_front=True),
*connect('tfslice:0', 'output', on_front=True),
*connect_front('begin:0', '1:tfslice'),
*connect_front('begin:0', '0:john_doe'),
*connect_front('size:0', '2:tfslice'),
*connect_front('tfslice:0', 'output'),
], nodes_with_edges_only=True)
graph.stage = 'front'

TFSliceToSliceReplacer().find_and_replace_pattern(graph)

graph_ref = build_graph(nodes_attrs=nodes, edges=[
*connect('input:0', 'slice', on_front=True),
*connect('begin:0', 'slice:1', on_front=True),
*connect('begin:0', 'john_doe:1', on_front=True),
*connect_front('input:0', 'slice'),
*connect_front('begin:0', 'slice:1'),
*connect_front('begin:0', 'john_doe:1'),

*connect('begin:0', 'end_const:0', on_front=True),
*connect('size:0', 'end_const:1', on_front=True),
*connect('size:0', 'equal:0', on_front=True),
*connect_front('begin:0', 'end_const:0'),
*connect_front('size:0', 'end_const:1'),
*connect_front('size:0', 'equal:0'),

*connect('int32_max:0', 'select:1', on_front=True),
*connect('minus_one:0', 'equal:1', on_front=True),
*connect_front('int32_max:0', 'select:1'),
*connect_front('minus_one:0', 'equal:1'),

*connect('equal:0', 'select:0', on_front=True),
*connect_front('equal:0', 'select:0'),

*connect('end_const:0', 'select:2', on_front=True),
*connect('select:0', 'slice:2', on_front=True),
*connect_front('end_const:0', 'select:2'),
*connect_front('select:0', 'slice:2'),

*connect('slice:0', 'output', on_front=True),
*connect_front('slice:0', 'output'),
], nodes_with_edges_only=True)

(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
self.assertTrue(flag, resp)

def test_slice_replacer(self):
graph = build_graph(nodes_attrs=nodes, edges=[
*connect('input:0', 'tfslice', on_front=True),
*connect('begin:0', '1:tfslice', on_front=True),
*connect('size:0', '2:tfslice', on_front=True),
*connect('tfslice:0', 'output', on_front=True),
*connect_front('input:0', 'tfslice'),
*connect_front('begin:0', '1:tfslice'),
*connect_front('size:0', '2:tfslice'),
*connect_front('tfslice:0', 'output'),
], nodes_with_edges_only=True)
graph.stage = 'front'

TFSliceToSliceReplacer().find_and_replace_pattern(graph)

graph_ref = build_graph(nodes_attrs=nodes, edges=[
*connect('input:0', 'slice', on_front=True),
*connect('begin:0', '1:slice', on_front=True),
*connect('begin:0', '0:end_const', on_front=True),
*connect('size:0', '1:end_const', on_front=True),
*connect('size:0', '0:equal', on_front=True),
*connect('int32_max:0', '1:select', on_front=True),
*connect('minus_one:0', '1:equal', on_front=True),
*connect('equal:0', '0:select', on_front=True),
*connect('end_const:0', '2:select', on_front=True),
*connect('select:0', '2:slice', on_front=True),
*connect('slice:0', 'output', on_front=True),
*connect_front('input:0', 'slice'),
*connect_front('begin:0', '1:slice'),
*connect_front('begin:0', '0:end_const'),
*connect_front('size:0', '1:end_const'),
*connect_front('size:0', '0:equal'),
*connect_front('int32_max:0', '1:select'),
*connect_front('minus_one:0', '1:equal'),
*connect_front('equal:0', '0:select'),
*connect_front('end_const:0', '2:select'),
*connect_front('select:0', '2:slice'),
*connect_front('slice:0', 'output'),
], nodes_with_edges_only=True)

(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
Expand Down
15 changes: 11 additions & 4 deletions model-optimizer/mo/utils/unittest/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,10 @@ def build_graph_with_attrs(nodes_with_attrs: list, edges_with_attrs: list, new_n

for node_id in graph.nodes():
node = Node(graph, node_id)
check_and_update_ports(node, [graph.get_edge_data(edge[0], node_id)[0] for edge in graph.in_edges(node_id)], True)
check_and_update_ports(node, [graph.get_edge_data(node_id, edge[1])[0] for edge in graph.out_edges(node_id)], False)
check_and_update_ports(node, [graph.get_edge_data(edge[0], node_id)[0] for edge in graph.in_edges(node_id)],
True)
check_and_update_ports(node, [graph.get_edge_data(node_id, edge[1])[0] for edge in graph.out_edges(node_id)],
False)

for node in graph.get_op_nodes():
# Add in_ports attribute
Expand Down Expand Up @@ -330,17 +332,19 @@ def get_name_and_port(tensor_name):
return node_name, 0


def connect(first_tensor_name, second_tensor_name, skip_data=False, on_front=False):
def connect(first_tensor_name, second_tensor_name, skip_data=False, front_phase=False):
# ports could be skipped -- then zero in/out ports would be used
# first_tensor_name = first_op_name:out_port
# second_tensor_name = in_port:second_op_name
# if skip_data is True connect directly from data node with postfix '_d' to second
# if front_phase is True connect nodes directly without postfixes and data nodes

first_op_name, out_port = get_name_and_port(first_tensor_name)
second_op_name, in_port = get_name_and_port(second_tensor_name)

if skip_data:
return [(first_op_name + '_d', second_op_name, {'in': in_port})]
if on_front:
if front_phase:
return [(first_op_name, second_op_name, {'out': out_port, 'in': in_port})]
return [
(first_op_name, first_op_name + '_d', {'out': out_port}),
Expand All @@ -351,3 +355,6 @@ def connect(first_tensor_name, second_tensor_name, skip_data=False, on_front=Fal
def connect_data(first_tensor_name, second_tensor_name):
return connect(first_tensor_name, second_tensor_name, skip_data=True)


def connect_front(first_tensor_name, second_tensor_name):
return connect(first_tensor_name, second_tensor_name, skip_data=False, front_phase=True)

0 comments on commit a28d5bc

Please sign in to comment.