Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correct layout for Einsum inputs and output #6696

2 changes: 1 addition & 1 deletion model-optimizer/automation/package_BOM.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ extensions/back/GroupedConvWeightsNormalize.py
extensions/back/insert_compatibility_l2normalization.py
extensions/back/InterpolateReshape.py
extensions/back/kaldi_remove_memory_output.py
extensions/back/LayoutChangeForEinsum.py
extensions/back/LayoutChangeForGatherND.py
extensions/back/LeakyReLUMutation.py
extensions/back/LinearToLinearONNXReplacer.py
Expand Down Expand Up @@ -597,6 +596,7 @@ extensions/middle/InsertSelect.py
extensions/middle/InterpolateSequenceToInterpolate.py
extensions/middle/L2NormFusing.py
extensions/middle/LayoutChangeForConstantShapePaths.py
extensions/middle/LayoutChangeForEinsum.py
extensions/middle/LeakyReluPattern.py
extensions/middle/LSTMRNNSequenceToTensorIterator.py
extensions/middle/MakeKaldiConstReshapable.py
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from extensions.middle.InsertLayoutPropagationTransposes import InsertLayoutPropagationTranspose, \
is_input_data_in_correct_layout, is_output_data_in_correct_layout
from extensions.ops.einsum import Einsum
from mo.back.replacement import BackReplacementPattern
from mo.graph.graph import Graph
from mo.middle.replacement import MiddleReplacementPattern


class LayoutChangeForEinsum(BackReplacementPattern):
class LayoutChangeForEinsum(MiddleReplacementPattern):
"""
The transformation adjusts Einsum equation to NCHW layout.
Subscripts for tensor of rank greater than three must be adjusted
Expand All @@ -19,7 +21,13 @@ class LayoutChangeForEinsum(BackReplacementPattern):
"""
enabled = True
force_shape_inference = True
graph_condition = [lambda graph: graph.graph['fw'] == 'tf']
graph_condition = [lambda graph: graph.graph['layout'] == 'NHWC']

def run_after(self):
return [InsertLayoutPropagationTranspose]

def run_before(self):
return []
rkazants marked this conversation as resolved.
Show resolved Hide resolved

def find_and_replace_pattern(self, graph: Graph):
import extensions.middle.InsertLayoutPropagationTransposes as InsertTransposes
Expand All @@ -31,27 +39,33 @@ def find_and_replace_pattern(self, graph: Graph):
connected_in_ports = [port for port in einsum.in_ports().values() if not port.disconnected()]
num_inputs = len(connected_in_ports)

# check if correct_data_layout attribute is set for inputs and output
input_correct_layout_mask = []
for input_ind in range(num_inputs):
input_correct_layout_mask.append(is_input_data_in_correct_layout(einsum, input_ind))
output_correct_layout_mask = is_output_data_in_correct_layout(einsum, 0)
rkazants marked this conversation as resolved.
Show resolved Hide resolved

# compute a mask of inputs of rank greater than 3 that are required original layout (NCHW)
# due to presence of ellipsis covering multiple tail dimensions in the corresponding input subscript
input_ranks = [len(einsum.in_port(port_idx).data.get_shape()) for port_idx in range(num_inputs)]
output_rank = len(einsum.out_port(0).data.get_shape())
permuted_equation, is_inputs_permuted, is_output_permuted = Einsum.adjust_equation_with_NCHW_layout(
permuted_equation, is_inputs_adjusted, is_output_adjusted = Einsum.adjust_equation_with_NCHW_layout(
einsum_name,
equation,
input_ranks,
output_rank)
assert len(is_inputs_permuted) == num_inputs
output_rank, input_correct_layout_mask, output_correct_layout_mask)
assert len(is_inputs_adjusted) == num_inputs

# setup adjusted equation
einsum.equation = permuted_equation

