Skip to content

Commit

Permalink
Revert "Fuse mul transformation fix (openvinotoolkit#5518)" (openvino…
Browse files Browse the repository at this point in the history
…toolkit#5831)

This reverts commit 84b94c9.
  • Loading branch information
Evgenya Stepyreva authored and rnugmanx committed Aug 26, 2021
1 parent d207916 commit fd58481
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 112 deletions.
37 changes: 3 additions & 34 deletions model-optimizer/mo/middle/passes/fusing/fuse_linear_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,40 +107,9 @@ def _fuse_mul(graph: Graph, node: Node, fuse_nodes: list, backward: bool = True)
w_mul = node.copy_node({'name': mul_name, 'in_ports_count': len(node.in_ports()),
'out_ports_count': len(node.out_ports()), 'can_be_fused': False})
w_mul.in_port(const_port.idx).connect(mul_const.out_port(0))

r"""
In this transformation we remove Mul or Div node (node) that goes after fuse_node and
create new Mul node (w_mul), connect it with the corrected const value (mul_const) and
insert w_mul before the fuse_node. So the input data of fuse_node becomes different.
For this reason we need to use set_destination from previous operation to w_mul which
guaranties that data node will be reused on previous_op -> w_mul connection and its
attributes won't be copied to the data node of w_mul -> fuse_node connection.
BEFORE AFTER
previous_op mul_const
\ /
previous_op w_mul
| |
fuse_node const fuse_node
\ / |
node next_op
|
next_op
"""
weights_port.get_connection().set_destination(w_mul.in_port(tensor_port.idx))
w_mul.out_port(0).connect(weights_port)

# As fusing is applied to convolutions it is important to keep 'permutation' and 'input_permutation' attributes
# which were obtained from original model. These attributes are stored on the incoming edge to the operation
# node and during the reconnection they are moved to the new connection. But during reconnection in this
# transformation these attributes are moved to the previous node. So we need manually set them at the
# incoming edge to fuse_node.
in_edge = w_mul.in_edge(tensor_port.idx)
if 'permutation' in in_edge:
fuse_node.in_edge(weights_port.idx)['permutation'] = in_edge['permutation']
if 'input_permutation' in in_edge:
fuse_node.in_edge(weights_port.idx)['input_permutation'] = in_edge['input_permutation']
w_const = weights_port.get_source()
weights_port.get_connection().set_source(w_mul.out_port(0))
w_const.connect(w_mul.in_port(tensor_port.idx))

# If we fuse in backward direction we should multiply biases if they exists
if backward and len(fuse_node.in_ports()) == 3 and not fuse_node.in_port(2).disconnected() and \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -817,84 +817,6 @@ def test_fuse_mul_to_deconv_1(self):
(flag, resp) = compare_graphs(graph, graph_ref, 'placeholder_1', 'placeholder_1')
self.assertTrue(flag, resp)

def test_fuse_mul_permutation_saving(self):
graph = build_graph(nodes_attributes,
[('placeholder_1', 'placeholder_1_data'),
('placeholder_1_data', 'mul_1'),
('const_mul_1_w', 'mul_1_w'),
('mul_1_w', 'mul_1'),
('mul_1', 'mul_1_data'),
('mul_1_data', 'conv_1'),
('const_conv_1_w', 'conv_1_w'),
('const_conv_1_b', 'conv_1_b'),
('conv_1_w', 'conv_1'),
('conv_1_b', 'conv_1'),
('conv_1', 'conv_1_data'),
('conv_1_data', 'op_output')
],
{'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
'const_mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
'mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
'const_conv_1_w': {'shape': np.array([11, 11, 3, 96]),
'value': np.ones((11, 11, 3, 96))},
'conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96)),
'output_channel_dim': 3, 'input_channel_dim': 2,
'dims_number': 4},
'const_conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
'conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
'conv_1_data': {}
})
conv_node = Node(graph, "conv_1")
for edge_idx in conv_node.in_edges():
conv_node.in_edge(edge_idx)['permutation'] = 'permutation_value' + str(edge_idx)
conv_node.in_edge(edge_idx)['input_permutation'] = 'input_permutation_value' + str(edge_idx)

_fuse_mul(graph, Node(graph, 'mul_1'), [Node(graph, 'conv_1')], backward=False)

conv_node = Node(graph, "conv_1")
for edge_idx in conv_node.in_edges():
self.assertTrue(conv_node.in_edge(edge_idx)['permutation'] == 'permutation_value' + str(edge_idx))
self.assertTrue(conv_node.in_edge(edge_idx)['input_permutation'] == 'input_permutation_value' + str(edge_idx))

def test_fuse_mul_data_nodes_names(self):
graph = build_graph(nodes_attributes,
[('placeholder_1', 'placeholder_1_data'),
('placeholder_1_data', 'mul_1'),
('const_mul_1_w', 'mul_1_w'),
('mul_1_w', 'mul_1'),
('mul_1', 'mul_1_data'),
('mul_1_data', 'conv_1'),
('const_conv_1_w', 'conv_1_w'),
('const_conv_1_b', 'conv_1_b'),
('conv_1_w', 'conv_1'),
('conv_1_b', 'conv_1'),
('conv_1', 'conv_1_data'),
('conv_1_data', 'op_output')
],
{'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
'const_mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
'mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
'const_conv_1_w': {'shape': np.array([11, 11, 3, 96]),
'value': np.ones((11, 11, 3, 96))},
'conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96)),
'output_channel_dim': 3, 'input_channel_dim': 2,
'dims_number': 4},
'const_conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
'conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
'conv_1_data': {}
})

_fuse_mul(graph, Node(graph, 'mul_1'), [Node(graph, 'conv_1')], backward=False)

conv_node = Node(graph, 'conv_1')
conv_in_data_name = conv_node.in_node(1)['name']
const_node = Node(graph, 'const_conv_1_w')
const_out_data_name = const_node.out_node(0)['name']

self.assertFalse(conv_in_data_name == const_out_data_name)


# Unit tests for fuse_linear_ops
class FuseLinOpsTests(unittest.TestCase):
Expand Down

0 comments on commit fd58481

Please sign in to comment.