diff --git a/model-optimizer/mo/middle/passes/fusing/fuse_linear_ops.py b/model-optimizer/mo/middle/passes/fusing/fuse_linear_ops.py index 14d12daae91912..0c021b19e2fb91 100644 --- a/model-optimizer/mo/middle/passes/fusing/fuse_linear_ops.py +++ b/model-optimizer/mo/middle/passes/fusing/fuse_linear_ops.py @@ -107,40 +107,9 @@ def _fuse_mul(graph: Graph, node: Node, fuse_nodes: list, backward: bool = True) w_mul = node.copy_node({'name': mul_name, 'in_ports_count': len(node.in_ports()), 'out_ports_count': len(node.out_ports()), 'can_be_fused': False}) w_mul.in_port(const_port.idx).connect(mul_const.out_port(0)) - - r""" - In this transformation we remove Mul or Div node (node) that goes after fuse_node and - create new Mul node (w_mul), connect it with the corrected const value (mul_const) and - insert w_mul before the fuse_node. So the input data of fuse_node becomes different. - For this reason we need to use set_destination from previous operation to w_mul which - guaranties that data node will be reused on previous_op -> w_mul connection and its - attributes won't be copied to the data node of w_mul -> fuse_node connection. - - BEFORE AFTER - - previous_op mul_const - \ / - previous_op w_mul - | | - fuse_node const fuse_node - \ / | - node next_op - | - next_op - """ - weights_port.get_connection().set_destination(w_mul.in_port(tensor_port.idx)) - w_mul.out_port(0).connect(weights_port) - - # As fusing is applied to convolutions it is important to keep 'permutation' and 'input_permutation' attributes - # which were obtained from original model. These attributes are stored on the incoming edge to the operation - # node and during the reconnection they are moved to the new connection. But during reconnection in this - # transformation these attributes are moved to the previous node. So we need manually set them at the - # incoming edge to fuse_node. - in_edge = w_mul.in_edge(tensor_port.idx) - if 'permutation' in in_edge: - fuse_node.in_edge(weights_port.idx)['permutation'] = in_edge['permutation'] - if 'input_permutation' in in_edge: - fuse_node.in_edge(weights_port.idx)['input_permutation'] = in_edge['input_permutation'] + w_const = weights_port.get_source() + weights_port.get_connection().set_source(w_mul.out_port(0)) + w_const.connect(w_mul.in_port(tensor_port.idx)) # If we fuse in backward direction we should multiply biases if they exists if backward and len(fuse_node.in_ports()) == 3 and not fuse_node.in_port(2).disconnected() and \ diff --git a/model-optimizer/unit_tests/mo/middle/passes/fusing/fuse_linear_ops_test.py b/model-optimizer/unit_tests/mo/middle/passes/fusing/fuse_linear_ops_test.py index f16b0d4210ef68..6626d61dba0690 100644 --- a/model-optimizer/unit_tests/mo/middle/passes/fusing/fuse_linear_ops_test.py +++ b/model-optimizer/unit_tests/mo/middle/passes/fusing/fuse_linear_ops_test.py @@ -817,84 +817,6 @@ def test_fuse_mul_to_deconv_1(self): (flag, resp) = compare_graphs(graph, graph_ref, 'placeholder_1', 'placeholder_1') self.assertTrue(flag, resp) - def test_fuse_mul_permutation_saving(self): - graph = build_graph(nodes_attributes, - [('placeholder_1', 'placeholder_1_data'), - ('placeholder_1_data', 'mul_1'), - ('const_mul_1_w', 'mul_1_w'), - ('mul_1_w', 'mul_1'), - ('mul_1', 'mul_1_data'), - ('mul_1_data', 'conv_1'), - ('const_conv_1_w', 'conv_1_w'), - ('const_conv_1_b', 'conv_1_b'), - ('conv_1_w', 'conv_1'), - ('conv_1_b', 'conv_1'), - ('conv_1', 'conv_1_data'), - ('conv_1_data', 'op_output') - ], - {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])}, - 'mul_1_data': {'shape': np.array([1, 227, 227, 3])}, - 'const_mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])}, - 'mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])}, - 'const_conv_1_w': {'shape': np.array([11, 11, 3, 96]), - 'value': np.ones((11, 11, 3, 96))}, - 'conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96)), - 'output_channel_dim': 3, 'input_channel_dim': 2, - 'dims_number': 4}, - 'const_conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)}, - 'conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)}, - 'conv_1_data': {} - }) - conv_node = Node(graph, "conv_1") - for edge_idx in conv_node.in_edges(): - conv_node.in_edge(edge_idx)['permutation'] = 'permutation_value' + str(edge_idx) - conv_node.in_edge(edge_idx)['input_permutation'] = 'input_permutation_value' + str(edge_idx) - - _fuse_mul(graph, Node(graph, 'mul_1'), [Node(graph, 'conv_1')], backward=False) - - conv_node = Node(graph, "conv_1") - for edge_idx in conv_node.in_edges(): - self.assertTrue(conv_node.in_edge(edge_idx)['permutation'] == 'permutation_value' + str(edge_idx)) - self.assertTrue(conv_node.in_edge(edge_idx)['input_permutation'] == 'input_permutation_value' + str(edge_idx)) - - def test_fuse_mul_data_nodes_names(self): - graph = build_graph(nodes_attributes, - [('placeholder_1', 'placeholder_1_data'), - ('placeholder_1_data', 'mul_1'), - ('const_mul_1_w', 'mul_1_w'), - ('mul_1_w', 'mul_1'), - ('mul_1', 'mul_1_data'), - ('mul_1_data', 'conv_1'), - ('const_conv_1_w', 'conv_1_w'), - ('const_conv_1_b', 'conv_1_b'), - ('conv_1_w', 'conv_1'), - ('conv_1_b', 'conv_1'), - ('conv_1', 'conv_1_data'), - ('conv_1_data', 'op_output') - ], - {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])}, - 'mul_1_data': {'shape': np.array([1, 227, 227, 3])}, - 'const_mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])}, - 'mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])}, - 'const_conv_1_w': {'shape': np.array([11, 11, 3, 96]), - 'value': np.ones((11, 11, 3, 96))}, - 'conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96)), - 'output_channel_dim': 3, 'input_channel_dim': 2, - 'dims_number': 4}, - 'const_conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)}, - 'conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)}, - 'conv_1_data': {} - }) - - _fuse_mul(graph, Node(graph, 'mul_1'), [Node(graph, 'conv_1')], backward=False) - - conv_node = Node(graph, 'conv_1') - conv_in_data_name = conv_node.in_node(1)['name'] - const_node = Node(graph, 'const_conv_1_w') - const_out_data_name = const_node.out_node(0)['name'] - - self.assertFalse(conv_in_data_name == const_out_data_name) - # Unit tests for fuse_linear_ops class FuseLinOpsTests(unittest.TestCase):