# insert Transpose node to get NHWC layout back (for inputs) that is required due to specifics of equation
for input_ind in range(num_inputs):
if not is_inputs_permuted[input_ind]:
if not is_inputs_adjusted[input_ind]:
# that means Einsum can only accept input in NHWC layout
# so the inserted transpose before the Einsum will convert the layout to NHWC
InsertTransposes.insert_transpose(graph, einsum.in_port(input_ind), before_input=True)
if not is_output_permuted:
if not is_output_adjusted:
# that means Einsum can only generate output in NHWC layout
# so the inserted transpose followed after the output will convert the layout back into NCHW layout
InsertTransposes.insert_transpose(graph, einsum.out_port(0), before_input=False)
25 changes: 16 additions & 9 deletions model-optimizer/extensions/ops/einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,8 @@ def extract_subscript_labels(node_name: str, subscript: str) -> list:
return labels

@staticmethod
def adjust_equation_with_NCHW_layout(node_name: str, equation: str, input_ranks: list, output_rank: int) -> (
def adjust_equation_with_NCHW_layout(node_name: str, equation: str, input_ranks: list, output_rank: int,
input_correct_layout_mask: list, output_correct_layout_mask: bool) -> (
str, list, bool):
"""
In order to satisfy NCHW layout, subscripts for tensors with rank greater than three must be adjusted by moving labels
Expand All @@ -151,11 +152,13 @@ def adjust_equation_with_NCHW_layout(node_name: str, equation: str, input_ranks:
:param output_rank: output rank
:return: adjusted equation, boolean mask for inputs, and boolean flag if output subscript is adjusted
"""
is_inputs_permuted = []
is_inputs_adjusted = []
input_subscripts, output_subscript = Einsum.parse_equation(node_name, equation)
num_inputs = len(input_ranks)
assert len(input_subscripts) == num_inputs, "The number of inputs must match a number " \
"of input subscripts"
assert len(input_correct_layout_mask) == num_inputs, "The number of inputs must match a number " \
"elements in input_correct_layout_mask list"

# permute labels in input subscripts and mark inputs for which inference in NCHW layout is acceptable
# in case ellipsis covering multiple dimensions in the end, the permutation is impossible
Expand All @@ -166,31 +169,35 @@ def adjust_equation_with_NCHW_layout(node_name: str, equation: str, input_ranks:
input_rank = input_ranks[input_ind]
labels = Einsum.extract_subscript_labels(node_name, input_subscript)
num_broadcasted_dims = input_rank - len(labels) + 1
if input_rank > 3 and (labels[-1] != "..." or labels[-1] == "..." and num_broadcasted_dims == 1):
is_inputs_permuted.append(True)
if input_correct_layout_mask[input_ind]:
is_inputs_adjusted.append(True)
elif input_rank > 3 and (labels[-1] != "..." or labels[-1] == "..." and num_broadcasted_dims == 1):
is_inputs_adjusted.append(True)
labels.insert(1, labels[-1])
del labels[-1]
else:
is_inputs_permuted.append(False)
is_inputs_adjusted.append(False)
permuted_input_subscript = ''.join(labels)
permuted_input_subscripts.append(permuted_input_subscript)

# perform the same procedure for the output subscript as for the inputs subscripts
labels = Einsum.extract_subscript_labels(node_name, output_subscript)
num_broadcasted_dims = output_rank - len(labels) + 1
if output_rank > 3 and (labels[-1] != "..." or labels[-1] == "..." and num_broadcasted_dims == 1):
is_output_permuted = True
if output_correct_layout_mask:
is_output_adjusted = True
elif output_rank > 3 and (labels[-1] != "..." or labels[-1] == "..." and num_broadcasted_dims == 1):
is_output_adjusted = True
labels.insert(1, labels[-1])
del labels[-1]
else:
is_output_permuted = False
is_output_adjusted = False
permuted_output_subscript = ''.join(labels)

# concatenate the left and right hands of the resulted equation
left_hand = ','.join(permuted_input_subscripts)
right_hand = permuted_output_subscript
permuted_equation = left_hand + "->" + right_hand
return permuted_equation, is_inputs_permuted, is_output_permuted
return permuted_equation, is_inputs_adjusted, is_output_adjusted

@staticmethod
def infer(node: Node):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import numpy as np

from extensions.back.LayoutChangeForEinsum import LayoutChangeForEinsum
from extensions.middle.LayoutChangeForEinsum import LayoutChangeForEinsum
from mo.front.common.partial_infer.utils import int64_array
from mo.utils.ir_engine.compare_graphs import compare_graphs
from unit_tests.utils.graph import build_graph, result, regular_op_with_shaped_data, valued_const_with_data, connect
Expand Down Expand Up @@ -47,7 +47,7 @@ def test_layout_change_einsum(self):
# this input does not require additional transpose
# since the corresponding subscript can be adjusted
'placeholder_2_d': {'shape': np.array([3, 8, 5, 7])},
# [3, 5, 10, 12] - NHWC, [3, 12, 5, 10] - NCHW
# [3, 8, 10, 12] - NHWC, [3, 12, 8, 10] - NCHW
# the third input must be transposed to NHWC layout
# since ellipsis covers multiple dimensions in the end
# the corresponding subscript is not changed
Expand All @@ -60,7 +60,7 @@ def test_layout_change_einsum(self):
# and additional transpose to NCHW will be inserted
'einsum_d': {'shape': np.array([2, 12, 7, 8, 10])},
}, nodes_with_edges_only=True)
graph.graph['fw'] = 'tf'
graph.graph['layout'] = 'NHWC'

