Skip to content

Commit

Permalink
[MO] StridedSlice improvements (#4139)
Browse files Browse the repository at this point in the history
* fix ss

* successfully converted

* successfully run moved infer and normalizer unit-tests

* successfully rewritten StridedSlice infer unittests

* int64 array

* Successfully converter crash-when-loading, xj_feauture and toy nets (cherry-picked maxpoolV4 and tf_broadcast_ext)

* successfully moved PermuteAttrs to general mechanism

* successfully converted xj_feauture and crash when loading with the new rewritten SS infer

* fixed get_shape_from_slice and moved to common utils

* fixed extending masks and some other

* some refactoring

* fixed extending masks in extractor, fixed licence year and some other code clearing

* corrected a couple of unittests

* fox permute for 5 rank slice and 4 rank inputs/

* WIP

* Added comments

* fixed StridedSlice in ProposalMutation.py

* rechecked shape_infer unittests added some new cases

* added shape_infer unit-tests after StridedSliceNormalizer pass and Permute unit-tests

* corrected unittests

* Applied review comments

* general permutations for inputs implemented, corrected ellipsis unrolling when shrink_axis is at the beginning, some other corrections

* removed code duplication in infer and normalizer, moved 'slices' attr normalizing to StridedSliceNormalizer.py

* removed some code duplication and other minor improvements

* Added tests

* minor corrections

* wider range of unittests added (froze the number)

* review comments applied

* enabled skipped unit-test

* comment corrections

* applied review comments: changed op -> type, added some asserts, corrected comments and other minor corrections

* sorted inputs, updated Supported_Frameworks_Layers.md, some minor
  • Loading branch information
pavel-esir authored Feb 16, 2021
1 parent d2548dd commit 22169a0
Show file tree
Hide file tree
Showing 21 changed files with 3,855 additions and 1,154 deletions.
2 changes: 1 addition & 1 deletion docs/MO_DG/prepare_model/Supported_Frameworks_Layers.md
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ Standard TensorFlow\* operations:
| Square| No |
| Squeeze | The case when squeeze axis is not specified is not supported |
| StopGradient | Not needed for shape inference |
| StridedSlice | No |
| StridedSlice | Supported only for constant begin, end, and strides inputs |
| Sub | No |
| Sum | No |
| Swish | No |
Expand Down
2 changes: 1 addition & 1 deletion model-optimizer/automation/package_BOM.txt
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,7 @@ extensions/middle/SliceLikeToStridedSlice.py
extensions/middle/sparse_reshape.py
extensions/middle/split_tdnn_memoryoffset.py
extensions/middle/SplitConcatPairToInterpolate.py
extensions/middle/StridedSliceNormalizer.py
extensions/middle/SwapAxesMiddleReplacer.py
extensions/middle/TensorIterator_utils.py
extensions/middle/TensorIteratorBackEdge.py
Expand Down Expand Up @@ -800,7 +801,6 @@ mo/front/common/partial_infer/multi_box_prior.py
mo/front/common/partial_infer/random_uniform.py
mo/front/common/partial_infer/reshape.py
mo/front/common/partial_infer/roipooling.py
mo/front/common/partial_infer/slice.py
mo/front/common/partial_infer/utils.py
mo/front/common/register_custom_ops.py
mo/front/common/replacement.py
Expand Down
10 changes: 6 additions & 4 deletions model-optimizer/extensions/back/CropToStridedSlice.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Copyright (C) 2018-2021 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -64,8 +64,10 @@ def replace_pattern(self, graph: Graph, match: [str, Node]):
end_mask = axis_mask.copy()

ss = StridedSlice(graph, {'name': node.soft_get('name', node.id) + '/strided_slice', 'begin_mask': begin_mask,
'end_mask': end_mask, 'new_axis_mask': np.array([0]),
'shrink_axis_mask': np.array([0]), 'ellipsis_mask': np.array([0])}).create_node()
'end_mask': end_mask,
'new_axis_mask': np.zeros(len(end_mask)),
'shrink_axis_mask': np.zeros(len(end_mask)),
'ellipsis_mask': np.zeros(len(end_mask))}).create_node()

