Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MO] StridedSlice improvements #4139

Merged
merged 34 commits into from
Feb 16, 2021
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
2bce1da
fix ss
pavel-esir Jan 21, 2021
eabf5f3
successfully converted
pavel-esir Jan 27, 2021
166976e
successfully run moved infer and normalizer unit-tests
pavel-esir Jan 28, 2021
04df120
successfully rewritten StridedSlice infer unittests
pavel-esir Jan 29, 2021
1486d87
int64 array
pavel-esir Jan 30, 2021
dde8937
Successfully converter crash-when-loading, xj_feauture and toy nets (…
pavel-esir Feb 1, 2021
ab3b7b2
successfully moved PermuteAttrs to general mechanism
pavel-esir Feb 1, 2021
e9aa2d7
successfully converted xj_feauture and crash when loading with the ne…
pavel-esir Feb 2, 2021
b7b5dd0
fixed get_shape_from_slice and moved to common utils
pavel-esir Feb 2, 2021
e79d72f
fixed extending masks and some other
pavel-esir Feb 3, 2021
003efb0
some refactoring
pavel-esir Feb 3, 2021
25da477
fixed extending masks in extractor, fixed licence year and some other…
pavel-esir Feb 3, 2021
5cf59a8
corrected a couple of unittests
pavel-esir Feb 3, 2021
37fc9c5
fox permute for 5 rank slice and 4 rank inputs/
pavel-esir Feb 3, 2021
2e01ad4
WIP
pavel-esir Feb 4, 2021
9fd9b2a
Added comments
pavel-esir Feb 5, 2021
bb2fdb1
fixed StridedSlice in ProposalMutation.py
pavel-esir Feb 5, 2021
44e6087
rechecked shape_infer unittests added some new cases
pavel-esir Feb 9, 2021
020bd68
added shape_infer unit-tests after StridedSliceNormalizer pass and Pe…
pavel-esir Feb 9, 2021
f4f1406
corrected unittests
pavel-esir Feb 9, 2021
f3aa124
Applied review comments
pavel-esir Feb 9, 2021
c95280d
general permutations for inputs implemented, corrected ellipsis unrol…
pavel-esir Feb 11, 2021
cc365f4
removed code duplication in infer and normalizer, moved 'slices' attr…
pavel-esir Feb 11, 2021
e88bdcb
removed some code duplication and other minor improvements
pavel-esir Feb 11, 2021
b755a41
Added tests
pavel-esir Feb 11, 2021
4425c56
Merge branch 'master' into strided_slice_fix
pavel-esir Feb 11, 2021
376378a
minor corrections
pavel-esir Feb 11, 2021
b4496de
wider range of unittests added (froze the number)
pavel-esir Feb 15, 2021
d59be03
review comments applied
pavel-esir Feb 15, 2021
2115e73
Merge remote-tracking branch 'upstream/master' into strided_slice_fix
pavel-esir Feb 15, 2021
d01b0e2
enabled skipped unit-test
pavel-esir Feb 15, 2021
73046d7
comment corrections
pavel-esir Feb 15, 2021
a8da55a
applied review comments: changed op -> type, added some asserts, corr…
pavel-esir Feb 15, 2021
b359de5
sorted inputs, updated Supported_Frameworks_Layers.md, some minor
pavel-esir Feb 15, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion model-optimizer/automation/package_BOM.txt
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,7 @@ extensions/middle/SliceLikeToStridedSlice.py
extensions/middle/sparse_reshape.py
extensions/middle/split_tdnn_memoryoffset.py
extensions/middle/SplitConcatPairToInterpolate.py
extensions/middle/StridedSliceNormalizer.py
extensions/middle/SwapAxesMiddleReplacer.py
extensions/middle/TensorIterator_utils.py
extensions/middle/TensorIteratorBackEdge.py
Expand Down Expand Up @@ -800,7 +801,6 @@ mo/front/common/partial_infer/multi_box_prior.py
mo/front/common/partial_infer/random_uniform.py
mo/front/common/partial_infer/reshape.py
mo/front/common/partial_infer/roipooling.py
mo/front/common/partial_infer/slice.py
mo/front/common/partial_infer/utils.py
mo/front/common/register_custom_ops.py
mo/front/common/replacement.py
Expand Down
10 changes: 6 additions & 4 deletions model-optimizer/extensions/back/CropToStridedSlice.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Copyright (C) 2018-2021 Intel Corporation

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -64,8 +64,10 @@ def replace_pattern(self, graph: Graph, match: [str, Node]):
end_mask = axis_mask.copy()

ss = StridedSlice(graph, {'name': node.soft_get('name', node.id) + '/strided_slice', 'begin_mask': begin_mask,
'end_mask': end_mask, 'new_axis_mask': np.array([0]),
'shrink_axis_mask': np.array([0]), 'ellipsis_mask': np.array([0])}).create_node()
'end_mask': end_mask,
'new_axis_mask': np.zeros(len(end_mask)),
'shrink_axis_mask': np.zeros(len(end_mask)),
'ellipsis_mask': np.zeros(len(end_mask))}).create_node()