graph_ref = build_graph(nodes_attributes,
[*connect('placeholder_3', '0:transpose_1'),
Expand All @@ -80,3 +80,46 @@ def test_layout_change_einsum(self):
LayoutChangeForEinsum().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
self.assertTrue(flag, resp)

def test_no_adjustment_layout_einsum(self):
graph = build_graph(nodes_attributes,
[*connect('placeholder_1', '0:einsum'),
*connect('placeholder_2', '1:einsum'),
*connect('placeholder_3', '2:einsum'),
*connect('einsum', 'output')],
{ # this input stays as is since it is of a rank equal to 3
'placeholder_1_d': {'shape': np.array([2, 3, 5])},
# [3, 5, 7, 8] - NHWC
# this input does not require additional transpose
# since the corresponding layout is correct
'placeholder_2_d': {'shape': np.array([3, 5, 7, 8])},
# [3, 8, 10, 12] - NHWC
# this input does not require additional transpose
# since the corresponding layout is correct
'placeholder_3_d': {'shape': np.array([3, 8, 10, 12])},
# equation is still for NHWC layout
'einsum': {'equation': "abc,bcde,bc...->ade...",
'correct_in_data_layout': [0, 1, 2],
'correct_out_data_layout': [0]},
# [2, 7, 8, 10, 12] - NHWC
# this output does not require additional transpose
# since the corresponding layout is correct
'einsum_d': {'shape': np.array([2, 7, 8, 10, 12])},
}, nodes_with_edges_only=True)
graph.graph['layout'] = 'NHWC'

graph_ref = build_graph(nodes_attributes,
[*connect('placeholder_1', '0:einsum'),
*connect('placeholder_2', '1:einsum'),
*connect('placeholder_3', '2:einsum'),
*connect('einsum', 'output')],
{'placeholder_1_d': {'shape': np.array([2, 3, 5])},
'placeholder_2_d': {'shape': np.array([3, 5, 7, 8])},
'placeholder_3_d': {'shape': np.array([3, 8, 10, 12])},
'einsum': {'equation': "abc,bcde,bc...->ade..."},
'einsum_d': {'shape': np.array([2, 7, 8, 10, 12])}
})

LayoutChangeForEinsum().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
self.assertTrue(flag, resp)