From 60714ce40af3b89aa667600776728321c4577778 Mon Sep 17 00:00:00 2001 From: iliya mironov Date: Wed, 8 Sep 2021 14:48:59 +0300 Subject: [PATCH] Fix return values for lift_up_through func (#7323) * Fix return valuese for lift_up_through func * Update unit test * Refactoring code according to code review * Fix revers outputs * Fix unit test * Fix comment * Add multioutput support * Add unit test for cace with several output from ReverseChannel op * Fix distinantion connect --- .../extensions/back/ReverseInputChannels.py | 27 ++++++++++--------- .../back/ReverseInputChannels_test.py | 23 ++++++++++++++-- 2 files changed, 36 insertions(+), 14 deletions(-) diff --git a/model-optimizer/extensions/back/ReverseInputChannels.py b/model-optimizer/extensions/back/ReverseInputChannels.py index 088130e0bf26e1..ce0cbb952ae363 100644 --- a/model-optimizer/extensions/back/ReverseInputChannels.py +++ b/model-optimizer/extensions/back/ReverseInputChannels.py @@ -114,7 +114,7 @@ def pass_rc_through(node: Node, reverse_channels: Node): returns boolean value whatever we should continue propagating current ReverseChannels operation down or not """ # detaching reverse_channels node from the graph - if reverse_channels.is_in_port_connected(0) and reverse_channels.is_out_port_connected(0)\ + if reverse_channels.is_in_port_connected(0) and reverse_channels.is_out_port_connected(0) \ and node.is_out_port_connected(0): reverse_channels.out_port(0).get_connection().set_source( reverse_channels.in_port(0).get_connection().get_source()) @@ -137,7 +137,7 @@ def pass_rc_through_conv(node, reverse_channels): ReverseChannels weights previous_op ReverseChannels \ / \ / Conv Conv - + For grouped convolution: BEFORE AFTER @@ -295,12 +295,11 @@ class ReverseChannelsPropagationUp(BackReplacementPattern): 'Subtract': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_eltwise(node, rc), 'Pow': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_eltwise(node, rc), 'Convert': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_eltwise(node, rc), - - 'Pad': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through(node, rc), + 'Pad': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_pad(node, rc), } @staticmethod - def lift_up_through(node: Node, reverse_channels: Node): + def lift_up_through_pad(node: Node, reverse_channels: Node): r""" BEFORE AFTER @@ -308,25 +307,29 @@ def lift_up_through(node: Node, reverse_channels: Node): \ previous_op previous_op ReverseChannels previous_op \ / \ / - Node Node + Pad Pad | | ReverseChannels next_op | next_op - returns boolean value whatever we should continue propagating current ReverseChannels operation up or not + returns two objects: + first - boolean value whatever we should continue propagating current ReverseChannels operation up or not + second - list of ReverseChannels operations that were produced while propagating reverse_channels up """ if node.is_in_port_connected(0): node_input_port_0 = node.in_port(0) - reverse_channels_out_npde = reverse_channels.out_port(0).get_connection().get_destination().node + reverse_channels_out_nodes = reverse_channels.out_port(0).get_connection().get_destinations() reverse_channels.out_port(0).disconnect() - + reverse_channels.in_port(0).disconnect() src = node_input_port_0.get_connection().get_source() node_input_port_0.get_connection().set_source(reverse_channels.out_port(0)) src.connect(reverse_channels.in_port(0)) - node.out_port(0).get_connection().set_destination(reverse_channels_out_npde.in_port(0)) - return True - return False + for reverse_channels_destination in reverse_channels_out_nodes: + node.out_port(0).get_connection().add_destination(reverse_channels_destination) + + return True, [reverse_channels] + return False, [] @staticmethod def lift_up_through_eltwise(node: Node, reverse_channels: Node): diff --git a/model-optimizer/unit_tests/extensions/back/ReverseInputChannels_test.py b/model-optimizer/unit_tests/extensions/back/ReverseInputChannels_test.py index 8ac90c8708fdaf..634f3ea9aef39c 100644 --- a/model-optimizer/unit_tests/extensions/back/ReverseInputChannels_test.py +++ b/model-optimizer/unit_tests/extensions/back/ReverseInputChannels_test.py @@ -32,6 +32,7 @@ **regular_op_with_shaped_data('pad', [1, 3, 10, 10], {'type': 'Pad'}), **regular_op_with_shaped_data('reverse_channels', [1, 3, 10, 10], {'type': 'ReverseChannels', 'axis': 1}), **result('result'), + **result('result2'), } class ReverseInputChannelsTest(unittest.TestCase): @@ -64,7 +65,7 @@ def test_lift_up_through_eltwise(self): ReverseChannelsPropagationUp.lift_up_through_eltwise(node, reverse_channels) self.check_graph_attrs(graph, ['placeholder1', 'placeholder2']) - def test_lift_up_through(self): + def test_lift_up_through_pad(self): graph = build_graph(nodes2, [*connect('placeholder', '0:mul'), *connect('mul_const', '1:mul'), *connect('mul', '0:pad'), *connect('pad_const_1', '1:pad'), *connect('pad_const_2', '2:pad'), *connect('pad', 'reverse_channels'), @@ -74,7 +75,25 @@ def test_lift_up_through(self): node = Node(graph, 'pad') reverse_channels = Node(graph, 'reverse_channels') - ReverseChannelsPropagationUp.lift_up_through(node, reverse_channels) + keep_moving_up, new_reverses = ReverseChannelsPropagationUp.lift_up_through_pad(node, reverse_channels) + self.assertTrue(keep_moving_up is True) + self.assertTrue(len(new_reverses) == 1) + self.check_graph_attrs(graph, ['placeholder']) + + + def test_lift_up_through_pad2(self): + graph = build_graph(nodes2, [*connect('placeholder', '0:mul'), *connect('mul_const', '1:mul'), + *connect('mul', '0:pad'), *connect('pad_const_1', '1:pad'), + *connect('pad_const_2', '2:pad'), *connect('pad', 'reverse_channels'), + *connect('reverse_channels:0', '0:result'), *connect('reverse_channels:0', '0:result2')]) + self.set_graph_attrs(graph, ['placeholder']) + + node = Node(graph, 'pad') + reverse_channels = Node(graph, 'reverse_channels') + + keep_moving_up, new_reverses = ReverseChannelsPropagationUp.lift_up_through_pad(node, reverse_channels) + self.assertTrue(keep_moving_up is True) + self.assertTrue(len(new_reverses) == 1) self.check_graph_attrs(graph, ['placeholder'])