if len(node.in_nodes()) == 2 and node.has_valid('offset'):
# Crop Type 1
Expand Down Expand Up @@ -112,7 +114,7 @@ def replace_pattern(self, graph: Graph, match: [str, Node]):
source = node.in_port(0).get_connection().get_source()

stride = Const(graph, {'value': np.ones(shape_rank, dtype=np.int64),
'name': ss.name + '/stride'}).create_node()
'name': ss.name + '/stride'}).create_node()

source.connect(ss.in_port(0))
begin.out_port(0).connect(ss.in_port(1))
Expand Down
6 changes: 3 additions & 3 deletions model-optimizer/extensions/back/ProposalMutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ def replace_pattern(graph: Graph, match: dict):
{'name': 'cropped_im_info',
'begin_mask': int64_array([1, 1]),
'end_mask': int64_array([1, 1]),
'new_axis_mask': int64_array([0]),
'shrink_axis_mask': int64_array([0]),
'ellipsis_mask': int64_array([0]),
'new_axis_mask': int64_array([0, 0]),
'shrink_axis_mask': int64_array([0, 0]),
'ellipsis_mask': int64_array([0, 0]),
'override_output_shape': True,
})

Expand Down
29 changes: 10 additions & 19 deletions model-optimizer/extensions/back/StridedSliceMasksNormalizer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Copyright (C) 2018-2021 Intel Corporation

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -14,32 +14,23 @@
limitations under the License.
"""

from extensions.back.ConvolutionNormalizer import DeconvolutionNormalizer
from extensions.back.CropToStridedSlice import CropToStridedSlice
from mo.back.replacement import BackReplacementPattern
from mo.front.common.partial_infer.utils import int64_array
from mo.graph.graph import Graph, Node
from mo.graph.graph import Graph


class StridedSliceMasksNormalizer(BackReplacementPattern):
enabled = True
force_clean_up = True

def run_after(self):
from extensions.back.ConvolutionNormalizer import DeconvolutionNormalizer
from extensions.back.CropToStridedSlice import CropToStridedSlice
return [CropToStridedSlice, DeconvolutionNormalizer]

@staticmethod
def pattern():
return dict(
nodes=[
('strided_slice', dict(type='StridedSlice'))
],
edges=[]
)

def replace_pattern(self, graph: Graph, match: [str, Node]):
node = match['strided_slice']
assert node.has_valid('begin_mask')
assert node.has_valid('end_mask')
node.begin_mask = int64_array([1 - i for i in node.begin_mask])
node.end_mask = int64_array([1 - i for i in node.end_mask])
def find_and_replace_pattern(self, graph: Graph):
for node in graph.get_op_nodes(type='StridedSlice'):
assert node.has_valid('begin_mask')
assert node.has_valid('end_mask')
node.begin_mask = int64_array([1 - i for i in node.begin_mask])
node.end_mask = int64_array([1 - i for i in node.end_mask])
7 changes: 3 additions & 4 deletions model-optimizer/extensions/middle/ApplyPermutations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Copyright (C) 2018-2021 Intel Corporation

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -133,15 +133,14 @@ def permute_input_data(graph: Graph):
input_permutations = [(in_port, edge_attrs['input_permutation']) for in_port, edge_attrs in
node.in_edges().items() if edge_attrs.get('input_permutation') is not None]
for in_port, input_perm in input_permutations:
permutation, port_info = input_perm
permutation, port_info, check_shape = input_perm
direction, port = port_info.split(':')
port = int(port)
port_to_check = node.in_port(port) if direction == 'input' else node.out_port(port)
permutation_data_node = get_node_with_permutation(node, port_info)

if permutation_data_node.has_and_set('permutation') and \
not is_input_data_in_correct_layout(node, in_port) and \
len(port_to_check.data.get_shape()) >= 4:
not is_input_data_in_correct_layout(node, in_port) and check_shape(port_to_check):
permutation(node, port_info, in_port)
if node.has_and_set('need_shape_inference'):
node.infer(node)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ class ConvertGroupedStridedSlice(MiddleReplacementPattern):
enabled = True

def run_after(self):
return [ConvertSlice]
from extensions.middle.StridedSliceNormalizer import StridedSliceNormalizer
return [ConvertSlice, StridedSliceNormalizer]

def run_before(self):
from extensions.middle.pass_separator import MiddleFinish
Expand Down
251 changes: 251 additions & 0 deletions model-optimizer/extensions/middle/StridedSliceNormalizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
"""
Copyright (C) 2018-2021 Intel Corporation

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import numpy as np

