Skip to content

Commit

Permalink
Correct layout for Einsum inputs and output (openvinotoolkit#6696)
Browse files Browse the repository at this point in the history
* Fix recovery of output subscript in Einsum implicit mode

Signed-off-by: Roman Kazantsev <[email protected]>

* Fix code style

Signed-off-by: Roman Kazantsev <[email protected]>

* Correct layout adjustment for Einsum inputs and output

Signed-off-by: Roman Kazantsev <[email protected]>

* Correct a comment in the unit-test

Signed-off-by: Roman Kazantsev <[email protected]>

* Setup correct transformation dependencies for LayoutChangeForEinsum

Signed-off-by: Roman Kazantsev <[email protected]>
  • Loading branch information
rkazants authored and rnugmanx committed Aug 26, 2021
1 parent b0300c9 commit af95fa8
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 23 deletions.
2 changes: 1 addition & 1 deletion model-optimizer/automation/package_BOM.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ extensions/back/GroupedConvWeightsNormalize.py
extensions/back/insert_compatibility_l2normalization.py
extensions/back/InterpolateReshape.py
extensions/back/kaldi_remove_memory_output.py
extensions/back/LayoutChangeForEinsum.py
extensions/back/LayoutChangeForGatherND.py
extensions/back/LeakyReLUMutation.py
extensions/back/LinearToLinearONNXReplacer.py
Expand Down Expand Up @@ -597,6 +596,7 @@ extensions/middle/InsertSelect.py
extensions/middle/InterpolateSequenceToInterpolate.py
extensions/middle/L2NormFusing.py
extensions/middle/LayoutChangeForConstantShapePaths.py
extensions/middle/LayoutChangeForEinsum.py
extensions/middle/LeakyReluPattern.py
extensions/middle/LSTMRNNSequenceToTensorIterator.py
extensions/middle/MakeKaldiConstReshapable.py
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from extensions.middle.InsertLayoutPropagationTransposes import is_input_data_in_correct_layout, \
is_output_data_in_correct_layout
from extensions.ops.einsum import Einsum
from mo.back.replacement import BackReplacementPattern
from mo.graph.graph import Graph
from mo.middle.replacement import MiddleReplacementPattern


class LayoutChangeForEinsum(BackReplacementPattern):
class LayoutChangeForEinsum(MiddleReplacementPattern):
"""
The transformation adjusts Einsum equation to NCHW layout.
Subscripts for tensor of rank greater than three must be adjusted
Expand All @@ -19,7 +21,15 @@ class LayoutChangeForEinsum(BackReplacementPattern):
"""
enabled = True
force_shape_inference = True
graph_condition = [lambda graph: graph.graph['fw'] == 'tf']
graph_condition = [lambda graph: graph.graph['layout'] == 'NHWC']

def run_after(self):
from extensions.middle.MarkSubgraphsWithCorrectLayout import MarkSubGraphsWithCorrectLayout
return [MarkSubGraphsWithCorrectLayout]

def run_before(self):
from extensions.middle.InsertLayoutPropagationTransposes import InsertLayoutPropagationTranspose
return [InsertLayoutPropagationTranspose]

def find_and_replace_pattern(self, graph: Graph):
import extensions.middle.InsertLayoutPropagationTransposes as InsertTransposes
Expand All @@ -31,27 +41,35 @@ def find_and_replace_pattern(self, graph: Graph):
connected_in_ports = [port for port in einsum.in_ports().values() if not port.disconnected()]
num_inputs = len(connected_in_ports)

# compute a mask of inputs of rank greater than 3 that are required original layout (NCHW)
# due to presence of ellipsis covering multiple tail dimensions in the corresponding input subscript
# check if correct_data_layout attribute is set for inputs and output
# this attribute can be set up within MarkSubgraphWithCorrectLayout transformation
# for example, when Einsum is located near to MatMul operation in a graph
input_correct_layout_mask = []
for input_ind in range(num_inputs):
input_correct_layout_mask.append(is_input_data_in_correct_layout(einsum, input_ind))
is_output_layout_correct = is_output_data_in_correct_layout(einsum, 0)

# compute a mask of which inputs/output are adjusted to the required layout
# if they are not adjusted, it means to require transpose
input_ranks = [len(einsum.in_port(port_idx).data.get_shape()) for port_idx in range(num_inputs)]
output_rank = len(einsum.out_port(0).data.get_shape())
permuted_equation, is_inputs_permuted, is_output_permuted = Einsum.adjust_equation_with_NCHW_layout(
permuted_equation, are_inputs_adjusted, is_output_adjusted = Einsum.adjust_equation_with_NCHW_layout(
einsum_name,
equation,
input_ranks,
output_rank)
assert len(is_inputs_permuted) == num_inputs
output_rank, input_correct_layout_mask, is_output_layout_correct)
assert len(are_inputs_adjusted) == num_inputs

# setup adjusted equation
einsum.equation = permuted_equation

# insert Transpose node to get NHWC layout back (for inputs) that is required due to specifics of equation
for input_ind in range(num_inputs):
if not is_inputs_permuted[input_ind]:
if not are_inputs_adjusted[input_ind]:
# that means Einsum can only accept input in NHWC layout
# so the inserted transpose before the Einsum will convert the layout to NHWC
InsertTransposes.insert_transpose(graph, einsum.in_port(input_ind), before_input=True)
if not is_output_permuted:
if not is_output_adjusted:
# that means Einsum can only generate output in NHWC layout
# so the inserted transpose followed after the output will convert the layout back into NCHW layout
InsertTransposes.insert_transpose(graph, einsum.out_port(0), before_input=False)
25 changes: 16 additions & 9 deletions model-optimizer/extensions/ops/einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,8 @@ def extract_subscript_labels(node_name: str, subscript: str) -> list:
return labels

@staticmethod
def adjust_equation_with_NCHW_layout(node_name: str, equation: str, input_ranks: list, output_rank: int) -> (
def adjust_equation_with_NCHW_layout(node_name: str, equation: str, input_ranks: list, output_rank: int,
input_correct_layout_mask: list, output_correct_layout_mask: bool) -> (
str, list, bool):
"""
In order to satisfy NCHW layout, subscripts for tensors with rank greater than three must be adjusted by moving labels
Expand All @@ -151,11 +152,13 @@ def adjust_equation_with_NCHW_layout(node_name: str, equation: str, input_ranks:
:param output_rank: output rank
:return: adjusted equation, boolean mask for inputs, and boolean flag if output subscript is adjusted
"""
is_inputs_permuted = []
is_inputs_adjusted = []
input_subscripts, output_subscript = Einsum.parse_equation(node_name, equation)
num_inputs = len(input_ranks)
assert len(input_subscripts) == num_inputs, "The number of inputs must match a number " \
"of input subscripts"
assert len(input_correct_layout_mask) == num_inputs, "The number of inputs must match a number " \
"elements in input_correct_layout_mask list"

# permute labels in input subscripts and mark inputs for which inference in NCHW layout is acceptable
# in case ellipsis covering multiple dimensions in the end, the permutation is impossible
Expand All @@ -166,31 +169,35 @@ def adjust_equation_with_NCHW_layout(node_name: str, equation: str, input_ranks:
input_rank = input_ranks[input_ind]
labels = Einsum.extract_subscript_labels(node_name, input_subscript)
num_broadcasted_dims = input_rank - len(labels) + 1
if input_rank > 3 and (labels[-1] != "..." or labels[-1] == "..." and num_broadcasted_dims == 1):
is_inputs_permuted.append(True)
if input_correct_layout_mask[input_ind]:
is_inputs_adjusted.append(True)
elif input_rank > 3 and (labels[-1] != "..." or labels[-1] == "..." and num_broadcasted_dims == 1):
is_inputs_adjusted.append(True)
labels.insert(1, labels[-1])
del labels[-1]
else:
is_inputs_permuted.append(False)
is_inputs_adjusted.append(False)
permuted_input_subscript = ''.join(labels)
permuted_input_subscripts.append(permuted_input_subscript)

# perform the same procedure for the output subscript as for the inputs subscripts
labels = Einsum.extract_subscript_labels(node_name, output_subscript)
num_broadcasted_dims = output_rank - len(labels) + 1
if output_rank > 3 and (labels[-1] != "..." or labels[-1] == "..." and num_broadcasted_dims == 1):
is_output_permuted = True
if output_correct_layout_mask:
is_output_adjusted = True
elif output_rank > 3 and (labels[-1] != "..." or labels[-1] == "..." and num_broadcasted_dims == 1):
is_output_adjusted = True
labels.insert(1, labels[-1])
del labels[-1]
else:
is_output_permuted = False
is_output_adjusted = False
permuted_output_subscript = ''.join(labels)

# concatenate the left and right hands of the resulted equation
left_hand = ','.join(permuted_input_subscripts)
right_hand = permuted_output_subscript
permuted_equation = left_hand + "->" + right_hand
return permuted_equation, is_inputs_permuted, is_output_permuted
return permuted_equation, is_inputs_adjusted, is_output_adjusted

@staticmethod
def infer(node: Node):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import numpy as np

from extensions.back.LayoutChangeForEinsum import LayoutChangeForEinsum
from extensions.middle.LayoutChangeForEinsum import LayoutChangeForEinsum
from mo.front.common.partial_infer.utils import int64_array
from mo.utils.ir_engine.compare_graphs import compare_graphs
from unit_tests.utils.graph import build_graph, result, regular_op_with_shaped_data, valued_const_with_data, connect
Expand Down Expand Up @@ -47,7 +47,7 @@ def test_layout_change_einsum(self):
# this input does not require additional transpose
# since the corresponding subscript can be adjusted
'placeholder_2_d': {'shape': np.array([3, 8, 5, 7])},
# [3, 5, 10, 12] - NHWC, [3, 12, 5, 10] - NCHW
# [3, 8, 10, 12] - NHWC, [3, 12, 8, 10] - NCHW
# the third input must be transposed to NHWC layout
# since ellipsis covers multiple dimensions in the end
# the corresponding subscript is not changed
Expand All @@ -60,7 +60,7 @@ def test_layout_change_einsum(self):
# and additional transpose to NCHW will be inserted
'einsum_d': {'shape': np.array([2, 12, 7, 8, 10])},
}, nodes_with_edges_only=True)
graph.graph['fw'] = 'tf'
graph.graph['layout'] = 'NHWC'

graph_ref = build_graph(nodes_attributes,
[*connect('placeholder_3', '0:transpose_1'),
Expand All @@ -80,3 +80,46 @@ def test_layout_change_einsum(self):
LayoutChangeForEinsum().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
self.assertTrue(flag, resp)

def test_no_adjustment_layout_einsum(self):
graph = build_graph(nodes_attributes,
[*connect('placeholder_1', '0:einsum'),
*connect('placeholder_2', '1:einsum'),
*connect('placeholder_3', '2:einsum'),
*connect('einsum', 'output')],
{ # this input stays as is since it is of a rank equal to 3
'placeholder_1_d': {'shape': np.array([2, 3, 5])},
# [3, 5, 7, 8] - NHWC
# this input does not require additional transpose
# since the corresponding layout is correct
'placeholder_2_d': {'shape': np.array([3, 5, 7, 8])},
# [3, 8, 10, 12] - NHWC
# this input does not require additional transpose
# since the corresponding layout is correct
'placeholder_3_d': {'shape': np.array([3, 8, 10, 12])},
# equation is still for NHWC layout
'einsum': {'equation': "abc,bcde,bc...->ade...",
'correct_in_data_layout': [0, 1, 2],
'correct_out_data_layout': [0]},
# [2, 7, 8, 10, 12] - NHWC
# this output does not require additional transpose
# since the corresponding layout is correct
'einsum_d': {'shape': np.array([2, 7, 8, 10, 12])},
}, nodes_with_edges_only=True)
graph.graph['layout'] = 'NHWC'

graph_ref = build_graph(nodes_attributes,
[*connect('placeholder_1', '0:einsum'),
*connect('placeholder_2', '1:einsum'),
*connect('placeholder_3', '2:einsum'),
*connect('einsum', 'output')],
{'placeholder_1_d': {'shape': np.array([2, 3, 5])},
'placeholder_2_d': {'shape': np.array([3, 5, 7, 8])},
'placeholder_3_d': {'shape': np.array([3, 8, 10, 12])},
'einsum': {'equation': "abc,bcde,bc...->ade..."},
'einsum_d': {'shape': np.array([2, 7, 8, 10, 12])}
})

LayoutChangeForEinsum().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
self.assertTrue(flag, resp)

0 comments on commit af95fa8

Please sign in to comment.