diff --git a/model-optimizer/automation/package_BOM.txt b/model-optimizer/automation/package_BOM.txt index 939f75b685d292..e4dc75cb2142e2 100644 --- a/model-optimizer/automation/package_BOM.txt +++ b/model-optimizer/automation/package_BOM.txt @@ -29,7 +29,6 @@ extensions/back/GroupedConvWeightsNormalize.py extensions/back/insert_compatibility_l2normalization.py extensions/back/InterpolateReshape.py extensions/back/kaldi_remove_memory_output.py -extensions/back/LayoutChangeForEinsum.py extensions/back/LayoutChangeForGatherND.py extensions/back/LeakyReLUMutation.py extensions/back/LinearToLinearONNXReplacer.py @@ -597,6 +596,7 @@ extensions/middle/InsertSelect.py extensions/middle/InterpolateSequenceToInterpolate.py extensions/middle/L2NormFusing.py extensions/middle/LayoutChangeForConstantShapePaths.py +extensions/middle/LayoutChangeForEinsum.py extensions/middle/LeakyReluPattern.py extensions/middle/LSTMRNNSequenceToTensorIterator.py extensions/middle/MakeKaldiConstReshapable.py diff --git a/model-optimizer/extensions/back/LayoutChangeForEinsum.py b/model-optimizer/extensions/middle/LayoutChangeForEinsum.py similarity index 58% rename from model-optimizer/extensions/back/LayoutChangeForEinsum.py rename to model-optimizer/extensions/middle/LayoutChangeForEinsum.py index f45bff54b931c5..aab33d3f0a762b 100644 --- a/model-optimizer/extensions/back/LayoutChangeForEinsum.py +++ b/model-optimizer/extensions/middle/LayoutChangeForEinsum.py @@ -1,12 +1,14 @@ # Copyright (C) 2018-2021 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +from extensions.middle.InsertLayoutPropagationTransposes import is_input_data_in_correct_layout, \ + is_output_data_in_correct_layout from extensions.ops.einsum import Einsum -from mo.back.replacement import BackReplacementPattern from mo.graph.graph import Graph +from mo.middle.replacement import MiddleReplacementPattern -class LayoutChangeForEinsum(BackReplacementPattern): +class LayoutChangeForEinsum(MiddleReplacementPattern): """ The transformation adjusts Einsum equation to NCHW layout. Subscripts for tensor of rank greater than three must be adjusted @@ -19,7 +21,15 @@ class LayoutChangeForEinsum(BackReplacementPattern): """ enabled = True force_shape_inference = True - graph_condition = [lambda graph: graph.graph['fw'] == 'tf'] + graph_condition = [lambda graph: graph.graph['layout'] == 'NHWC'] + + def run_after(self): + from extensions.middle.MarkSubgraphsWithCorrectLayout import MarkSubGraphsWithCorrectLayout + return [MarkSubGraphsWithCorrectLayout] + + def run_before(self): + from extensions.middle.InsertLayoutPropagationTransposes import InsertLayoutPropagationTranspose + return [InsertLayoutPropagationTranspose] def find_and_replace_pattern(self, graph: Graph): import extensions.middle.InsertLayoutPropagationTransposes as InsertTransposes @@ -31,27 +41,35 @@ def find_and_replace_pattern(self, graph: Graph): connected_in_ports = [port for port in einsum.in_ports().values() if not port.disconnected()] num_inputs = len(connected_in_ports) - # compute a mask of inputs of rank greater than 3 that are required original layout (NCHW) - # due to presence of ellipsis covering multiple tail dimensions in the corresponding input subscript + # check if correct_data_layout attribute is set for inputs and output + # this attribute can be set up within MarkSubgraphWithCorrectLayout transformation + # for example, when Einsum is located near to MatMul operation in a graph + input_correct_layout_mask = [] + for input_ind in range(num_inputs): + input_correct_layout_mask.append(is_input_data_in_correct_layout(einsum, input_ind)) + is_output_layout_correct = is_output_data_in_correct_layout(einsum, 0) + + # compute a mask of which inputs/output are adjusted to the required layout + # if they are not adjusted, it means to require transpose input_ranks = [len(einsum.in_port(port_idx).data.get_shape()) for port_idx in range(num_inputs)] output_rank = len(einsum.out_port(0).data.get_shape()) - permuted_equation, is_inputs_permuted, is_output_permuted = Einsum.adjust_equation_with_NCHW_layout( + permuted_equation, are_inputs_adjusted, is_output_adjusted = Einsum.adjust_equation_with_NCHW_layout( einsum_name, equation, input_ranks, - output_rank) - assert len(is_inputs_permuted) == num_inputs + output_rank, input_correct_layout_mask, is_output_layout_correct) + assert len(are_inputs_adjusted) == num_inputs # setup adjusted equation einsum.equation = permuted_equation # insert Transpose node to get NHWC layout back (for inputs) that is required due to specifics of equation for input_ind in range(num_inputs): - if not is_inputs_permuted[input_ind]: + if not are_inputs_adjusted[input_ind]: # that means Einsum can only accept input in NHWC layout # so the inserted transpose before the Einsum will convert the layout to NHWC InsertTransposes.insert_transpose(graph, einsum.in_port(input_ind), before_input=True) - if not is_output_permuted: + if not is_output_adjusted: # that means Einsum can only generate output in NHWC layout # so the inserted transpose followed after the output will convert the layout back into NCHW layout InsertTransposes.insert_transpose(graph, einsum.out_port(0), before_input=False) diff --git a/model-optimizer/extensions/ops/einsum.py b/model-optimizer/extensions/ops/einsum.py index a30bb96bdfcb89..a37c60def7c76b 100644 --- a/model-optimizer/extensions/ops/einsum.py +++ b/model-optimizer/extensions/ops/einsum.py @@ -137,7 +137,8 @@ def extract_subscript_labels(node_name: str, subscript: str) -> list: return labels @staticmethod - def adjust_equation_with_NCHW_layout(node_name: str, equation: str, input_ranks: list, output_rank: int) -> ( + def adjust_equation_with_NCHW_layout(node_name: str, equation: str, input_ranks: list, output_rank: int, + input_correct_layout_mask: list, output_correct_layout_mask: bool) -> ( str, list, bool): """ In order to satisfy NCHW layout, subscripts for tensors with rank greater than three must be adjusted by moving labels @@ -151,11 +152,13 @@ def adjust_equation_with_NCHW_layout(node_name: str, equation: str, input_ranks: :param output_rank: output rank :return: adjusted equation, boolean mask for inputs, and boolean flag if output subscript is adjusted """ - is_inputs_permuted = [] + is_inputs_adjusted = [] input_subscripts, output_subscript = Einsum.parse_equation(node_name, equation) num_inputs = len(input_ranks) assert len(input_subscripts) == num_inputs, "The number of inputs must match a number " \ "of input subscripts" + assert len(input_correct_layout_mask) == num_inputs, "The number of inputs must match a number " \ + "elements in input_correct_layout_mask list" # permute labels in input subscripts and mark inputs for which inference in NCHW layout is acceptable # in case ellipsis covering multiple dimensions in the end, the permutation is impossible @@ -166,31 +169,35 @@ def adjust_equation_with_NCHW_layout(node_name: str, equation: str, input_ranks: input_rank = input_ranks[input_ind] labels = Einsum.extract_subscript_labels(node_name, input_subscript) num_broadcasted_dims = input_rank - len(labels) + 1 - if input_rank > 3 and (labels[-1] != "..." or labels[-1] == "..." and num_broadcasted_dims == 1): - is_inputs_permuted.append(True) + if input_correct_layout_mask[input_ind]: + is_inputs_adjusted.append(True) + elif input_rank > 3 and (labels[-1] != "..." or labels[-1] == "..." and num_broadcasted_dims == 1): + is_inputs_adjusted.append(True) labels.insert(1, labels[-1]) del labels[-1] else: - is_inputs_permuted.append(False) + is_inputs_adjusted.append(False) permuted_input_subscript = ''.join(labels) permuted_input_subscripts.append(permuted_input_subscript) # perform the same procedure for the output subscript as for the inputs subscripts labels = Einsum.extract_subscript_labels(node_name, output_subscript) num_broadcasted_dims = output_rank - len(labels) + 1 - if output_rank > 3 and (labels[-1] != "..." or labels[-1] == "..." and num_broadcasted_dims == 1): - is_output_permuted = True + if output_correct_layout_mask: + is_output_adjusted = True + elif output_rank > 3 and (labels[-1] != "..." or labels[-1] == "..." and num_broadcasted_dims == 1): + is_output_adjusted = True labels.insert(1, labels[-1]) del labels[-1] else: - is_output_permuted = False + is_output_adjusted = False permuted_output_subscript = ''.join(labels) # concatenate the left and right hands of the resulted equation left_hand = ','.join(permuted_input_subscripts) right_hand = permuted_output_subscript permuted_equation = left_hand + "->" + right_hand - return permuted_equation, is_inputs_permuted, is_output_permuted + return permuted_equation, is_inputs_adjusted, is_output_adjusted @staticmethod def infer(node: Node): diff --git a/model-optimizer/unit_tests/extensions/back/LayoutChangeForEinsum_test.py b/model-optimizer/unit_tests/extensions/middle/LayoutChangeForEinsum_test.py similarity index 59% rename from model-optimizer/unit_tests/extensions/back/LayoutChangeForEinsum_test.py rename to model-optimizer/unit_tests/extensions/middle/LayoutChangeForEinsum_test.py index 45e0f2badabf20..aa908ab04f7304 100644 --- a/model-optimizer/unit_tests/extensions/back/LayoutChangeForEinsum_test.py +++ b/model-optimizer/unit_tests/extensions/middle/LayoutChangeForEinsum_test.py @@ -5,7 +5,7 @@ import numpy as np -from extensions.back.LayoutChangeForEinsum import LayoutChangeForEinsum +from extensions.middle.LayoutChangeForEinsum import LayoutChangeForEinsum from mo.front.common.partial_infer.utils import int64_array from mo.utils.ir_engine.compare_graphs import compare_graphs from unit_tests.utils.graph import build_graph, result, regular_op_with_shaped_data, valued_const_with_data, connect @@ -47,7 +47,7 @@ def test_layout_change_einsum(self): # this input does not require additional transpose # since the corresponding subscript can be adjusted 'placeholder_2_d': {'shape': np.array([3, 8, 5, 7])}, - # [3, 5, 10, 12] - NHWC, [3, 12, 5, 10] - NCHW + # [3, 8, 10, 12] - NHWC, [3, 12, 8, 10] - NCHW # the third input must be transposed to NHWC layout # since ellipsis covers multiple dimensions in the end # the corresponding subscript is not changed @@ -60,7 +60,7 @@ def test_layout_change_einsum(self): # and additional transpose to NCHW will be inserted 'einsum_d': {'shape': np.array([2, 12, 7, 8, 10])}, }, nodes_with_edges_only=True) - graph.graph['fw'] = 'tf' + graph.graph['layout'] = 'NHWC' graph_ref = build_graph(nodes_attributes, [*connect('placeholder_3', '0:transpose_1'), @@ -80,3 +80,46 @@ def test_layout_change_einsum(self): LayoutChangeForEinsum().find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True) self.assertTrue(flag, resp) + + def test_no_adjustment_layout_einsum(self): + graph = build_graph(nodes_attributes, + [*connect('placeholder_1', '0:einsum'), + *connect('placeholder_2', '1:einsum'), + *connect('placeholder_3', '2:einsum'), + *connect('einsum', 'output')], + { # this input stays as is since it is of a rank equal to 3 + 'placeholder_1_d': {'shape': np.array([2, 3, 5])}, + # [3, 5, 7, 8] - NHWC + # this input does not require additional transpose + # since the corresponding layout is correct + 'placeholder_2_d': {'shape': np.array([3, 5, 7, 8])}, + # [3, 8, 10, 12] - NHWC + # this input does not require additional transpose + # since the corresponding layout is correct + 'placeholder_3_d': {'shape': np.array([3, 8, 10, 12])}, + # equation is still for NHWC layout + 'einsum': {'equation': "abc,bcde,bc...->ade...", + 'correct_in_data_layout': [0, 1, 2], + 'correct_out_data_layout': [0]}, + # [2, 7, 8, 10, 12] - NHWC + # this output does not require additional transpose + # since the corresponding layout is correct + 'einsum_d': {'shape': np.array([2, 7, 8, 10, 12])}, + }, nodes_with_edges_only=True) + graph.graph['layout'] = 'NHWC' + + graph_ref = build_graph(nodes_attributes, + [*connect('placeholder_1', '0:einsum'), + *connect('placeholder_2', '1:einsum'), + *connect('placeholder_3', '2:einsum'), + *connect('einsum', 'output')], + {'placeholder_1_d': {'shape': np.array([2, 3, 5])}, + 'placeholder_2_d': {'shape': np.array([3, 5, 7, 8])}, + 'placeholder_3_d': {'shape': np.array([3, 8, 10, 12])}, + 'einsum': {'equation': "abc,bcde,bc...->ade..."}, + 'einsum_d': {'shape': np.array([2, 7, 8, 10, 12])} + }) + + LayoutChangeForEinsum().find_and_replace_pattern(graph) + (flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True) + self.assertTrue(flag, resp)