from extensions.ops.split import VariadicSplit
from mo.front.common.partial_infer.utils import int64_array
from mo.ops.strided_slice import StridedSlice
pavel-esir marked this conversation as resolved.
Show resolved Hide resolved
from mo.front.tf.graph_utils import create_op_with_const_inputs
from mo.graph.graph import Graph, Node
from mo.graph.perm_inputs import PermuteInputs
from mo.middle.replacement import MiddleReplacementPattern
from mo.ops.concat import Concat
from mo.utils.error import Error
from mo.ops.const import Const
from mo.ops.op import PermuteAttrs


class StridedSliceNormalizer(MiddleReplacementPattern):
"""
StridedSlice is not normal if it cannot be permuted by ApplyPermutations. This normalizer
inserts blank colons ':' in slice expression so that it can be correctly permuted
from NHWC to NCHW layout. It changes masks and inserts blank begin, end and strides values.
In order to successfully handle StridedSlice in ShapeOf subgraphs
changes must be done by inserting nodes not just by overwriting constants.

StridedSlice is not normal in 2 cases:
1. rank of a slice expression is less than rank of input tensor
2. there is an ellipsis

1st case example
BEFORE:
|
begin
value=[0, 0]
|

AFTER:
|
begin Const
value=[0, 0] value=[0, 0]
\ /
\ /
Concat
value=[0, 0, 0, 0]
|

Input of a shape [16, 100, 100, 3] in NHWC layout, output = input[:, 0:50].
StridedSlice will be extended to input[:, 0:50, :, :].
After permutation to NCHW output = input[:, :, 0:50, :].
Example for 'begin' input transformation is shown above on the picture.
'end' and 'strides' inputs will be transformed the same way.

2nd case example
BEFORE:
|
begin
value=[1, 50]
|

AFTER:
|
begin
value=[1, 1, 1]
|
VariadicSplit
/ \
/ \
/ Const \
\ val=[0, 0] /
\ | /
\ | /
Concat
value=[1, 0, 0, 1, 1]
|

Input of a shape [16, 10, 100, 100, 3] in NDHWC layout, output = input[1:4, ..., 1:51, 1:3],
output_shape = [3, 10, 100, 50, 2]. In order to perform correct layout permutation
ellipsis must be replaced with colons: input[1:4, ..., 1:51, 1:3] => input[1:4, :, :, 1:51, 1:3].
After layour permutation input[1:4, 1:3, :, : 1:5].

In the places of colons blank begin, end and strides values should be inserted.
In order to do that we split input and insert blank zeros to the middle.
Example for 'begin' input transformation is shown above on the picture.
'end' and 'strides' inputs will be transformed the same way.
"""
enabled = True
pavel-esir marked this conversation as resolved.
Show resolved Hide resolved

