Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/master' into openvino-cmake-co…
Browse files Browse the repository at this point in the history
…nfig
  • Loading branch information
ilya-lavrenov committed Sep 8, 2021
2 parents ef84fde + 60714ce commit 1f2bb08
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 14 deletions.
27 changes: 15 additions & 12 deletions model-optimizer/extensions/back/ReverseInputChannels.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def pass_rc_through(node: Node, reverse_channels: Node):
returns boolean value whatever we should continue propagating current ReverseChannels operation down or not
"""
# detaching reverse_channels node from the graph
if reverse_channels.is_in_port_connected(0) and reverse_channels.is_out_port_connected(0)\
if reverse_channels.is_in_port_connected(0) and reverse_channels.is_out_port_connected(0) \
and node.is_out_port_connected(0):
reverse_channels.out_port(0).get_connection().set_source(
reverse_channels.in_port(0).get_connection().get_source())
Expand All @@ -137,7 +137,7 @@ def pass_rc_through_conv(node, reverse_channels):
ReverseChannels weights previous_op ReverseChannels
\ / \ /
Conv Conv
For grouped convolution:
BEFORE AFTER
Expand Down Expand Up @@ -295,38 +295,41 @@ class ReverseChannelsPropagationUp(BackReplacementPattern):
'Subtract': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_eltwise(node, rc),
'Pow': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_eltwise(node, rc),
'Convert': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_eltwise(node, rc),

'Pad': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through(node, rc),
'Pad': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_pad(node, rc),
}

@staticmethod
def lift_up_through(node: Node, reverse_channels: Node):
def lift_up_through_pad(node: Node, reverse_channels: Node):
r"""
BEFORE AFTER
previous_op
\
previous_op previous_op ReverseChannels previous_op
\ / \ /
Node Node
Pad Pad
| |
ReverseChannels next_op
|
next_op
returns boolean value whatever we should continue propagating current ReverseChannels operation up or not
returns two objects:
first - boolean value whatever we should continue propagating current ReverseChannels operation up or not
second - list of ReverseChannels operations that were produced while propagating reverse_channels up
"""
if node.is_in_port_connected(0):
node_input_port_0 = node.in_port(0)
reverse_channels_out_npde = reverse_channels.out_port(0).get_connection().get_destination().node
reverse_channels_out_nodes = reverse_channels.out_port(0).get_connection().get_destinations()
reverse_channels.out_port(0).disconnect()

reverse_channels.in_port(0).disconnect()
src = node_input_port_0.get_connection().get_source()
node_input_port_0.get_connection().set_source(reverse_channels.out_port(0))
src.connect(reverse_channels.in_port(0))
node.out_port(0).get_connection().set_destination(reverse_channels_out_npde.in_port(0))
return True
return False
for reverse_channels_destination in reverse_channels_out_nodes:
node.out_port(0).get_connection().add_destination(reverse_channels_destination)

return True, [reverse_channels]
return False, []

@staticmethod
def lift_up_through_eltwise(node: Node, reverse_channels: Node):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
**regular_op_with_shaped_data('pad', [1, 3, 10, 10], {'type': 'Pad'}),
**regular_op_with_shaped_data('reverse_channels', [1, 3, 10, 10], {'type': 'ReverseChannels', 'axis': 1}),
**result('result'),
**result('result2'),
}

class ReverseInputChannelsTest(unittest.TestCase):
Expand Down Expand Up @@ -64,7 +65,7 @@ def test_lift_up_through_eltwise(self):
ReverseChannelsPropagationUp.lift_up_through_eltwise(node, reverse_channels)
self.check_graph_attrs(graph, ['placeholder1', 'placeholder2'])

def test_lift_up_through(self):
def test_lift_up_through_pad(self):
graph = build_graph(nodes2, [*connect('placeholder', '0:mul'), *connect('mul_const', '1:mul'),
*connect('mul', '0:pad'), *connect('pad_const_1', '1:pad'),
*connect('pad_const_2', '2:pad'), *connect('pad', 'reverse_channels'),
Expand All @@ -74,7 +75,25 @@ def test_lift_up_through(self):
node = Node(graph, 'pad')
reverse_channels = Node(graph, 'reverse_channels')

ReverseChannelsPropagationUp.lift_up_through(node, reverse_channels)
keep_moving_up, new_reverses = ReverseChannelsPropagationUp.lift_up_through_pad(node, reverse_channels)
self.assertTrue(keep_moving_up is True)
self.assertTrue(len(new_reverses) == 1)
self.check_graph_attrs(graph, ['placeholder'])


def test_lift_up_through_pad2(self):
graph = build_graph(nodes2, [*connect('placeholder', '0:mul'), *connect('mul_const', '1:mul'),
*connect('mul', '0:pad'), *connect('pad_const_1', '1:pad'),
*connect('pad_const_2', '2:pad'), *connect('pad', 'reverse_channels'),
*connect('reverse_channels:0', '0:result'), *connect('reverse_channels:0', '0:result2')])
self.set_graph_attrs(graph, ['placeholder'])

node = Node(graph, 'pad')
reverse_channels = Node(graph, 'reverse_channels')

keep_moving_up, new_reverses = ReverseChannelsPropagationUp.lift_up_through_pad(node, reverse_channels)
self.assertTrue(keep_moving_up is True)
self.assertTrue(len(new_reverses) == 1)
self.check_graph_attrs(graph, ['placeholder'])


Expand Down

0 comments on commit 1f2bb08

Please sign in to comment.