From fab8524999e6e5c387235f5d82ca3e8d69ba11bb Mon Sep 17 00:00:00 2001 From: Evgenya Stepyreva Date: Thu, 27 Aug 2020 09:45:08 +0300 Subject: [PATCH] [MO] Relax Reshape layer hardcode under MatMul (#1921) * [MO] Relax Reshape layer hardcode under MatMul * Memory fix --- .../extensions/back/MatMulNormalizer.py | 77 +++++++++++- .../extensions/back/MatMulNormalizer_test.py | 110 ++++++++++++++++++ .../mo/front/common/partial_infer/utils.py | 2 +- 3 files changed, 187 insertions(+), 2 deletions(-) create mode 100644 model-optimizer/extensions/back/MatMulNormalizer_test.py diff --git a/model-optimizer/extensions/back/MatMulNormalizer.py b/model-optimizer/extensions/back/MatMulNormalizer.py index bc5616db2ce401..d19d6b2b3a6102 100644 --- a/model-optimizer/extensions/back/MatMulNormalizer.py +++ b/model-optimizer/extensions/back/MatMulNormalizer.py @@ -18,11 +18,14 @@ from extensions.ops.transpose import Transpose from mo.back.replacement import BackReplacementPattern +from mo.front.caffe.extractors.utils import get_canonical_axis_index from mo.front.common.partial_infer.utils import int64_array from mo.front.tf.graph_utils import create_op_node_with_second_input -from mo.graph.graph import Graph +from mo.graph.graph import Graph, Node from mo.ops.const import Const +from mo.ops.shape import Shape from mo.ops.unsqueeze import Unsqueeze +from mo.utils.shape import node_to_get_shape_value_of_indices, new_shape_node_from_shape_nodes class MatMulConstTransposesExtraction(BackReplacementPattern): @@ -142,3 +145,75 @@ def replace_pattern(graph: Graph, match: dict): src = port.get_source() port.get_connection().set_source(transpose_copy.out_port(0)) src.connect(start_port) + + +class SmartReshape_HC_Reshape_MatMul(BackReplacementPattern): + """ + Relaxes hard-coded input of Reshape in such sub-graphs: + + input_1 Constant + \ / + Reshape input_2 + \ / + MatMul + | + """ + enabled = True + force_clean_up = True + + def run_after(self): + return [MatMulConstTransposesExtraction] + + def pattern(self): + return dict( + nodes=[ + ('output_shape', dict(type='Const')), + ('output_shape_d', dict()), + ('reshape', dict(type='Reshape')), + ('reshape_d', dict()), + ('other_input', dict(type=lambda t: t not in ['Reshape', 'Transpose'])), + ('other_input_d', dict()), + ('matmul', dict(type='MatMul')), + ], + edges=[ + ('output_shape', 'output_shape_d'), + ('output_shape_d', 'reshape', {'in': 1}), + ('reshape', 'reshape_d'), + ('reshape_d', 'matmul'), + ('other_input', 'other_input_d'), + ('other_input_d', 'matmul'), + ] + ) + + def replace_pattern(self, graph: Graph, match: dict): + matmul = match['matmul'] + reshape = match['reshape'] + other_input_port_idx = 0 if match['matmul'].in_port(0).get_source().node.id == match['other_input'].id else 1 + shape_source = match['matmul'].in_port(other_input_port_idx).get_source() + initial_reshape_pattern = reshape.in_port(1).data.get_value() + if len(initial_reshape_pattern) != 2: + return + + reshape_is_A_input = matmul.in_port(0).get_source().node.id == reshape.id + if reshape_is_A_input: + idx = -1 if matmul.transpose_b else -2 + else: + idx = -2 if matmul.transpose_a else -1 + idx = get_canonical_axis_index(initial_reshape_pattern, idx) + + shape_name = shape_source.node.soft_get('name', shape_source.node.id) + shape = Shape(graph, {'name': shape_name + '/Shape'}).create_node() + shape.in_port(0).connect(shape_source) + C = node_to_get_shape_value_of_indices(shape, [idx]) + N = Const(graph, {'name': shape_name + '/MinusOne', 'value': int64_array([-1])}).create_node() + + if len(initial_reshape_pattern) == 2: + if reshape_is_A_input: + reshape_pattern = [C, N] if matmul.transpose_a else [N, C] + else: + reshape_pattern = [N, C] if matmul.transpose_b else [C, N] + new_reshape_pattern = new_shape_node_from_shape_nodes(reshape_pattern) + reshape.in_port(1).get_connection().set_source(new_reshape_pattern.out_port(0)) + else: + return + diff --git a/model-optimizer/extensions/back/MatMulNormalizer_test.py b/model-optimizer/extensions/back/MatMulNormalizer_test.py new file mode 100644 index 00000000000000..8cb0505eaa224d --- /dev/null +++ b/model-optimizer/extensions/back/MatMulNormalizer_test.py @@ -0,0 +1,110 @@ +""" + Copyright (C) 2018-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import unittest +from argparse import Namespace + +from generator import generate, generator + +from extensions.back.MatMulNormalizer import SmartReshape_HC_Reshape_MatMul +from extensions.ops.MatMul import MatMul +from mo.front.common.partial_infer.utils import int64_array +from mo.ops.reshape import Reshape +from mo.utils.ir_engine.compare_graphs import compare_graphs +from mo.utils.unittest.graph import build_graph, regular_op_with_shaped_data, const_with_data, \ + result, connect +from mo.utils.unittest.graph import regular_op_with_empty_data as op_with_empty_data + + +@generator +class SmartReshape_HC_Reshape_MatMulTest(unittest.TestCase): + @generate( + *[ + ([1, 20, 30], [30, 40], [20, -1], False, False, [-1, 30]), + ([1, 20, 30], [40, 30], [20, -1], False, True, [-1, 30]), + ([1, 30, 20], [30, 40], [-1, 20], True, False, [30, -1]), + ([1, 30, 20], [40, 30], [-1, 20], True, True, [30, -1]), + ] + ) + def test_reshape_on_the_A_input(self, + in1_shape, in2_shape, reshape_pattern, transpose_a, transpose_b, updated_pattern): + nodes = { + **regular_op_with_shaped_data('in_1', in1_shape, dict(type='Parameter', op='Parameter')), + **regular_op_with_shaped_data('in_2', in2_shape, dict(type='Parameter', op='Parameter')), + **const_with_data('dim', int64_array(reshape_pattern)), + **op_with_empty_data('reshape', + dict(type='Reshape', op='Reshape', infer=Reshape.infer, need_shape_inference=True)), + **op_with_empty_data('matmul', + dict(type='MatMul', op='MatMul', infer=MatMul.infer, need_shape_inference=True, + transpose_a=transpose_a, transpose_b=transpose_b, dim_attrs={})), + **result(), + } + edges = [ + *connect('in_1:0', '0:reshape'), + *connect('dim:0', '1:reshape'), + *connect('reshape:0', '0:matmul'), + *connect('in_2:0', '1:matmul'), + *connect('matmul:0', 'output'), + ] + graph = build_graph(nodes_attrs=nodes, edges=edges, cli=Namespace(static_shape=True)) + graph.clean_up() + SmartReshape_HC_Reshape_MatMul().find_and_replace_pattern(graph) + graph.clean_up() + + graph_ref = build_graph(nodes_attrs=nodes, edges=edges, update_attributes={ + 'dim': {'value': int64_array(updated_pattern)}, 'dim_d': {'value': int64_array(updated_pattern)}}) + graph_ref.clean_up() + + (flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True) + self.assertTrue(flag, resp) + + @generate(*[ + ([20, 30], [1, 30, 40], [-1, 40], False, False, [30, -1]), + ([20, 30], [1, 40, 30], [40, -1], False, True, [-1, 30]), + ([30, 20], [1, 30, 40], [-1, 40], True, False, [30, -1]), + ([30, 20], [1, 40, 30], [40, -1], True, True, [-1, 30]), + ]) + def test_reshape_on_the_B_input(self, + in1_shape, in2_shape, reshape_pattern, transpose_a, transpose_b, updated_pattern): + nodes = { + **regular_op_with_shaped_data('in_1', in1_shape, dict(type='Parameter', op='Parameter')), + **regular_op_with_shaped_data('in_2', in2_shape, dict(type='Parameter', op='Parameter')), + **const_with_data('dim', int64_array(reshape_pattern)), + **op_with_empty_data('reshape', + dict(type='Reshape', op='Reshape', infer=Reshape.infer, need_shape_inference=True)), + **op_with_empty_data('matmul', + dict(type='MatMul', op='MatMul', infer=MatMul.infer, need_shape_inference=True, + transpose_a=transpose_a, transpose_b=transpose_b, dim_attrs={})), + **result(), + } + edges = [ + *connect('in_1:0', '0:matmul'), + *connect('in_2:0', '0:reshape'), + *connect('dim:0', '1:reshape'), + *connect('reshape:0', '1:matmul'), + *connect('matmul:0', 'output'), + ] + graph = build_graph(nodes_attrs=nodes, edges=edges, cli=Namespace(static_shape=True)) + graph.clean_up() + SmartReshape_HC_Reshape_MatMul().find_and_replace_pattern(graph) + graph.clean_up() + + graph_ref = build_graph(nodes_attrs=nodes, edges=edges, update_attributes={ + 'dim': {'value': int64_array(updated_pattern)}, 'dim_d': {'value': int64_array(updated_pattern)}}) + graph_ref.clean_up() + + (flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True) + self.assertTrue(flag, resp) diff --git a/model-optimizer/mo/front/common/partial_infer/utils.py b/model-optimizer/mo/front/common/partial_infer/utils.py index 4917c167e59d67..cbddae0ec292b1 100644 --- a/model-optimizer/mo/front/common/partial_infer/utils.py +++ b/model-optimizer/mo/front/common/partial_infer/utils.py @@ -44,7 +44,7 @@ def assign_dims_to_weights(node, spatial, input_channel, output_channel=None, di node['spatial_dims'] = np.array(spatial, dtype=np.int64) node['input_channel_dim'] = np.array(input_channel, dtype=np.int64) node['output_channel_dim'] = np.array(output_channel, dtype=np.int64) - if 'input_channel_dim' not in node['dim_attrs']: + if 'dim_attrs' in node and 'input_channel_dim' not in node['dim_attrs']: node['dim_attrs'].append('input_channel_dim') node['dims_number'] = dims_number