def run_before(self):
from extensions.middle.LayoutChangeForConstantShapePaths import LayoutChangeForConstantShapePaths
return [LayoutChangeForConstantShapePaths]

def find_and_replace_pattern(self, graph: Graph):
for node in graph.get_op_nodes(type='StridedSlice'):
StridedSliceNormalizer.normalize_strided_slice(graph, node)
PermuteAttrs.create_permute_attrs(node,
attrs=[('begin_mask', 'input:0'), # but indeed depends from slice_rank
('end_mask', 'input:0'),
('new_axis_mask', 'input:0'),
('shrink_axis_mask', 'input:0'),
('ellipsis_mask', 'input:0')])

# StridedSliceNormalizer inserted nodes that changed original begin, end, and strides data nodes
# Until now it was not possible to set correct permutations
PermuteInputs().set_input_permutation(node.in_node(1), node, 'input:1', 'slice', 'dim_size')
PermuteInputs().set_input_permutation(node.in_node(2), node, 'input:2', 'slice', 'dim_size')
PermuteInputs().set_input_permutation(node.in_node(3), node, 'input:3', 'slice', 'dim_size')

@staticmethod
def normalize_strided_slice(graph: Graph, node: Node):
input_shape = node.in_port(0).data.get_shape()
input_rank = len(input_shape)
begin, _, _ = StridedSlice.validate_inputs_and_get_args(node)
slice_rank = len(begin)

StridedSlice.align_mask_with_slice_rank(node, slice_rank) # if StridedSlice is created after partial_infer
StridedSliceNormalizer.normalize_slices_attr(node)

num_insertions = input_rank - slice_rank + np.count_nonzero(node.new_axis_mask)
pavel-esir marked this conversation as resolved.
Show resolved Hide resolved
assert num_insertions >= 0, 'slice_rank - num_new_axis must <= input rank. Got instead: ' \
'input_rank = {}, slice_rank = {}, num_new_axis = {}'. \
format(input_rank, slice_rank, np.count_nonzero(node.new_axis_mask))

if np.any(node.ellipsis_mask):
assert np.count_nonzero(node.ellipsis_mask) == 1, 'only one ellipsis_mask nonzero value is allowed'
ellipsis_start = np.nonzero(node.ellipsis_mask)[0][0]
# since we don't expect values in begin and end: take the whole range along ellipsis_start
node.begin_mask[ellipsis_start] = 0
node.end_mask[ellipsis_start] = 0
node.ellipsis_mask[ellipsis_start] = 0
insertion_start_idx = ellipsis_start + 1

StridedSliceNormalizer.unroll_ellipsis_for_inputs(graph, node, ellipsis_start, num_insertions)
elif num_insertions > 0:
insertion_start_idx = slice_rank # insert blank values to mask ends
StridedSliceNormalizer.extend_inputs(node, num_insertions)

if num_insertions > 0:
# insert blank values for ellipsis unrolling and extending
for mask_name in StridedSlice.get_mask_names():
node[mask_name] = np.insert(node[mask_name], insertion_start_idx, [0] * num_insertions).astype(int)

@staticmethod
def unroll_ellipsis_for_inputs(graph: Graph, node: Node, ellipsis_start: int, num_insertions: int):
node_name = node.soft_get('name', node.id)

for i, input_name in [(1, 'begin'), (2, 'end'), (3, 'strides')]:
blank_values_arr = np.zeros(num_insertions) if input_name != 'strides' else np.ones(num_insertions)
blank_values_node = Const(graph, {'name': node_name + '/const_to_unroll_{}_ellipsis'.format(input_name),
'value': int64_array(blank_values_arr)}).create_node()

if input_name == 'strides' and node.in_port(3).disconnected():
pavel-esir marked this conversation as resolved.
Show resolved Hide resolved
continue # no need to extend strides if they are not connected