if len(node.in_nodes()) == 2 and node.has_valid('offset'):
# Crop Type 1
Expand Down Expand Up @@ -112,7 +114,7 @@ def replace_pattern(self, graph: Graph, match: [str, Node]):
source = node.in_port(0).get_connection().get_source()

stride = Const(graph, {'value': np.ones(shape_rank, dtype=np.int64),
'name': ss.name + '/stride'}).create_node()
'name': ss.name + '/stride'}).create_node()

source.connect(ss.in_port(0))
begin.out_port(0).connect(ss.in_port(1))
Expand Down
6 changes: 3 additions & 3 deletions model-optimizer/extensions/back/ProposalMutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ def replace_pattern(graph: Graph, match: dict):
{'name': 'cropped_im_info',
'begin_mask': int64_array([1, 1]),
'end_mask': int64_array([1, 1]),
'new_axis_mask': int64_array([0]),
'shrink_axis_mask': int64_array([0]),
'ellipsis_mask': int64_array([0]),
'new_axis_mask': int64_array([0, 0]),
'shrink_axis_mask': int64_array([0, 0]),
'ellipsis_mask': int64_array([0, 0]),
'override_output_shape': True,
})

Expand Down
29 changes: 10 additions & 19 deletions model-optimizer/extensions/back/StridedSliceMasksNormalizer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Copyright (C) 2018-2021 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -14,32 +14,23 @@
limitations under the License.
"""

from extensions.back.ConvolutionNormalizer import DeconvolutionNormalizer
from extensions.back.CropToStridedSlice import CropToStridedSlice
from mo.back.replacement import BackReplacementPattern
from mo.front.common.partial_infer.utils import int64_array
from mo.graph.graph import Graph, Node
from mo.graph.graph import Graph


class StridedSliceMasksNormalizer(BackReplacementPattern):
enabled = True
force_clean_up = True

def run_after(self):
from extensions.back.ConvolutionNormalizer import DeconvolutionNormalizer
from extensions.back.CropToStridedSlice import CropToStridedSlice
return [CropToStridedSlice, DeconvolutionNormalizer]

@staticmethod
def pattern():
return dict(
nodes=[
('strided_slice', dict(type='StridedSlice'))
],
edges=[]
)

def replace_pattern(self, graph: Graph, match: [str, Node]):
node = match['strided_slice']
assert node.has_valid('begin_mask')
assert node.has_valid('end_mask')
node.begin_mask = int64_array([1 - i for i in node.begin_mask])
node.end_mask = int64_array([1 - i for i in node.end_mask])
def find_and_replace_pattern(self, graph: Graph):
for node in graph.get_op_nodes(type='StridedSlice'):
assert node.has_valid('begin_mask')
assert node.has_valid('end_mask')
node.begin_mask = int64_array([1 - i for i in node.begin_mask])
node.end_mask = int64_array([1 - i for i in node.end_mask])
7 changes: 3 additions & 4 deletions model-optimizer/extensions/middle/ApplyPermutations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Copyright (C) 2018-2021 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -133,15 +133,14 @@ def permute_input_data(graph: Graph):
input_permutations = [(in_port, edge_attrs['input_permutation']) for in_port, edge_attrs in
node.in_edges().items() if edge_attrs.get('input_permutation') is not None]
for in_port, input_perm in input_permutations:
permutation, port_info = input_perm
permutation, port_info, check_shape = input_perm
direction, port = port_info.split(':')
port = int(port)
port_to_check = node.in_port(port) if direction == 'input' else node.out_port(port)
permutation_data_node = get_node_with_permutation(node, port_info)

if permutation_data_node.has_and_set('permutation') and \
not is_input_data_in_correct_layout(node, in_port) and \
len(port_to_check.data.get_shape()) >= 4:
not is_input_data_in_correct_layout(node, in_port) and check_shape(port_to_check):
permutation(node, port_info, in_port)
if node.has_and_set('need_shape_inference'):
node.infer(node)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ class ConvertGroupedStridedSlice(MiddleReplacementPattern):
enabled = True

def run_after(self):
return [ConvertSlice]
from extensions.middle.StridedSliceNormalizer import StridedSliceNormalizer
return [ConvertSlice, StridedSliceNormalizer]

def run_before(self):
from extensions.middle.pass_separator import MiddleFinish
Expand Down
Loading

0 comments on commit 22169a0

Please sign in to comment.