concat_in_ports_count = 3 if ellipsis_start != 0 else 2
concat = Concat(graph, {'axis': 0, 'name': node_name + '/concat_{}'.format(input_name),
'in_ports_count': concat_in_ports_count}).create_node()

if ellipsis_start != 0:
split = create_op_with_const_inputs(graph, VariadicSplit, {1: int64_array(0),
2: int64_array([ellipsis_start, -1])},
{'name': node_name + '/split_for_{}_ellipsis'.format(input_name),
'out_ports_count': 2})
node.in_port(i).get_connection().set_destination(split.in_port(0))

concat.in_port(0).connect(split.out_port(0))
concat.in_port(1).connect(blank_values_node.out_port(0))
concat.in_port(2).connect(split.out_port(1))
else:
concat.in_port(0).connect(blank_values_node.out_port(0))
node.in_port(i).get_connection().set_destination(concat.in_port(1))

concat.out_port(0).get_connection().set_destination(node.in_port(i))

@staticmethod
def extend_inputs(node: Node, num_insertions: int):
graph = node.graph
node_name = node.soft_get('name', node.id)

for i, input_name in [(1, 'begin'), (2, 'end'), (3, 'strides')]:
blank_values_arr = np.zeros(num_insertions) if input_name != 'strides' else np.ones(num_insertions)
blank_values_node = Const(graph, {'name': node_name + '/extend_{}_const'.format(input_name),
'value': int64_array(blank_values_arr)}).create_node()

if input_name == 'strides' and node.in_port(3).disconnected():
pavel-esir marked this conversation as resolved.
Show resolved Hide resolved
continue # no need to extend strides if they are not connected

if node.in_port(i).get_source().node.soft_get('type') == 'Concat':
# concat already exists
concat = node.in_port(i).get_source().node
last_in_port = max(concat.in_ports().keys())
assert not concat.in_port(last_in_port).disconnected(), 'The last in_port of Concat node {}' \
'should be connected'. \
format(concat.soft_get('name', node.id))

concat.add_input_port(last_in_port + 1)
concat.in_port(last_in_port + 1).connect(blank_values_node.out_port(0))
else:
# have to create concat
concat = Concat(graph, {'axis': 0, 'name': node_name + '/concat_{}'.format(input_name),
'in_ports_count': 2}).create_node()
node.in_port(i).get_connection().set_destination(concat.in_port(0))
concat.in_port(1).connect(blank_values_node.out_port(0))
concat.out_port(0).get_connection().set_destination(node.in_port(i))

@staticmethod
def normalize_slices_attr(node: Node):
# removes negative starts, ends and magic numbers from 'slice' attr which is used by ConvertGroupedStridedSlice
slice_rank = len(node['slices'])
data_shape = node.in_port(0).data.get_shape()

node_name = node.soft_get('name', node.id)
if node.is_in_port_connected(3):
strides = node.in_port(3).data.get_value()
if strides is None:
raise Error('StridedSlice operation for node {} supports only constant strides input'.format(node_name))
else:
strides = np.ones(slice_rank)

num_ellipsis_inserts = len(data_shape) - slice_rank + np.count_nonzero(node.new_axis_mask) + 1
res_slices = []

in_idx = 0
for i, s in enumerate(node['slices']):
if node.new_axis_mask[i]:
res_slices.append(slice(0, 1, 1))
elif node.shrink_axis_mask[i]:
res_slices.append(slice(s, s + 1, strides[i])) # need strides if shrink index is negative
elif node.ellipsis_mask[i]:
for idx in range(num_ellipsis_inserts):
res_slices.append(slice(0, data_shape[in_idx], 1))
in_idx += 1
else:
res_slices.append(s)

if not (node.new_axis_mask[i] or node.ellipsis_mask[i]):
res_slices[-1] = slice(*res_slices[-1].indices(data_shape[in_idx])) # convert negative begins/ends
in_idx += 1
node['slices'] = np.array(res_slices)
pavel-esir marked this conversation as resolved.
Show resolved Hide resolved
Loading